In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
class SphereDatasetAnchor(Dataset):
    def __init__(self, extra_modes, num_points, transform_fn=None):
        self.n = num_points
        self.m = extra_modes + 1
        self.points = np.random.uniform(-1,1,size=(self.n, 3))
        mags = np.linalg.norm(self.points, axis=-1)
        self.points /= mags[...,None]
        self.disks = np.array([np.random.uniform(-1,1,size=(3)) for i in range(extra_modes)])
        self.disks = np.array([disk / np.linalg.norm(disk) for disk in self.disks])
        self.disk_points = torch.tensor(np.array([self.project_points(self.points, disk) for disk in self.disks])).float()
        self.points = torch.tensor(self.points).float()
        self.disks = torch.tensor(self.disks).float()
        self.transform_fn = transform_fn
        
    def project_points(self, points, disk):
        output_points = []
        for point in points:
            cosval = np.dot(point, disk)
            arrow = (point*1/np.dot(point, disk) - disk) * cosval # mind the sign of cosval
            arrow /= np.linalg.norm(arrow)
            output_points.append(arrow)
        return np.array(output_points)
    
    def __len__(self):
        return 1
    
    def __getitem__(self, idx):
        all_points = [self.points] + list(self.disk_points)
        if self.transform_fn is not None:
            return [self.transform_fn.evaluate(ptrs) for ptrs in all_points]
        return all_points
    
    def all_points(self):
        return [self.points] + list(self.disk_points)

In [3]:
from torch.nn.functional import cross_entropy, relu
class Mapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 128)
        self.linear2 = nn.Linear(128, 128)
    
    def forward(self, X):
        return self.linear2(relu(self.linear(X)))

    def evaluate(self, X):
        self.eval()
        with torch.no_grad():
            return self.forward(X)

mapper = Mapper()

In [4]:
modes = 3
dataset = SphereDatasetAnchor(modes, 1000, transform_fn=mapper)

In [38]:
from torch.nn.functional import cross_entropy, relu
class ContrastiveModel(nn.Module):
    def __init__(self, modes):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(128, 128) for i in range(modes)])
        self.linears2 = nn.ModuleList([nn.Linear(128, 64) for i in range(modes)])
        self.linears3 = nn.ModuleList([nn.Linear(64, 64) for i in range(modes)])
        self.linears4 = nn.ModuleList([nn.Linear(64, 3) for i in range(modes)])
    
    def loss(self, corr_mats):
        losses_rows = torch.zeros((len(corr_mats)))
        for i, corr_mat in enumerate(corr_mats):
            losses_row = cross_entropy(corr_mat, torch.arange(0, len(corr_mat)).long()) # row is matching f_points[i] to correct o_points
            losses_rows[i] = losses_row
        return torch.mean(losses_rows)
    
    def compute_corr_dot(self, f_list_points):
        f_points = torch.stack(f_list_points, dim=0)
        corr_dots = []
        for i in range(len(f_points)):
            o_points = torch.sum(f_points, dim=0) - f_points[i]
            corr_dots.append(torch.tensordot(o_points, f_points[i], dims=([-1],[-1])))
        return corr_dots
    
    def forward(self, list_points):
        f_list_points = []
        for i, points in enumerate(list_points):
            f_list_points.append(self.run_evaluate(i, points))
        corr_mats = self.compute_corr_dot(f_list_points)
        return corr_mats
    
    def run_evaluate(self, i, points):
        #points = self.linears2[i](relu(self.linears[i](points)))
        points = self.linears4[i](relu(self.linears3[i](relu(self.linears2[i](relu(self.linears[i](points)))))))
        return points / torch.linalg.norm(points, dim=-1, keepdim=True)

In [39]:
model = ContrastiveModel(modes+1)

In [29]:
from torch.optim import Adam
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
optim = Adam(model.parameters(), lr=0.001)

In [30]:
def train_data(optim, model, list_points, iteration=10):
    avg_loss = 0
    for i in range(iteration):
        optim.zero_grad()
        corr_mats = model(list_points)
        loss = model.loss(corr_mats)
        loss.backward()
        optim.step()
        avg_loss += loss.item() / iteration
    return avg_loss

In [31]:
EPOCH = 1000
for epoch in range(EPOCH):
    list_points = dataset[0]
    loss = train_data(optim, model, list_points, iteration=1)
    print(f'{epoch}: {loss}')

0: 6.908141613006592
1: 6.907873153686523
2: 6.9045209884643555
3: 6.896583557128906
4: 6.880324363708496
5: 6.852261543273926
6: 6.8091325759887695
7: 6.748734474182129
8: 6.673169136047363
9: 6.59132194519043
10: 6.519070625305176
11: 6.465191841125488
12: 6.424980163574219
13: 6.393402576446533
14: 6.365603446960449
15: 6.334010601043701
16: 6.298844337463379
17: 6.266374588012695
18: 6.235842704772949
19: 6.206605911254883
20: 6.18183708190918
21: 6.159780979156494
22: 6.137234687805176
23: 6.1155595779418945
24: 6.093151569366455
25: 6.069911956787109
26: 6.049690246582031
27: 6.032480239868164
28: 6.0135979652404785
29: 5.987823009490967
30: 5.95727014541626
31: 5.929038047790527
32: 5.906407356262207
33: 5.888138294219971
34: 5.871938705444336
35: 5.856729030609131
36: 5.843364715576172
37: 5.832629203796387
38: 5.823418140411377
39: 5.814898490905762
40: 5.807919979095459
41: 5.80217170715332
42: 5.795605182647705
43: 5.786746501922607
44: 5.77678108215332
45: 5.767820358276367

