In [165]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [219]:
class SphereDatasetAnchor(Dataset):
    def __init__(self, extra_modes, num_points, transform_fn=None):
        self.n = num_points
        self.m = extra_modes + 1
        self.points = np.random.uniform(-1,1,size=(self.n, 3))
        mags = np.linalg.norm(self.points, axis=-1)
        self.points /= mags[...,None]
        self.disks = np.array([np.random.uniform(-1,1,size=(3)) for i in range(extra_modes)])
        self.disks = np.array([disk / np.linalg.norm(disk) for disk in self.disks])
        self.disk_points = torch.tensor([self.project_points(self.points, disk) for disk in self.disks]).float()
        self.points = torch.tensor(self.points).float()
        self.disks = torch.tensor(self.disks).float()
        self.transform_fn = transform_fn
        
    def project_points(self, points, disk):
        output_points = []
        for point in points:
            cosval = np.dot(point, disk)
            arrow = (point*1/np.dot(point, disk) - disk) * cosval # mind the sign of cosval
            arrow /= np.linalg.norm(arrow)
            output_points.append(arrow)
        return np.array(output_points)
    
    def __len__(self):
        return self.m-1
    
    def __getitem__(self, idx):
        e_points = self.disk_points[idx]
        if self.transform_fn is not None:
            return self.transform_fn.evaluate(self.points), self.transform_fn.evaluate(e_points)
        return self.points, e_points
    
    def all_points(self):
        return [self.points] + list(self.disk_points)

In [220]:
class Mapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 128)
    
    def forward(self, X):
        return self.linear(X)

    def evaluate(self, X):
        self.eval()
        with torch.no_grad():
            return self.forward(X)

mapper = Mapper()

In [221]:
dataset = SphereDatasetAnchor(1, 1000, transform_fn=mapper)

In [233]:
from torch.nn.functional import cross_entropy, relu
class ContrastiveModel(nn.Module):
    def __init__(self, modes):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(128, 3) for i in range(modes)])
        self.linears2 = nn.ModuleList([nn.Linear(64, 64) for i in range(modes)])
        self.linears3 = nn.ModuleList([nn.Linear(64, 64) for i in range(modes)])
        self.linears4 = nn.ModuleList([nn.Linear(64, 3) for i in range(modes)])
        self.comps = nn.ModuleList([nn.Linear(3*2, 1) for i in range(modes-1)])
    
    def loss(self, corr_mat):
        losses_row = cross_entropy(corr_mat, torch.arange(0, len(corr_mat)).long())
        losses_col = cross_entropy(torch.transpose(corr_mat,0,1), torch.arange(0, len(corr_mat)).long())
        return torch.mean(losses_row) + torch.mean(losses_col)
    
    def compute_corr_dot(self, a_points, b_points):
        return torch.sum(a_points[:,None] * b_points[None,:], dim=-1)
    
    def compute_corr_func(self, a_points, b_points, b):
        a_points_tiled = torch.tile(a_points[:,None], (1,len(b_points),1))
        b_points_tiled = torch.tile(b_points[None,:], (len(a_points),1,1))
        return self.comps[b-1](torch.concat([a_points_tiled, b_points_tiled], dim=-1))[:,:,0]
    
    def forward(self, a, b, a_points, b_points):
        a_points = self.linears[a](a_points)
        b_points = self.linears[b](b_points)
        #a_points = self.linears4[a](relu(self.linears3[a](relu(self.linears2[a](relu(self.linears[a](a_points)))))))
        #b_points = self.linears4[b](relu(self.linears3[b](relu(self.linears2[b](relu(self.linears[b](b_points)))))))
        corr_mat = self.compute_corr_dot(a_points, b_points)
        return corr_mat
    
    def run_evaluate(self, i, points):
        return self.linears[i](points)

In [234]:
model = ContrastiveModel(2)

In [235]:
from torch.optim import Adam
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
optim = Adam(model.parameters(), lr=0.001)

