In [3]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [14]:
class SquareDatasetAnchor(Dataset):
    def __init__(self, extra_modes, num_points, transform_fn=None):
        self.n = num_points
        self.m = extra_modes + 1
        self.points = np.random.uniform(-1,1,size=(self.n, 3))
        self.directions = np.random.uniform(-1,1,size=(extra_modes, 3))
        self.directions /= np.linalg.norm(self.directions, axis=-1)[:, None]
        self.plane_points = [self.project_points(self.points, self.directions[i]) for i in range(extra_modes)]
        self.points = torch.tensor(self.points).float()
        self.plane_points = [torch.tensor(points).float() for points in self.plane_points]
        self.transform_fn = transform_fn
    
    def project_points(self, points, direction):
        output_points = []
        for point in points:
            output_points.append(point - np.dot(point, direction) * direction)
        return np.array(output_points)
        
    def __len__(self):
        return self.m-1
    
    def __getitem__(self, idx):
        e_points = self.plane_points[idx]
        if self.transform_fn is not None:
            return self.transform_fn.evaluate(self.points), self.transform_fn.evaluate(e_points)
        return self.points, e_points
    
    def all_points(self):
        return [self.points] + list(self.plane_points)[:self.m-1]

In [15]:
class Mapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 128)
    
    def forward(self, X):
        return self.linear(X)

    def evaluate(self, X):
        self.eval()
        with torch.no_grad():
            return self.forward(X)

mapper = Mapper()

In [59]:
extra_modes = 7
dataset = SquareDatasetAnchor(extra_modes, 1000, transform_fn=mapper)

In [70]:
from torch.nn.functional import cross_entropy, relu
class ContrastiveModel(nn.Module):
    def __init__(self, modes):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(128, 64) for i in range(modes)])
        self.linears2 = nn.ModuleList([nn.Linear(64, 64) for i in range(modes)])
        self.linears3 = nn.ModuleList([nn.Linear(64, 64) for i in range(modes)])
        self.linears4 = nn.ModuleList([nn.Linear(64, 3) for i in range(modes)])
        self.comps = nn.ModuleList([nn.Linear(3*2, 1) for i in range(modes-1)])
    
    def loss(self, corr_mat):
        #print(corr_mat)
        losses_row = cross_entropy(corr_mat, torch.arange(0, len(corr_mat)).long())
        losses_col = cross_entropy(torch.transpose(corr_mat,0,1), torch.arange(0, len(corr_mat)).long())
        return torch.mean(losses_row)*2 + torch.mean(losses_col)
    
    def compute_corr_dot(self, a_points, b_points):
        return torch.sum(a_points[:,None,:] * b_points[None,:,:], dim=-1)
    
    def compute_corr_dist(self, a_points, b_points):
        return -torch.sum((a_points[:,None,:] - b_points[None,:,:])**2, dim=-1)
    
    def compute_corr_func(self, a_points, b_points, b):
        a_points_tiled = torch.tile(a_points[:,None], (1,len(b_points),1))
        b_points_tiled = torch.tile(b_points[None,:], (len(a_points),1,1))
        return self.comps[b-1](torch.concat([a_points_tiled, b_points_tiled], dim=-1))[:,:,0]
    
    def forward(self, a, b, a_points, b_points):
        #a_points = self.linears[a](a_points)
        #b_points = self.linears[b](b_points)
        a_points = self.linears4[a](relu(self.linears3[a](relu(self.linears2[a](relu(self.linears[a](a_points)))))))
        b_points = self.linears4[b](relu(self.linears3[b](relu(self.linears2[b](relu(self.linears[b](b_points)))))))
        corr_mat = self.compute_corr_dist(a_points, b_points)
        return corr_mat
    
    def run_evaluate(self, i, points):
        #return self.linears[i](points)
        return self.linears4[i](relu(self.linears3[i](relu(self.linears2[i](relu(self.linears[i](points)))))))

In [71]:
model = ContrastiveModel(extra_modes+1)

In [72]:
from torch.optim import Adam
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
optim = Adam(model.parameters(), lr=0.001)

In [73]:
def train_data(optim, model, ptrs1, ptrs2, a, b, iteration=10):
    avg_loss = 0
    for i in range(iteration):
        optim.zero_grad()
        corr_mat = model(a, b, ptrs1, ptrs2)
        loss = model.loss(corr_mat)
        loss.backward()
        optim.step()
        avg_loss += loss.item() / iteration
    return avg_loss

In [74]:
EPOCH = 100
for epoch in range(EPOCH):
    for i in range(extra_modes):
        ptrs1, ptrs2 = dataset[i]
        loss = train_data(optim, model, ptrs1, ptrs2, 0, i+1, iteration=10)
        print(f'{i+1}: {loss}')

