In [1]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from IPython.core.debugger import set_trace
import numpy as np

In [2]:
# Dataset for pairs
class BoxDataset(Dataset):
  """Dataset wrapping images and target labels for Kaggle - Planet Amazon from Space competition.

  Arguments:
      A CSV file path
  """

  def __init__(self, csv_path):
    data = np.loadtxt(csv_path)
    self.len = len(data)
    self.X_train = torch.from_numpy(data[:,:2].astype(np.long))
    self.y_train = torch.from_numpy(data[:,2].astype(np.float32))

  def __getitem__(self, index):
    return self.X_train[index], self.y_train[index]

  def __len__(self):
    return self.len

# Model = a tensor of boxes
class Boxes(nn.Module):
  def __init__(self, num_boxes, dim):
    super(Boxes, self).__init__()
    box_mins = torch.rand(num_boxes, dim)
    box_maxs = box_mins + torch.rand(num_boxes, dim) * (1 - box_mins)
    boxes = torch.stack([box_mins, box_maxs], dim=1)
    self.boxes = nn.Parameter(boxes)
    
  def forward(self, X):
    """Returns box embeddings for ids"""
    #set_trace()
    x = self.boxes[X]
    o = cond_probs(x[:,0,:,:], x[:,1,:,:])
    return o
    

In [3]:
MIN_IND, MAX_IND = 0, 1

def volumes(boxes):
  r = (boxes[:,MAX_IND,:] - boxes[:, MIN_IND,:]).prod(1)
  return r

def intersections(boxes1, boxes2):
  #set_trace()
  intersections_min = torch.max(boxes1[:, :, MIN_IND], boxes2[:, :, MIN_IND])
  intersections_max = torch.min(boxes1[:, :, MAX_IND], boxes2[:, :, MAX_IND])
  apap = torch.stack([intersections_min, intersections_max], 1)
  return apap

def cond_probs(boxes1, boxes2):
  return volumes(intersections(boxes1, boxes2))/volumes(boxes2)

In [4]:
train_ds = BoxDataset("data/sample/train.txt")
train_dl = DataLoader(train_ds, batch_size=18, shuffle=True, num_workers=4)

model = Boxes(6,4)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1.0)


In [5]:
N_EPOCHS = 500

