In [2]:
import re
import torch
import warnings
import numpy as np
import pandas as pd
import torch.nn.functional as F
import torch.nn as nn
from itertools import combinations
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader, SequentialSampler
from typing import Tuple, Iterator, List, Dict

# %pip install git+https://github.com/geoopt/geoopt.git
import geoopt
import itertools
import matplotlib.pyplot as plt
import pandas as pd

import panel as pn
import panel.widgets as pnw
import hvplot.pandas
pn.extension()

plt.rcParams["figure.figsize"] = (10,10)
pd.options.plotting.backend = 'holoviews'

In [5]:
# #number of unique items in the data matrix
def get_nitems(train_triplets:torch.Tensor) -> int:
    n_items = torch.max(train_triplets).item()
    if torch.min(train_triplets).item() == 0:
        n_items += 1
    return n_items


# returns batches of one-hot-encoced vector of item triplets
class BatchGenerator(object):

    def __init__(
                self,
                I:torch.tensor,
                dataset:torch.Tensor,
                batch_size:int,
                sampling_method:str='normal',
                p=None,
):
        self.I = I
        self.dataset = dataset
        self.batch_size = batch_size
        self.sampling_method = sampling_method
        self.p = p

        if sampling_method == 'soft':
            assert isinstance(self.p, float)
            self.n_batches = int(len(self.dataset) * self.p) // self.batch_size
        else:
            self.n_batches = len(self.dataset) // self.batch_size

    def __len__(self) -> int:
        return self.n_batches

    def __iter__(self) -> Iterator[torch.Tensor]:
        return self.get_batches(self.I, self.dataset)

    def sampling(self, triplets:torch.Tensor) -> torch.Tensor:
        """randomly sample training data during each epoch"""
        rnd_perm = torch.randperm(len(triplets))
        if self.sampling_method == 'soft':
            rnd_perm = rnd_perm[:int(len(rnd_perm) * self.p)]
        return triplets[rnd_perm]

    def get_batches(self, I:torch.Tensor, triplets:torch.Tensor) -> Iterator[torch.Tensor]:
        if not isinstance(self.sampling_method, type(None)):
            triplets = self.sampling(triplets)
        for i in range(self.n_batches):
            batch = encode_as_onehot(I, triplets[i*self.batch_size: (i+1)*self.batch_size])
            yield batch


def encode_as_onehot(I:torch.Tensor, triplets:torch.Tensor) -> torch.Tensor:
    """encode item triplets as one-hot-vectors"""
    return I[triplets.flatten(), :]

#load train and test mini-batches
def load_batches(
              test_triplets:torch.Tensor,
              n_items:int,
              batch_size:int,
              inference:bool=False,
              sampling_method:str=None,
              rnd_seed:int=None,
              multi_proc:bool=False,
              n_gpus:int=None,
              p=None,
              ):
  #initialize an identity matrix of size n_items x n_items for one-hot-encoding of triplets
    I = torch.eye(n_items)
    val_batches = BatchGenerator(I=I, dataset=test_triplets, batch_size=batch_size, sampling_method=None, p=None)
    return val_batches


# SPoSE model
class SPoSE(nn.Module):

    def __init__(
                self,
                in_size:int,
                out_size:int,
                init_weights:bool=True,
                ):
        super(SPoSE, self).__init__()
        self.in_size = in_size
        self.out_size = out_size
        self.fc = nn.Linear(self.in_size, self.out_size, bias=False)

        if init_weights:
            self._initialize_weights()

    def forward(self, x:torch.Tensor) -> torch.Tensor:
        return self.fc(x)

    def _initialize_weights(self) -> None:
        mean, std = .1, .01
        for m in self.modules():
            if isinstance(m, nn.Linear):
                m.weight.data.normal_(mean, std)
                

def compute_similarities(hyperbolic, anchor:torch.Tensor, positive:torch.Tensor, negative:torch.Tensor, method:str, distance_metric:str = 'dot') -> Tuple:
    if distance_metric == 'dot':
        pos_sim = torch.sum(anchor * positive, dim=1)
        neg_sim = torch.sum(anchor * negative, dim=1)
        if method == 'odd_one_out':
            neg_sim_2 = torch.sum(positive * negative, dim=1)
            return pos_sim, neg_sim, neg_sim_2
        else:
            return pos_sim, neg_sim
    elif distance_metric == 'euclidean':
        pos_sim = -1*torch.sqrt(torch.sum(torch.square(torch.sub(anchor,positive)), dim=1))
        neg_sim = -1*torch.sqrt(torch.sum(torch.square(torch.sub(anchor,negative)), dim=1))
        
        if method == 'odd_one_out':
            neg_sim_2 = -1*torch.sqrt(torch.sum(torch.square(torch.sub(positive,negative)), dim=1))
            return pos_sim, neg_sim, neg_sim_2
        else:
            return pos_sim, neg_sim
    elif distance_metric == 'hyperbolic':
        pos_sim = -1*hyperbolic.dist(anchor, positive)
        neg_sim = -1*hyperbolic.dist(anchor, negative)
        
        if method == 'odd_one_out':
            neg_sim_2 = -1*hyperbolic.dist(positive,negative)
            return pos_sim, neg_sim, neg_sim_2
        else:
            return pos_sim, neg_sim