In [236]:
def train_data(optim, model, ptrs1, ptrs2, a, b, iteration=10):
    avg_loss = 0
    for i in range(iteration):
        optim.zero_grad()
        corr_mat = model(a, b, ptrs1, ptrs2)
        loss = model.loss(corr_mat)
        loss.backward()
        optim.step()
        avg_loss += loss.item() / iteration
    return avg_loss

In [237]:
EPOCH = 100
for epoch in range(EPOCH):
    for i in range(1):
        ptrs1, ptrs2 = dataset[i]
        loss = train_data(optim, model, ptrs1, ptrs2, 0, i+1, iteration=10)
        print(f'{i+1}: {loss}')

1: 13.559563827514648
1: 12.968652057647706
1: 12.253328323364258
1: 11.631587028503418
1: 11.258765697479248
1: 11.087079715728759
1: 11.02204351425171
1: 11.00248851776123
1: 10.997598457336426
1: 10.997148513793945
1: 10.99716930389404
1: 10.997151660919188
1: 10.997091579437255
1: 10.99704837799072
1: 10.997030258178713
1: 10.997021865844726
1: 10.99701623916626
1: 10.997009944915774
1: 10.997002887725829
1: 10.996995162963865
1: 10.996987152099608
1: 10.996978187561036
1: 10.99696912765503
1: 10.996958827972412
1: 10.996948432922363
1: 10.996937561035155
1: 10.99692678451538
1: 10.996915531158447
1: 10.996903991699218
1: 10.99689292907715
1: 10.996881675720214
1: 10.996871662139892
1: 10.996861362457274
1: 10.996852493286132
1: 10.996844387054443
1: 10.996838283538821
1: 10.99683198928833
1: 10.996826553344725
1: 10.996822738647461
1: 10.996819591522218
1: 10.99681692123413
1: 10.996815395355224
1: 10.99681386947632
1: 10.996812534332276
1: 10.996811962127687
1: 10.9968111038208
1

In [238]:
orig_points = dataset.all_points()

In [239]:
import plotly.express as px
import pandas as pd

In [240]:
df = pd.DataFrame(np.concatenate(orig_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = ['0']*1000 + ['1']*1000 #+ ['2']*1000 + ['3']*1000
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.show()

In [241]:
pred_points = [model.run_evaluate(i, mapper(points)).detach().numpy() for i, points in enumerate(orig_points)]

In [243]:
df = pd.DataFrame(np.concatenate(pred_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = ['0']*1000 + ['1']*1000 #+ ['2']*1000 + ['3']*1000
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.show()

In [244]:
def comp_stable_rank(ptrs):
    corr = np.matmul(ptrs.T, ptrs)/len(ptrs)
    sing = np.linalg.svd(corr)[1][0]
    norm = np.linalg.norm(corr)
    return norm**2/sing**2
    
print(comp_stable_rank(pred_points[0]))
print(comp_stable_rank(pred_points[1]))

1.6527351575185716
1.720198008088884


In [164]:
ptrs1 = dataset.points
ptrs2 = dataset.disk_points[2]
loss_total = 0
for i,ptr2 in enumerate(ptrs2):
    loss_total += cross_entropy(torch.sum(ptrs1 * ptr2[None,:], dim=-1), torch.tensor(i).long())
for i,ptr1 in enumerate(ptrs1):
    loss_total += cross_entropy(torch.sum(ptr1[None,:] * ptrs2, dim=-1), torch.tensor(i).long())
print(loss_total/len(ptrs2))

tensor(12.5544)


In [272]:
l1 = cross_entropy(torch.sum(ptrs1[:,None]*ptrs2[None,:], dim=-1), torch.arange(len(ptrs1)).long())
l2 = cross_entropy(torch.sum(ptrs2[:,None]*ptrs1[None,:], dim=-1), torch.arange(len(ptrs1)).long())
l1+l2

tensor(12.5478)

In [273]:
print(dataset.disks[0])

tensor([-0.3848,  0.8510, -0.3575])
