In [16]:
from os.path import join
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

In [33]:
class SpatialDataset(Dataset):
    def __init__(self, inputs, outputs):
        n_locs = inputs.shape[0] # number of (x,y) positions.
        n_contexts = outputs.shape[1]
        context_onehots = torch.eye(n_contexts).float()
        self.outputs = outputs.T.ravel().reshape(-1,1).float()
        self.inputs = inputs.repeat_interleave(n_contexts, dim=0)
        self.contexts = context_onehots.repeat_interleave(n_locs, dim=0)

    def __len__(self):
        return len(self.inputs)

    def __getitem__(self, idx):
        return self.inputs[idx], self.contexts[idx], self.outputs[idx]
    
    
class SpatialModel(nn.Module):
    def __init__(self, n_contexts=100):
        super().__init__()
        self.ctxt_ind = nn.Linear(2, 16)
        self.ctxt = nn.Linear(n_contexts, 16, bias=False)
        self.ctxt_dep = nn.Linear(32, 16)
        self.out = nn.Linear(16, 1)

    def forward(self, x, context):
        ctxt_indep_rep = F.relu(self.ctxt_ind(x))
        ctxt_rep = F.relu(self.ctxt(context))
        ctxt_dep_rep = F.relu(self.ctxt_dep(torch.hstack([ctxt_indep_rep, ctxt_rep])))
        return torch.sigmoid(self.out(ctxt_dep_rep))
    
    def loss(self, out, ys):
        return F.binary_cross_entropy(out, ys)
    
def train(model, device, train_loader, optimizer):
    model.train()
    train_loss = 0
    for batch_idx, (xs, tcs, ys) in enumerate(train_loader):
        optimizer.zero_grad()
        xs, tcs, ys = xs.to(device), tcs.to(device), ys.to(device)
        out = model(xs, tcs)
        loss = model.loss(out, ys)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader.dataset)
    return train_loss


def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (xs, tcs, ys) in enumerate(test_loader):
            xs, tcs, ys = xs.to(device), tcs.to(device), ys.to(device)
            out = model(xs, tcs)
            loss = model.loss(out, ys)
            test_loss += loss.item()
        test_loss /= len(test_loader.dataset)
    return test_loss


def setup(n_contexts=80, batch_size=512, use_gpu=True, lr=1e-3):
    use_cuda = use_gpu and torch.cuda.is_available()
    device = torch.device('cuda' if use_cuda else 'cpu')
    inputs = torch.load('../data/spatial_inputs.pt')
    outputs = torch.load('../data/spatial_outputs.pt')[:, :n_contexts]
    train_dataset = SpatialDataset(inputs, outputs)
    train_loader = train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)
    #### COULD SET UP TEST DATASET HERE ####
    model = SpatialModel(n_contexts=n_contexts)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    return model, train_loader, optimizer, device

In [34]:
# Setup the environment.
model, train_loader, optimizer, device = setup()

nepochs = 100
log_interval = 1
train_losses = torch.zeros([nepochs])
for iepoch in range(nepochs):
    train_loss = train(model, device, train_loader, optimizer)
    train_losses[iepoch] = train_loss
    if iepoch%log_interval==0:
        print(f'Epoch {iepoch} train set average loss: {train_loss:.8f}')

Epoch 0 train set average loss: 0.00127853
Epoch 1 train set average loss: 0.00125425
Epoch 2 train set average loss: 0.00123302
Epoch 3 train set average loss: 0.00122800
Epoch 4 train set average loss: 0.00122670
Epoch 5 train set average loss: 0.00122591
Epoch 6 train set average loss: 0.00122522
Epoch 7 train set average loss: 0.00122458
Epoch 8 train set average loss: 0.00122407
Epoch 9 train set average loss: 0.00122328
Epoch 10 train set average loss: 0.00122256
Epoch 11 train set average loss: 0.00122205


KeyboardInterrupt: 

In [None]:
from sklearn.metrics.pairwise import cosine_similarity

fig, axes = plt.subplots(10, 10, sharex=True, sharey=True, tight_layout=True, figsize=(20,20))
test = cosine_similarity(all_gradients)
plt.imshow(test)