1: 20.670342636108398
2: 20.513302421569826
3: 20.697291374206543
4: 20.491935539245606
5: 20.57996234893799
6: 20.564365196228025
7: 20.529852294921874
1: 18.1136962890625
2: 19.811268997192386
3: 18.85637741088867
4: 18.44685001373291
5: 20.53530673980713
6: 18.720100593566894
7: 18.293576431274413
1: 18.70912055969238
2: 20.324320220947264
3: 18.219822883605957
4: 18.311013984680176
5: 20.028289794921875
6: 20.639617919921875
7: 23.701905822753904
1: 23.269095230102536
2: 19.96127395629883
3: 18.561805725097656
4: 18.915280151367188
5: 18.471296691894533
6: 20.687884521484374
7: 19.79105796813965
1: 18.36461143493652
2: 18.728527450561522
3: 17.342563247680662
4: 18.798563385009768
5: 19.51927337646484
6: 19.08689785003662
7: 18.580598449707033
1: 16.566430664062498
2: 19.69428310394287
3: 15.597719955444335
4: 17.401248550415037
5: 20.119186019897462
6: 18.44501152038574
7: 18.537429809570312
1: 15.533378314971925
2: 19.823841857910153
3: 17.020412158966064
4: 18.311787796020507
5:

2: 18.026879501342773
3: 15.708403301239013
4: 16.30507335662842
5: 19.401101684570314
6: 16.073073291778563
7: 15.74032974243164
1: 14.130953502655032
2: 17.247898864746094
3: 14.912924671173094
4: 16.25548963546753
5: 17.699782562255862
6: 15.76763744354248
7: 14.99561939239502
1: 14.827679634094237
2: 16.919832611083987
3: 14.614955043792724
4: 15.586952972412108
5: 18.167854499816894
6: 15.40659484863281
7: 14.587972831726074
1: 14.535999298095701
2: 16.87560615539551
3: 14.740825462341311
4: 15.474749660491945
5: 17.92597141265869
6: 15.673756122589111
7: 15.344693088531493
1: 14.249291706085208
2: 17.211737251281736
3: 14.829003047943115
4: 15.478881740570069
5: 18.49251079559326
6: 15.163601303100586
7: 14.740923309326172
1: 14.317812728881835
2: 17.426529121398925
3: 14.800101566314696
4: 15.442823410034181
5: 18.441172790527347
6: 15.50099458694458
7: 14.66395797729492
1: 14.845840454101562
2: 17.552258872985842
3: 15.302947616577145
4: 15.489180374145507
5: 18.3244873046875
6

In [75]:
orig_points = dataset.all_points()

In [76]:
import plotly.express as px
import pandas as pd

In [77]:
df = pd.DataFrame(np.concatenate(orig_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type', opacity=1)
fig.update_traces(marker={'size': 5})
fig.show()

In [78]:
pred_points = [model.run_evaluate(i, mapper(points)).detach().numpy() for i, points in enumerate(orig_points)]

In [79]:
df = pd.DataFrame(np.concatenate(pred_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [81]:
def comp_stable_rank(ptrs):
    corr = np.matmul(ptrs.T, ptrs)/len(ptrs)
    sing = np.linalg.svd(corr)[1][0]
    norm = np.linalg.norm(corr)
    return norm**2/sing**2
    
print(comp_stable_rank(pred_points[0]))
print(comp_stable_rank(pred_points[1]))

1.5675443307129318
1.3129170819667515


In [82]:
ptrs1 = dataset.points
ptrs2 = dataset.plane_points[2]
loss_total1 = 0
loss_total2 = 0
for i,ptr2 in enumerate(ptrs2):
    loss_total1 += cross_entropy(torch.sum((ptrs1 - ptr2[None,:])**2, dim=-1), torch.tensor(i).long())
print(loss_total1*2 / len(ptrs2))
for i,ptr1 in enumerate(ptrs1):
    loss_total2 += cross_entropy(torch.sum((ptr1[None,:] - ptrs2)**2, dim=-1), torch.tensor(i).long())
print(loss_total2*2 / len(ptrs1))
print((loss_total1+loss_total2)/len(ptrs2))

tensor(17.8656)
tensor(17.9695)
tensor(17.9176)


In [161]:
l1 = cross_entropy(torch.sum((ptrs1[:,None]-ptrs2[None,:])**2, dim=-1), torch.arange(len(ptrs1)).long())
l2 = cross_entropy(torch.sum((ptrs2[:,None]-ptrs1[None,:])**2, dim=-1), torch.arange(len(ptrs1)).long())
l1+l2

tensor(13.0398)

In [273]:
print(dataset.disks[0])

tensor([-0.3848,  0.8510, -0.3575])