for epoch in range(N_EPOCHS):
    
    # Train
    model.train()  # IMPORTANT
    
    running_loss, correct = 0.0, 0
    for X, y in train_dl:
      #set_trace()
      #X, y = X.to(device), y.to(device)

      optimizer.zero_grad()
      with torch.set_grad_enabled(True):
        y_ = model(X)
        loss = criterion(y_, y)

      loss.backward()
      optimizer.step()
      
      # Statistics
      print("    batch loss: "+str(loss.item()))
      #_, y_label_ = torch.max(y_, 1)
      #correct += (y_label_ == y).sum().item()
      running_loss += loss.item() * X.shape[0]

    print("  Train Loss: "+str(running_loss / len(train_dl.dataset)))
    print("  Train Acc:  "+str(correct / len(train_dl.dataset)))
    

    batch loss: 23193576.0
  Train Loss: 23193576.0
  Train Acc:  0
    batch loss: 0.776467859745
  Train Loss: 0.776467859745
  Train Acc:  0
    batch loss: 0.446060538292
  Train Loss: 0.446060538292
  Train Acc:  0
    batch loss: 0.393458276987
  Train Loss: 0.393458276987
  Train Acc:  0
    batch loss: 0.374850600958
  Train Loss: 0.374850600958
  Train Acc:  0
    batch loss: 0.365793436766
  Train Loss: 0.365793436766
  Train Acc:  0
    batch loss: 0.360561639071
  Train Loss: 0.360561639071
  Train Acc:  0
    batch loss: 0.357201397419
  Train Loss: 0.357201397419
  Train Acc:  0
    batch loss: 0.354882925749
  Train Loss: 0.354882925749
  Train Acc:  0
    batch loss: 0.353198826313
  Train Loss: 0.353198826313
  Train Acc:  0
    batch loss: 0.351927608252
  Train Loss: 0.351927608252
  Train Acc:  0
    batch loss: 0.350939184427
  Train Loss: 0.350939184427
  Train Acc:  0
    batch loss: 0.350152254105
  Train Loss: 0.350152254105
  Train Acc:  0
    batch loss: 0.34

    batch loss: 0.345482707024
  Train Loss: 0.345482707024
  Train Acc:  0
    batch loss: 0.345482647419
  Train Loss: 0.345482647419
  Train Acc:  0
    batch loss: 0.345482587814
  Train Loss: 0.345482587814
  Train Acc:  0
    batch loss: 0.34548252821
  Train Loss: 0.34548252821
  Train Acc:  0
    batch loss: 0.345482498407
  Train Loss: 0.345482498407
  Train Acc:  0
    batch loss: 0.345482409
  Train Loss: 0.345482409
  Train Acc:  0
    batch loss: 0.345482349396
  Train Loss: 0.345482349396
  Train Acc:  0
    batch loss: 0.345482319593
  Train Loss: 0.345482319593
  Train Acc:  0
    batch loss: 0.345482230186
  Train Loss: 0.345482230186
  Train Acc:  0
    batch loss: 0.345482170582
  Train Loss: 0.345482170582
  Train Acc:  0
    batch loss: 0.345482140779
  Train Loss: 0.345482140779
  Train Acc:  0
    batch loss: 0.345482081175
  Train Loss: 0.345482081175
  Train Acc:  0
    batch loss: 0.34548202157
  Train Loss: 0.34548202157
  Train Acc:  0
    batch loss: 0.3454

    batch loss: 0.345477759838
  Train Loss: 0.345477759838
  Train Acc:  0
    batch loss: 0.345477730036
  Train Loss: 0.345477730036
  Train Acc:  0
    batch loss: 0.345477670431
  Train Loss: 0.345477670431
  Train Acc:  0
    batch loss: 0.345477640629
  Train Loss: 0.345477640629
  Train Acc:  0
    batch loss: 0.345477581024
  Train Loss: 0.345477581024
  Train Acc:  0
    batch loss: 0.345477551222
  Train Loss: 0.345477551222
  Train Acc:  0
    batch loss: 0.345477491617
  Train Loss: 0.345477491617
  Train Acc:  0
    batch loss: 0.345477461815
  Train Loss: 0.345477461815
  Train Acc:  0
    batch loss: 0.34547740221
  Train Loss: 0.34547740221
  Train Acc:  0
    batch loss: 0.345477372408
  Train Loss: 0.345477372408
  Train Acc:  0
    batch loss: 0.345477312803
  Train Loss: 0.345477312803
  Train Acc:  0
    batch loss: 0.345477253199
  Train Loss: 0.345477253199
  Train Acc:  0
    batch loss: 0.345477223396
  Train Loss: 0.345477223396
  Train Acc:  0
    batch loss

  Train Loss: 0.34547200799
  Train Acc:  0
    batch loss: 0.345471948385
  Train Loss: 0.345471948385
  Train Acc:  0
    batch loss: 0.345471918583
  Train Loss: 0.345471918583
  Train Acc:  0
    batch loss: 0.345471858978
  Train Loss: 0.345471858978
  Train Acc:  0
    batch loss: 0.345471799374
  Train Loss: 0.345471799374
  Train Acc:  0
    batch loss: 0.345471739769
  Train Loss: 0.345471739769
  Train Acc:  0
    batch loss: 0.345471680164
  Train Loss: 0.345471680164
  Train Acc:  0
    batch loss: 0.34547162056
  Train Loss: 0.34547162056
  Train Acc:  0
    batch loss: 0.345471590757
  Train Loss: 0.345471590757
  Train Acc:  0
    batch loss: 0.345471531153
  Train Loss: 0.345471531153
  Train Acc:  0
    batch loss: 0.345471471548
  Train Loss: 0.345471471548
  Train Acc:  0
    batch loss: 0.345471411943
  Train Loss: 0.345471411943
  Train Acc:  0
    batch loss: 0.345471352339
  Train Loss: 0.345471352339
  Train Acc:  0
    batch loss: 0.345471292734
  Train Loss: 0

    batch loss: 0.345465451479
  Train Loss: 0.345465451479
  Train Acc:  0
    batch loss: 0.345465362072
  Train Loss: 0.345465362072
  Train Acc:  0
    batch loss: 0.34546533227
  Train Loss: 0.34546533227
  Train Acc:  0
    batch loss: 0.345465272665
  Train Loss: 0.345465272665
  Train Acc:  0
    batch loss: 0.345465183258
  Train Loss: 0.345465183258
  Train Acc:  0
    batch loss: 0.345465123653
  Train Loss: 0.345465123653
  Train Acc:  0
    batch loss: 0.345465064049
  Train Loss: 0.345465064049
  Train Acc:  0
    batch loss: 0.345465004444
  Train Loss: 0.345465004444
  Train Acc:  0
    batch loss: 0.345464944839
  Train Loss: 0.345464944839
  Train Acc:  0
    batch loss: 0.345464885235
  Train Loss: 0.345464885235
  Train Acc:  0
    batch loss: 0.345464795828
  Train Loss: 0.345464795828
  Train Acc:  0
    batch loss: 0.345464736223
  Train Loss: 0.345464736223
  Train Acc:  0
    batch loss: 0.345464676619
  Train Loss: 0.345464676619
  Train Acc:  0
    batch loss