In [154]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [155]:
class SphereDatasetAnchor(Dataset):
    def __init__(self, extra_modes, num_points, transform_fn=None):
        self.n = num_points
        self.m = extra_modes + 1
        self.points = np.random.uniform(-1,1,size=(self.n, 3))
        mags = np.linalg.norm(self.points, axis=-1)
        self.points /= mags[...,None]
        self.disks = np.array([np.random.uniform(-1,1,size=(3)) for i in range(extra_modes)])
        self.disks = np.array([disk / np.linalg.norm(disk) for disk in self.disks])
        self.disk_points = torch.tensor(np.array([self.project_points(self.points, disk) for disk in self.disks])).float()
        self.points = torch.tensor(self.points).float()
        self.disks = torch.tensor(self.disks).float()
        self.transform_fn = transform_fn
        
    def project_points(self, points, disk):
        output_points = []
        for point in points:
            cosval = np.dot(point, disk)
            arrow = (point*1/np.dot(point, disk) - disk) * cosval # mind the sign of cosval
            arrow /= np.linalg.norm(arrow)
            output_points.append(arrow)
        return np.array(output_points)
    
    def __len__(self):
        return self.m-1
    
    def __getitem__(self, idx):
        e_points = self.disk_points[idx]
        if self.transform_fn is not None:
            return self.transform_fn.evaluate(self.points), self.transform_fn.evaluate(e_points)
        return self.points, e_points
    
    def all_points(self):
        return [self.points] + list(self.disk_points)

In [156]:
from torch.nn.functional import cross_entropy, relu
class Mapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 128)
        self.linear2 = nn.Linear(128, 128)
    
    def forward(self, X):
        return self.linear2(relu(self.linear(X)))

    def evaluate(self, X):
        self.eval()
        with torch.no_grad():
            return self.forward(X)

mapper = Mapper()

In [157]:
modes = 3
dataset = SphereDatasetAnchor(modes, 1000, transform_fn=mapper)

In [161]:
from torch.nn.functional import cross_entropy, relu
class ContrastiveModel(nn.Module):
    def __init__(self, modes):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(128, 128) for i in range(modes)])
        self.linears2 = nn.ModuleList([nn.Linear(128, 64) for i in range(modes)])
        self.linears3 = nn.ModuleList([nn.Linear(64, 64) for i in range(modes)])
        self.linears4 = nn.ModuleList([nn.Linear(64, 3) for i in range(modes)])
        self.comps = nn.ParameterList([nn.Parameter(torch.zeros(3,3)) for i in range(modes-1)])
        for comp in self.comps:
            nn.init.xavier_uniform_(comp)
    
    def loss(self, corr_mat):
        losses_row = cross_entropy(corr_mat, torch.arange(0, len(corr_mat)).long())
        losses_col = cross_entropy(torch.transpose(corr_mat,0,1), torch.arange(0, len(corr_mat)).long())
        return torch.mean(losses_row) + torch.mean(losses_col)
    
    def compute_corr_dot(self, a_points, b_points):
        return torch.sum(a_points[:,None] * b_points[None,:], dim=-1)
    
    def compute_corr_func(self, a_points, b_points, b):
        return self.comps4[b-1](relu(self.comps3[b-1](relu(self.comps2[b-1](relu(self.comps[b-1](a_points[:,None] * b_points[None,:])))))))[...,0]
        #return self.compute_corr_dot(a_points_mapped, b_points)
    
    def forward(self, a, b, a_points, b_points):
        #a_points = self.linears2[a](relu(self.linears[a](a_points)))
        #b_points = self.linears2[b](relu(self.linears[b](b_points)))
        a_points = self.linears4[a](relu(self.linears3[a](relu(self.linears2[a](relu(self.linears[a](a_points)))))))
        b_points = self.linears4[b](relu(self.linears3[b](relu(self.linears2[b](relu(self.linears[b](b_points)))))))
        a_points = torch.tensordot(a_points, (self.comps[b-1] + self.comps[b-1].T)/2, dims=1)
        a_points = a_points / torch.linalg.norm(a_points, dim=-1, keepdim=True)
        b_points = b_points / torch.linalg.norm(b_points, dim=-1, keepdim=True)
        corr_mat = self.compute_corr_dot(a_points, b_points)
        return corr_mat
    
    def run_evaluate(self, i, points):
        #points = self.linears2[i](relu(self.linears[i](points)))
        points = self.linears4[i](relu(self.linears3[i](relu(self.linears2[i](relu(self.linears[i](points)))))))
        return points / torch.linalg.norm(points, dim=-1, keepdim=True)

In [162]:
model = ContrastiveModel(modes+1)

In [163]:
from torch.optim import Adam
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
optim = Adam(model.parameters(), lr=0.0001)

