In [52]:
%load_ext autoreload
%autoreload 2

import matplotlib.pyplot as plt
import numpy as np
import plotly.express as px
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import utils

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [74]:
class SpatialDataset(Dataset):
    def __init__(self, inputs, outputs,device=utils.set_torch_device()):
        n_locs = inputs.shape[0] # number of (x,y) positions.
        n_contexts = outputs.shape[1]
        context_onehots = torch.eye(n_contexts,device=device).float()
        self.outputs = outputs.T.ravel().reshape(-1,1).float()
        self.inputs = inputs.repeat_interleave(n_contexts, dim=0)
        self.contexts = context_onehots.repeat_interleave(n_locs, dim=0)


    def __len__(self):
        return len(self.inputs)


    def __getitem__(self, idx):
        return self.inputs[idx], self.contexts[idx], self.outputs[idx]    


class SpatialModel(nn.Module):
    def __init__(self, n_input=2, n_output=1, n_context_independent_hidden = 128, n_context_hidden = 128,
                 n_context_dependent_hidden = 128, n_contexts=100, n_independent_layers = 1,
                 device=utils.set_torch_device()):
        super().__init__()
        self.device = device
        if n_independent_layers == 1:
            self.ctxt_ind = nn.Linear(n_input, n_context_independent_hidden,device=self.device)
            nn.init.kaiming_normal_(self.ctxt_ind.weight, nonlinearity='relu')
        else:
            layers = []
            for i in range(n_independent_layers-1):
                layers.append(nn.Linear(n_input, n_context_independent_hidden,device=self.device))
                layers.append(nn.ReLU())
            layers.append(nn.Linear(n_context_independent_hidden, n_context_independent_hidden,device=self.device))
            self.ctxt_ind = nn.Sequential(*layers)
        self.ctxt = nn.Linear(n_contexts, n_context_hidden,bias=False,device=self.device)
        self.ctxt_dep = nn.Linear(n_context_independent_hidden+n_context_hidden,
                                  n_context_dependent_hidden,device=self.device)
        self.out = nn.Linear(n_context_independent_hidden, n_output,device=self.device)

        
        nn.init.kaiming_normal_(self.ctxt.weight, nonlinearity='relu')
        nn.init.kaiming_normal_(self.ctxt_dep.weight, nonlinearity='relu')


    def forward(self, x, context):
        ctxt_indep_rep = F.relu(self.ctxt_ind(x))
        ctxt_rep = F.relu(self.ctxt(context))
        ctxt_dep_rep = F.relu(self.ctxt_dep(torch.hstack([ctxt_indep_rep, ctxt_rep])))
        return torch.sigmoid(self.out(ctxt_dep_rep))


    def get_representations(self, x, context):
        ctxt_indep_rep = F.relu(self.ctxt_ind(x))
        ctxt_rep = F.relu(self.ctxt(context))
        ctxt_dep_rep = F.relu(self.ctxt_dep(torch.hstack([ctxt_indep_rep, ctxt_rep])))
        return ctxt_indep_rep, ctxt_rep, ctxt_dep_rep
 
    
    def loss(self, out, ys):
        return F.mse_loss(out, ys)


def train(model, train_loader, optimizer, epochs=1, log_interval=None, device=utils.set_torch_device()):
    if log_interval is None:
        log_interval = epochs//5
    
    train_losses = torch.zeros([epochs],device=device)
    model.train()
    for iepoch in range(epochs):
        train_loss = 0
        for batch_idx, (xs, tcs, ys) in enumerate(train_loader):
            optimizer.zero_grad()
            out = model(xs, tcs)
            loss = model.loss(out, ys)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()

        train_loss /= len(train_loader.dataset)
        train_losses[iepoch] = train_loss
        if iepoch%log_interval==0:
            print(f'Epoch {iepoch} train set average loss: {train_loss:.8f}')
    return train_losses


def test(model, test_loader):
    model.eval()
    test_loss = 0
    with torch.no_grad():
        for batch_idx, (xs, tcs, ys) in enumerate(test_loader):
            out = model(xs, tcs)
            loss = model.loss(out, ys)
            test_loss += loss.item()
        test_loss /= len(test_loader.dataset)
    return test_loss


def setup(n_contexts=10, n_layers=1, batch_size=1024, lr=1e-3, device=utils.set_torch_device()):
    inputs = torch.load('../data/spatial_inputs.pt',map_location=device)
    outputs = torch.load('../data/spatial_outputs.pt',map_location=device)[:, :n_contexts]
    train_dataset = SpatialDataset(inputs, outputs, device=device)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, pin_memory=True)

    #### COULD SET UP TEST DATASET HERE ####

    model = SpatialModel(n_contexts=n_contexts, n_independent_layers = n_layers, device=device)
    optimizer = optim.Adam(model.parameters(), lr=lr)
    return model, train_loader, optimizer

In [88]:
# Setup the environment.
n_contexts = 6
device = utils.set_torch_device(use_gpu=False)

model, train_loader, optimizer = setup(n_contexts=n_contexts, lr=.005,device=device, n_layers=2)

nepochs = 150
log_interval = 25
train_losses = train(model, train_loader, optimizer, nepochs, log_interval, device=device)

fig = px.line(train_losses.cpu().numpy(), title='Training Loss', labels={'x':'Epoch', 'y':'Loss'})
fig.show()

Epoch 0 train set average loss: 0.00001214


In [85]:
train_x, train_c, train_y = train_loader.dataset[:]
model_predictions = model(train_x, train_c).cpu().detach().numpy()
ctxt_indep_rep, ctxt_rep, ctxt_dep_rep = model.get_representations(train_x, train_c)
train_y = train_y.reshape(n_contexts,100,100)
model_predictions = model_predictions.reshape(n_contexts,100,100)

utils.plot_predictions(train_y, model_predictions, n_contexts)

In [86]:
#similarity in the context representation of the model
context_embeddings = ctxt_rep.reshape(n_contexts,100,100,ctxt_rep.shape[-1])[:,0,0].cpu().detach().numpy()
context_similarity = cosine_similarity(context_embeddings)
px.imshow(context_similarity)

In [87]:
#similarity based on the location of the Gaussian's mean
mean_locs = train_y.cpu().detach().numpy().reshape((n_contexts,100*100)).argmax(axis=1)
mean_locs = np.stack(np.unravel_index(mean_locs, (100,100))).T
px.imshow(-euclidean_distances(mean_locs))