364: 5.632171630859375
365: 5.632110118865967
366: 5.6320695877075195
367: 5.632039546966553
368: 5.632102966308594
369: 5.632328033447266
370: 5.632772445678711
371: 5.633606433868408
372: 5.634686470031738
373: 5.6352643966674805
374: 5.634904384613037
375: 5.633554935455322
376: 5.632777690887451
377: 5.632362365722656
378: 5.6321845054626465
379: 5.632050514221191
380: 5.632187843322754
381: 5.632474899291992
382: 5.632413864135742
383: 5.632183074951172
384: 5.631689548492432
385: 5.631577968597412
386: 5.631626129150391
387: 5.631529331207275
388: 5.631513595581055
389: 5.631486415863037
390: 5.631457328796387
391: 5.631204128265381
392: 5.630981922149658
393: 5.631073951721191
394: 5.6314697265625
395: 5.631860733032227
396: 5.632470607757568
397: 5.633353233337402
398: 5.634341239929199
399: 5.635819911956787
400: 5.637721538543701
401: 5.638526916503906
402: 5.637091636657715
403: 5.632980823516846
404: 5.632457733154297
405: 5.634527683258057
406: 5.63355827331543
407: 5.6316

720: 5.622567653656006
721: 5.622879981994629
722: 5.623010635375977
723: 5.6228532791137695
724: 5.623366355895996
725: 5.624063491821289
726: 5.625176906585693
727: 5.626447677612305
728: 5.629246234893799
729: 5.629909992218018
730: 5.630033493041992
731: 5.626777172088623
732: 5.624302864074707
733: 5.625540733337402
734: 5.626832008361816
735: 5.625769138336182
736: 5.623169898986816
737: 5.623245716094971
738: 5.624742031097412
739: 5.624435901641846
740: 5.623330116271973
741: 5.623225688934326
742: 5.6250386238098145
743: 5.625827789306641
744: 5.6251726150512695
745: 5.625236511230469
746: 5.624958038330078
747: 5.622622013092041
748: 5.621349334716797
749: 5.622284889221191
750: 5.623723030090332
751: 5.624419212341309
752: 5.623321533203125
753: 5.6228179931640625
754: 5.623019695281982
755: 5.624082565307617
756: 5.624161720275879
757: 5.625671863555908
758: 5.627268314361572
759: 5.62742280960083
760: 5.625041961669922
761: 5.622406482696533
762: 5.622635841369629
763: 5.6

In [32]:
orig_points = dataset.all_points()

In [33]:
import plotly.express as px
import pandas as pd

In [34]:
df = pd.DataFrame(np.concatenate(orig_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [35]:
pred_points = [model.run_evaluate(i, mapper(points)).detach().numpy() for i, points in enumerate(orig_points)]

In [36]:
df = pd.DataFrame(np.concatenate(pred_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [40]:
model.loss(model.compute_corr_dot(orig_points))

tensor(5.7615)

In [44]:
a = torch.tensor([[1,0,0]]).float()
b = torch.tensor([[0,1,0]]).float()
c = torch.tensor([[0,1/1.414,1/1.414]]).float()

model.compute_corr_dot([a,b,c])

[tensor([[0.]]), tensor([[0.7072]]), tensor([[0.7072]])]

In [151]:
b=3
points = torch.tensor(pred_points[0])
points = model.comps4[b-1](relu(model.comps3[b-1](relu(model.comps2[b-1](relu(model.comps[b-1](points)))))))
points /= torch.linalg.norm(points, dim=-1, keepdim=True)
new_points = [ptrs.copy() for ptrs in pred_points]
new_points[0] = points.detach()
df = pd.DataFrame(np.concatenate(new_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [64]:
print(torch.linalg.det(model.comps[2].weight.detach()))
print(model.comps[0].weight)

tensor(0.0486)
Parameter containing:
tensor([[ 0.4028,  0.3115, -0.4583],
        [ 0.4075, -0.2247, -0.2875],
        [ 0.5676, -0.5145, -0.4094]], requires_grad=True)


In [59]:
def comp_stable_rank(ptrs):
    corr = np.matmul(ptrs.T, ptrs)/len(ptrs)
    sing = np.linalg.svd(corr)[1][0]
    norm = np.linalg.norm(corr)
    return norm**2/sing**2
    
print(comp_stable_rank(pred_points[0]))
print(comp_stable_rank(pred_points[1]))

1.4307200390811623
1.8832110405250448


In [60]:
ptrs1 = dataset.points
ptrs2 = dataset.disk_points[1]
loss_total = 0
for i,ptr2 in enumerate(ptrs2):
    loss_total += cross_entropy(torch.sum(ptrs1 * ptr2[None,:], dim=-1), torch.tensor(i).long())
for i,ptr1 in enumerate(ptrs1):
    loss_total += cross_entropy(torch.sum(ptr1[None,:] * ptrs2, dim=-1), torch.tensor(i).long())
print(loss_total/len(ptrs2))

tensor(12.5444)


In [61]:
l1 = cross_entropy(torch.sum(ptrs1[:,None]*ptrs2[None,:], dim=-1), torch.arange(len(ptrs1)).long())
l2 = cross_entropy(torch.sum(ptrs2[:,None]*ptrs1[None,:], dim=-1), torch.arange(len(ptrs1)).long())
l1+l2

tensor(12.5444)

In [62]:
model.loss(model.compute_corr_dot(ptrs2, ptrs2))

tensor(12.2896)