In [164]:
def train_data(optim, model, ptrs1, ptrs2, a, b, iteration=10):
    avg_loss = 0
    for i in range(iteration):
        optim.zero_grad()
        corr_mat = model(a, b, ptrs1, ptrs2)
        loss = model.loss(corr_mat)
        loss.backward()
        optim.step()
        avg_loss += loss.item() / iteration
    return avg_loss

In [165]:
EPOCH = 1000
for epoch in range(EPOCH):
    for i in range(modes):
        ptrs1, ptrs2 = dataset[i]
        loss = train_data(optim, model, ptrs1, ptrs2, 0, i+1, iteration=1)
        print(f'{i+1}: {loss}')

1: 13.817824363708496
2: 13.817174911499023
3: 13.821137428283691
1: 13.811138153076172
2: 13.815933227539062
3: 13.817567825317383
1: 13.806256294250488
2: 13.815321922302246
3: 13.816680908203125
1: 13.799189567565918
2: 13.814983367919922
3: 13.816604614257812
1: 13.788528442382812
2: 13.814704895019531
3: 13.81640911102295
1: 13.773563385009766
2: 13.814369201660156
3: 13.815901756286621
1: 13.753643989562988
2: 13.813921928405762
3: 13.815116882324219
1: 13.728053092956543
2: 13.813328742980957
3: 13.814035415649414
1: 13.695987701416016
2: 13.812551498413086
3: 13.812616348266602
1: 13.656482696533203
2: 13.811542510986328
3: 13.810615539550781
1: 13.609369277954102
2: 13.81025505065918
3: 13.807601928710938
1: 13.555523872375488
2: 13.808631896972656
3: 13.803366661071777
1: 13.495311737060547
2: 13.80661392211914
3: 13.79775619506836
1: 13.428621292114258
2: 13.804121971130371
3: 13.790226936340332
1: 13.356083869934082
2: 13.80104923248291
3: 13.77995777130127
1: 13.2788963317

1: 12.376398086547852
2: 12.322245597839355
3: 12.335052490234375
1: 12.37503433227539
2: 12.322042465209961
3: 12.334497451782227
1: 12.373738288879395
2: 12.32183837890625
3: 12.333976745605469
1: 12.372493743896484
2: 12.32163143157959
3: 12.333499908447266
1: 12.371295928955078
2: 12.321435928344727
3: 12.333049774169922
1: 12.370138168334961
2: 12.321245193481445
3: 12.332642555236816
1: 12.369016647338867
2: 12.32105827331543
3: 12.332244873046875
1: 12.367959976196289
2: 12.320862770080566
3: 12.331854820251465
1: 12.366963386535645
2: 12.320669174194336
3: 12.331450462341309
1: 12.366031646728516
2: 12.320487976074219
3: 12.331046104431152
1: 12.36513900756836
2: 12.320317268371582
3: 12.330671310424805
1: 12.364248275756836
2: 12.320160865783691
3: 12.330327033996582
1: 12.363374710083008
2: 12.32000732421875
3: 12.329986572265625
1: 12.362552642822266
2: 12.319847106933594
3: 12.329635620117188
1: 12.361780166625977
2: 12.319682121276855
3: 12.32928466796875
1: 12.36103820800

2: 12.307126998901367
3: 12.313638687133789
1: 12.329197883605957
2: 12.307037353515625
3: 12.313578605651855
1: 12.32900619506836
2: 12.306970596313477
3: 12.313465118408203
1: 12.328826904296875
2: 12.306905746459961
3: 12.313436508178711
1: 12.328627586364746
2: 12.306842803955078
3: 12.313299179077148
1: 12.328483581542969
2: 12.306742668151855
3: 12.31329345703125
1: 12.328289031982422
2: 12.306675910949707
3: 12.313150405883789
1: 12.328140258789062
2: 12.306612014770508
3: 12.313186645507812
1: 12.327939987182617
2: 12.306567192077637
3: 12.312997817993164
1: 12.32783031463623
2: 12.306475639343262
3: 12.313161849975586
1: 12.327625274658203
2: 12.306427001953125
3: 12.313000679016113
1: 12.327526092529297
2: 12.306373596191406
3: 12.313512802124023
1: 12.327259063720703
2: 12.30640697479248
3: 12.313420295715332
1: 12.327187538146973
2: 12.306438446044922
3: 12.314796447753906
1: 12.326875686645508
2: 12.306482315063477
3: 12.315025329589844
1: 12.326898574829102
2: 12.30686569