def softmax(sims:tuple, t:torch.Tensor) -> torch.Tensor:
    return torch.exp(sims[0] / t) / torch.sum(torch.stack([torch.exp(sim / t) for sim in sims]), dim=0)

## Run model on all objects

In [35]:
test_triplets = torch.tensor(range(1854)).view(-1,3)

In [36]:
# GENERAL SETTINGS
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
triplets_dir = "master_thesis/SPoSE/data"
sampling_method='normal'
rnd_seed=42
p=None
method = 'odd_one_out'
n_items = get_nitems(test_triplets)
task = 'odd_one_out'
batch_size=test_triplets.size()[0]

In [194]:
def run_model(n_items, out_size, file, distance_metric, test_triplets, batch_size, sampling_method, rnd_seed, p):
    # extract model from tar file
    model = SPoSE(in_size=n_items, out_size=embed_dim, init_weights=True)
    model.to(device)
    model.load_state_dict(torch.load(file, map_location=torch.device('cpu'))['model_state_dict'])
    model.state_dict(), model.state_dict()['fc.weight'].size()
    
    # evaluate model
    model.eval()
    with torch.no_grad():
        val_batches = load_batches( test_triplets=test_triplets,
                                    n_items=n_items,
                                    batch_size=batch_size,
                                    sampling_method=sampling_method,
                                    rnd_seed=rnd_seed,
                                    p=p
                                    )
        # batch size should be size of total batch to get all output
        for j, batch in enumerate(val_batches):
            batch = batch.to(device)
            logits = model(batch)
            if distance_metric == 'hyperbolic':
                logits = hyperbolic.expmap0(logits)
            anchor, positive, negative = torch.unbind(torch.reshape(logits, (-1, 3, logits.shape[-1])), dim=1)
            similarities  = compute_similarities(hyperbolic, anchor, positive, negative, method, distance_metric)
            probas = F.softmax(torch.stack(similarities, dim=-1), dim=1).detach().cpu().numpy()
            
            
    # create df of output of model, objectID and uniqueID (= name of object)
    columns = ['dim_' + str(i) for i in range(logits.size()[1])]
    logits_df = pd.DataFrame(logits, columns=columns)
    logits_df.index.name = 'objectID'
    logits_df = logits_df.reset_index()

    wordnet_df = pd.read_csv("items1854names.tsv", sep="\t"
                          ).rename_axis('objectID').reset_index()[['objectID', 'uniqueID']]

    category_df = pd.read_csv("category53_longFormat.tsv", sep="\t")[['category', 'uniqueID']]
    
    merge_df = pd.merge(logits_df, wordnet_df)
    final_df = pd.merge(merge_df, category_df) # some objects belong to multiple categories 
    return final_df

## Plot objects (2Dimensional)
### show all cateogries
> Hyperbolic space

In [193]:
# SPECIFIC SETTINGS (change per model)
hyperbolic = geoopt.PoincareBallExact(c=7.5)
distance_metric = 'hyperbolic'
embed_dim=2
file = 'model_hyperbolic_2dim.tar' #embed dim 2 hyperbolic - with negative dimensions

hyperbolic_df = run_model(n_items, embed_dim, file, distance_metric, test_triplets, batch_size, sampling_method, 
                   rnd_seed, p)

hyperbolic_df.hvplot.scatter(x='dim_0', y='dim_1', by = 'category', hover_cols=['uniqueID', 'category'], 
                             legend=False, width=700, height=700,
                            title='Objects in 2-dimensional hyperbolic space')

> Euclidean space

In [191]:
file = 'model_euclidean_2dim.tar'
distance_metric = 'euclidean'
embed_dim=2

euclidean_df = run_model(n_items, embed_dim, file, distance_metric, test_triplets, batch_size, sampling_method, 
                         rnd_seed, p)

euclidean_df.hvplot.scatter(x='dim_0', y='dim_1', by = 'category', hover_cols=['uniqueID', 'category'], 
                             legend=False, width=700, height=700,
                           title='Objects in 2-dimensional Euclidean space')

### select x categories
> Hyperbolic space

In [200]:
object_categories = pnw.MultiChoice(name='Category', value=['animal', 'toy'],
                                    options=list(sorted(hyperbolic_df['category'].unique())))

idf = hyperbolic_df.interactive()

data_pipeline = (idf[
    (idf.category.isin(object_categories))
])

data_pipeline.hvplot(kind='scatter' ,x='dim_0', y='dim_1', by = 'category', hover_cols=['uniqueID', 'category'],
                    legend=False, width=700, height=700, xlim=(-0.4,0.4), ylim=(-0.4,0.4),
                    title='Objects in 2-dimensional hyperbolic space')

> Euclidean space

In [201]:
object_categories = pnw.MultiChoice(name='Category', value=['animal', 'toy'],
                                    options=list(sorted(euclidean_df['category'].unique())))

idf = euclidean_df.interactive()

data_pipeline = (idf[
    (idf.category.isin(object_categories))
])

data_pipeline.hvplot(kind='scatter' ,x='dim_0', y='dim_1', by = 'category', hover_cols=['uniqueID', 'category'],
                    legend=False, width=700, height=700, xlim=(0,2.5), ylim=(0,2.5), 
                     title='Objects in 2-dimensional Euclidean space')