3: 12.304712295532227
1: 12.311797142028809
2: 12.299585342407227
3: 12.304856300354004
1: 12.311624526977539
2: 12.299578666687012
3: 12.304685592651367
1: 12.311612129211426
2: 12.299525260925293
3: 12.304966926574707
1: 12.311424255371094
2: 12.29953384399414
3: 12.304819107055664
1: 12.3114652633667
2: 12.299478530883789
3: 12.305376052856445
1: 12.311267852783203
2: 12.299515724182129
3: 12.305262565612793
1: 12.311422348022461
2: 12.299505233764648
3: 12.306350708007812
1: 12.311223983764648
2: 12.299566268920898
3: 12.306342124938965
1: 12.311548233032227
2: 12.29971694946289
3: 12.308572769165039
1: 12.311445236206055
2: 12.299764633178711
3: 12.308568954467773
1: 12.312137603759766
2: 12.300086975097656
3: 12.312463760375977
1: 12.312286376953125
2: 12.299667358398438
3: 12.311286926269531
1: 12.313457489013672
2: 12.299867630004883
3: 12.314523696899414
1: 12.313727378845215
2: 12.299107551574707
3: 12.310199737548828
1: 12.314925193786621
2: 12.29893684387207
3: 12.309385299

3: 12.29987907409668
1: 12.300886154174805
2: 12.295796394348145
3: 12.299457550048828
1: 12.301294326782227
2: 12.295554161071777
3: 12.299578666687012
1: 12.301067352294922
2: 12.29544448852539
3: 12.298885345458984
1: 12.30115795135498
2: 12.295364379882812
3: 12.298802375793457
1: 12.300678253173828
2: 12.295373916625977
3: 12.29852294921875
1: 12.300520896911621
2: 12.295382499694824
3: 12.298477172851562
1: 12.300230026245117
2: 12.295385360717773
3: 12.29849910736084
1: 12.300139427185059
2: 12.295382499694824
3: 12.298433303833008
1: 12.30010986328125
2: 12.29532241821289
3: 12.298523902893066
1: 12.300025939941406
2: 12.295294761657715
3: 12.298377990722656
1: 12.300056457519531
2: 12.295236587524414
3: 12.29841423034668
1: 12.299932479858398
2: 12.295204162597656
3: 12.298298835754395
1: 12.299924850463867
2: 12.295162200927734
3: 12.29831600189209
1: 12.299785614013672
2: 12.29515266418457
3: 12.298210144042969
1: 12.29975700378418
2: 12.295140266418457
3: 12.298173904418945

3: 12.301007270812988
1: 12.29802417755127
2: 12.294757843017578
3: 12.30660629272461
1: 12.298992156982422
2: 12.294140815734863
3: 12.305095672607422
1: 12.30174446105957
2: 12.294275283813477
3: 12.308489799499512
1: 12.303033828735352
2: 12.293240547180176
3: 12.300978660583496
1: 12.306058883666992
2: 12.293078422546387
3: 12.298169136047363
1: 12.30050277709961
2: 12.293601989746094
3: 12.295159339904785
1: 12.295799255371094
2: 12.293898582458496
3: 12.294868469238281
1: 12.293590545654297
2: 12.29378604888916
3: 12.296476364135742
1: 12.294256210327148
2: 12.293907165527344
3: 12.29639720916748
1: 12.295243263244629
2: 12.29344654083252
3: 12.29629898071289
1: 12.294839859008789
2: 12.29322338104248
3: 12.294885635375977
1: 12.294923782348633
2: 12.29314136505127
3: 12.29449462890625
1: 12.294198036193848
2: 12.293118476867676
3: 12.294499397277832
1: 12.293785095214844
2: 12.293079376220703
3: 12.294536590576172
1: 12.293645858764648
2: 12.293037414550781
3: 12.294744491577148

2: 12.291534423828125
3: 12.292292594909668
1: 12.290436744689941
2: 12.29151725769043
3: 12.292272567749023
1: 12.290422439575195
2: 12.291505813598633
3: 12.292253494262695
1: 12.29040813446045
2: 12.291496276855469
3: 12.29223346710205
1: 12.290390014648438
2: 12.291487693786621
3: 12.29222297668457
1: 12.290369033813477
2: 12.29148006439209
3: 12.292203903198242
1: 12.290355682373047
2: 12.291473388671875
3: 12.292177200317383
1: 12.290343284606934
2: 12.291459083557129
3: 12.292165756225586
1: 12.290327072143555
2: 12.291443824768066
3: 12.292155265808105
1: 12.290307998657227
2: 12.2914400100708
3: 12.292131423950195
1: 12.290292739868164
2: 12.291439056396484
3: 12.292108535766602
1: 12.290281295776367
2: 12.291435241699219
3: 12.29208755493164
1: 12.290266036987305
2: 12.291418075561523
3: 12.292078018188477
1: 12.290243148803711
2: 12.291397094726562
3: 12.292088508605957
1: 12.290223121643066
2: 12.291379928588867
3: 12.292074203491211
1: 12.290206909179688
2: 12.291378974914

3: 12.290701866149902
1: 12.289148330688477
2: 12.290483474731445
3: 12.290684700012207
1: 12.289140701293945
2: 12.29047966003418
3: 12.290674209594727
1: 12.289133071899414
2: 12.290468215942383
3: 12.290672302246094
1: 12.289127349853516
2: 12.290460586547852
3: 12.290658950805664
1: 12.289121627807617
2: 12.290456771850586
3: 12.290641784667969
1: 12.289115905761719
2: 12.29045295715332
3: 12.290634155273438
1: 12.28911018371582
2: 12.290447235107422
3: 12.290621757507324
1: 12.289101600646973
2: 12.29043960571289
3: 12.290614128112793
1: 12.28909683227539
2: 12.290428161621094
3: 12.290609359741211
1: 12.289093017578125
2: 12.290410995483398
3: 12.290603637695312
1: 12.289087295532227
2: 12.290396690368652
3: 12.29060173034668
1: 12.289082527160645
2: 12.290390014648438
3: 12.29058837890625
1: 12.289073944091797
2: 12.290390968322754
3: 12.290570259094238
1: 12.289068222045898
2: 12.290390968322754
3: 12.290552139282227
1: 12.2890625
2: 12.290385246276855
3: 12.290546417236328
1: 

In [166]:
orig_points = dataset.all_points()

In [167]:
import plotly.express as px
import pandas as pd

In [168]:
df = pd.DataFrame(np.concatenate(orig_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [169]:
pred_points = [model.run_evaluate(i, mapper(points)).detach().numpy() for i, points in enumerate(orig_points)]

In [170]:
df = pd.DataFrame(np.concatenate(pred_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [175]:
model.loss(model.compute_corr_dot(orig_points[1], orig_points[1]))

tensor(12.2903)

In [176]:
a = (model.comps[0]+model.comps[0].T)/2
print(a)
torch.linalg.eig(a)

tensor([[-0.0159,  0.0858, -0.0777],
        [ 0.0858, -0.5584,  0.2916],
        [-0.0777,  0.2916, -0.4733]], grad_fn=<DivBackward0>)


torch.return_types.linalg_eig(
eigenvalues=tensor([-8.2703e-01+0.j,  6.3556e-04+0.j, -2.2115e-01+0.j],
       grad_fn=<LinalgEigBackward0>),
eigenvectors=tensor([[ 0.1412+0.j, -0.9899+0.j,  0.0116+0.j],
        [-0.7485+0.j, -0.0991+0.j,  0.6557+0.j],
        [ 0.6480+0.j,  0.1013+0.j,  0.7549+0.j]], grad_fn=<LinalgEigBackward0>))

In [179]:
b=1
points = torch.tensor(pred_points[0])
points = torch.tensordot(points, (model.comps[b-1] + model.comps[b-1].T)/2, dims=1)
points /= torch.linalg.norm(points, dim=-1, keepdim=True)
new_points = [ptrs.copy() for ptrs in pred_points]
new_points[0] = points.detach()
df = pd.DataFrame(np.concatenate(new_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [59]:
def comp_stable_rank(ptrs):
    corr = np.matmul(ptrs.T, ptrs)/len(ptrs)
    sing = np.linalg.svd(corr)[1][0]
    norm = np.linalg.norm(corr)
    return norm**2/sing**2
    
print(comp_stable_rank(pred_points[0]))
print(comp_stable_rank(pred_points[1]))

1.4307200390811623
1.8832110405250448


In [60]:
ptrs1 = dataset.points
ptrs2 = dataset.disk_points[1]
loss_total = 0
for i,ptr2 in enumerate(ptrs2):
    loss_total += cross_entropy(torch.sum(ptrs1 * ptr2[None,:], dim=-1), torch.tensor(i).long())
for i,ptr1 in enumerate(ptrs1):
    loss_total += cross_entropy(torch.sum(ptr1[None,:] * ptrs2, dim=-1), torch.tensor(i).long())
print(loss_total/len(ptrs2))

tensor(12.5444)


In [61]:
l1 = cross_entropy(torch.sum(ptrs1[:,None]*ptrs2[None,:], dim=-1), torch.arange(len(ptrs1)).long())
l2 = cross_entropy(torch.sum(ptrs2[:,None]*ptrs1[None,:], dim=-1), torch.arange(len(ptrs1)).long())
l1+l2

tensor(12.5444)

In [62]:
model.loss(model.compute_corr_dot(ptrs2, ptrs2))

tensor(12.2896)