In [2]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [3]:
class CubeDatasetAnchor(Dataset):
    def __init__(self, extra_modes, num_points, transform_fn=None):
        self.n = num_points
        self.m = extra_modes + 1
        self.points = np.random.uniform(-1,1,size=(self.n, 3))
        self.directions = np.random.uniform(-1,1,size=(extra_modes, 3))
        self.directions /= np.linalg.norm(self.directions, axis=-1)[:, None]
        self.plane_points = [self.project_points(self.points, self.directions[i]) for i in range(extra_modes)]
        self.points = torch.tensor(self.points).float()
        self.plane_points = [torch.tensor(points).float() for points in self.plane_points]
        self.transform_fn = transform_fn
    
    def project_points(self, points, direction):
        output_points = []
        for point in points:
            output_points.append(point - np.dot(point, direction) * direction)
        return np.array(output_points)
        
    def __len__(self):
        return self.m-1
    
    def __getitem__(self, idx):
        e_points = self.plane_points[idx]
        if self.transform_fn is not None:
            return self.transform_fn.evaluate(self.points), self.transform_fn.evaluate(e_points)
        return self.points, e_points
    
    def all_points(self):
        return [self.points] + list(self.plane_points)[:self.m-1]

In [4]:
from torch.nn.functional import cross_entropy, relu
class Mapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 128)
        self.linear2 = nn.Linear(128, 128)
    
    def forward(self, X):
        return self.linear2(relu(self.linear(X)))

    def evaluate(self, X):
        self.eval()
        with torch.no_grad():
            return self.forward(X)

mapper = Mapper()

In [5]:
extra_modes = 3
dataset = CubeDatasetAnchor(extra_modes, 1000, transform_fn=mapper)

In [109]:
from torch.nn.functional import cross_entropy, relu
class ContrastiveModel(nn.Module):
    def __init__(self, modes):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(128, 128) for i in range(modes)])
        self.linears2 = nn.ModuleList([nn.Linear(128, 64) for i in range(modes)])
        self.linears3 = nn.ModuleList([nn.Linear(64, 64) for i in range(modes)])
        self.linears4 = nn.ModuleList([nn.Linear(64, 3) for i in range(modes)])
        self.comps = nn.ModuleList([nn.Linear(3, 3, bias=False) for i in range(modes-1)])
        self.metrics = nn.ParameterList([nn.Parameter(torch.zeros(3,3)) for i in range(modes-1)])
        for metric in self.metrics:
            nn.init.xavier_uniform_(metric)
    
    def loss(self, corr_mat):
        #print(corr_mat)
        losses_row = cross_entropy(corr_mat, torch.arange(0, len(corr_mat)).long())
        losses_col = cross_entropy(torch.transpose(corr_mat,0,1), torch.arange(0, len(corr_mat)).long())
        return torch.mean(losses_row) + torch.mean(losses_col)
    
    def compute_corr_dot(self, a_points, b_points):
        return torch.sum(a_points[:,None,:] * b_points[None,:,:], dim=-1)
    
    def compute_corr_dist(self, a_points, b_points):
        return -torch.sum((a_points[:,None,:] - b_points[None,:,:])**2, dim=-1)
    
    def compute_corr_metric(self, a_points, b_points, b):
        diff = a_points[:,None,:] - b_points[None,:,:]
        metric = (self.metrics[b-1] + self.metrics[b-1].T)/2
        return -torch.sum(torch.tensordot(diff, metric, dims=([-1],[0])) * diff, dim=-1)
    
    def compute_corr_func(self, a_points, b_points, b):
        return -torch.sum((self.comps[b-1](a_points[:,None,:] - b_points[None,:,:]) + a_points[:,None,:] - b_points[None,:,:])**2, dim=-1)
    
    def normalize(self, points):
        return points / torch.sqrt(torch.mean(torch.sum(points**2, dim=-1), dim=0))
    
    def forward(self, a, b, a_points, b_points):
        #a_points = self.linears[a](a_points)
        #b_points = self.linears[b](b_points)
        #a_points = self.linears2[a](relu(self.linears[a](a_points)))
        #b_points = self.linears2[b](relu(self.linears[b](b_points)))
        a_points = self.run_evaluate(a, a_points)
        b_points = self.run_evaluate(b, b_points)
        #a_points = self.comps[b-1](a_points) #+ a_points 
        #a_points = self.normalize(a_points)
        corr_mat = self.compute_corr_metric(a_points, b_points, b)
        return corr_mat
    
    def run_evaluate(self, i, points):
        #return self.linears[i](points)
        #return self.linears2[i](relu(self.linears[i](points)))
        points = self.linears4[i](relu(self.linears3[i](relu(self.linears2[i](relu(self.linears[i](points)))))))
        points = self.normalize(points)
        return points

In [110]:
model = ContrastiveModel(extra_modes+1)

In [111]:
from torch.optim import Adam, SGD
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

print([name for name, param in model.named_parameters()])
#model.linears.require_grad = False
optim = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

['linears.0.weight', 'linears.0.bias', 'linears.1.weight', 'linears.1.bias', 'linears.2.weight', 'linears.2.bias', 'linears.3.weight', 'linears.3.bias', 'linears2.0.weight', 'linears2.0.bias', 'linears2.1.weight', 'linears2.1.bias', 'linears2.2.weight', 'linears2.2.bias', 'linears2.3.weight', 'linears2.3.bias', 'linears3.0.weight', 'linears3.0.bias', 'linears3.1.weight', 'linears3.1.bias', 'linears3.2.weight', 'linears3.2.bias', 'linears3.3.weight', 'linears3.3.bias', 'linears4.0.weight', 'linears4.0.bias', 'linears4.1.weight', 'linears4.1.bias', 'linears4.2.weight', 'linears4.2.bias', 'linears4.3.weight', 'linears4.3.bias', 'comps.0.weight', 'comps.1.weight', 'comps.2.weight', 'metrics.0', 'metrics.1', 'metrics.2']


In [112]:
def train_data(optim, model, ptrs1, ptrs2, a, b, iteration=10):
    avg_loss = 0
    for i in range(iteration):
        optim.zero_grad()
        corr_mat = model(a, b, ptrs1, ptrs2)
        loss = model.loss(corr_mat)
        loss.backward()
        optim.step()
        avg_loss += loss.item() / iteration
    return avg_loss

In [None]:
EPOCH = 1000
for epoch in range(EPOCH):
    for i in range(extra_modes):
        ptrs1, ptrs2 = dataset[i]
        loss = train_data(optim, model, ptrs1, ptrs2, 0, i+1, iteration=1)
        print(f'{i+1}: {loss}')

1: 13.815582275390625
2: 13.825799942016602
3: 13.81922721862793
1: 13.817854881286621
2: 13.819601058959961
3: 13.810500144958496
1: 13.766717910766602
2: 13.807927131652832
3: 13.769187927246094
1: 13.67255973815918
2: 13.776327133178711
3: 13.674503326416016
1: 13.452216148376465
2: 13.688179969787598
3: 13.46382999420166
1: 13.159002304077148
2: 13.545013427734375
3: 13.222200393676758
1: 13.13957405090332
2: 13.374258041381836
3: 13.118797302246094
1: 13.007994651794434
2: 13.20467758178711
3: 13.101990699768066
1: 12.902938842773438
2: 13.080304145812988
3: 13.117721557617188
1: 12.767486572265625
2: 12.951263427734375
3: 13.116846084594727
1: 12.603142738342285
2: 12.836276054382324
3: 13.126883506774902
1: 12.42593765258789
2: 12.678213119506836
3: 13.158319473266602
1: 12.18881607055664
2: 12.498350143432617
3: 13.193496704101562
1: 11.932104110717773
2: 12.360921859741211
3: 13.21630573272705
1: 11.805194854736328
2: 12.598522186279297
3: 13.187808990478516
1: 12.184955596923

3: 11.904345512390137
1: 11.24935531616211
2: 11.426614761352539
3: 11.901272773742676
1: 11.245126724243164
2: 11.437214851379395
3: 11.904255867004395
1: 11.24344253540039
2: 11.482330322265625
3: 11.905763626098633
1: 11.23965835571289
2: 11.578828811645508
3: 11.904499053955078
1: 11.28647518157959
2: 11.742210388183594
3: 11.902436256408691
1: 11.30199146270752
2: 11.727712631225586
3: 11.941814422607422
1: 11.279847145080566
2: 11.558710098266602
3: 11.937688827514648
1: 11.223688125610352
2: 11.444042205810547
3: 11.902694702148438
1: 11.209322929382324
2: 11.450578689575195
3: 11.860300064086914
1: 11.229816436767578
2: 11.472774505615234
3: 11.88313102722168
1: 11.201070785522461
2: 11.45378303527832
3: 11.879005432128906
1: 11.170109748840332
2: 11.427938461303711
3: 11.84475326538086
1: 11.172520637512207
2: 11.436473846435547
3: 11.840144157409668
1: 11.17291259765625
2: 11.441990852355957
3: 11.838129997253418
1: 11.167181015014648
2: 11.436162948608398
3: 11.8231582641601

2: 11.427602767944336
3: 11.268407821655273
1: 10.529494285583496
2: 11.426851272583008
3: 11.26365852355957
1: 10.52676010131836
2: 11.427492141723633
3: 11.261115074157715
1: 10.522124290466309
2: 11.4278564453125
3: 11.258978843688965
1: 10.520273208618164
2: 11.428378105163574
3: 11.258966445922852
1: 10.520572662353516
2: 11.431535720825195
3: 11.261317253112793
1: 10.525895118713379
2: 11.437776565551758
3: 11.27003002166748
1: 10.541969299316406
2: 11.452780723571777
3: 11.286605834960938
1: 10.573677062988281
2: 11.477828979492188
3: 11.314764976501465
1: 10.625011444091797
2: 11.518074035644531
3: 11.32695484161377
1: 10.649147033691406
2: 11.550352096557617
3: 11.305625915527344
1: 10.623873710632324
2: 11.536335945129395
3: 11.26637077331543
1: 10.543294906616211
2: 11.484737396240234
3: 11.230844497680664
1: 10.50993537902832
2: 11.426511764526367
3: 11.235492706298828
1: 10.509273529052734
2: 11.41183853149414
3: 11.238075256347656
1: 10.519952774047852
2: 11.4229526519775

3: 10.920360565185547
1: 10.17162799835205
2: 11.329629898071289
3: 10.91145133972168
1: 10.16290283203125
2: 11.331653594970703
3: 10.903696060180664
1: 10.162824630737305
2: 11.332806587219238
3: 10.89970588684082
1: 10.155391693115234
2: 11.331920623779297
3: 10.896918296813965
1: 10.155454635620117
2: 11.329690933227539
3: 10.893329620361328
1: 10.151058197021484
2: 11.327095031738281
3: 10.891443252563477
1: 10.15100383758545
2: 11.324311256408691
3: 10.888761520385742
1: 10.147829055786133
2: 11.32237720489502
3: 10.887982368469238
1: 10.147942543029785
2: 11.320629119873047
3: 10.887364387512207
1: 10.145002365112305
2: 11.320403099060059
3: 10.887012481689453
1: 10.14703369140625
2: 11.321109771728516
3: 10.888082504272461
1: 10.147451400756836
2: 11.323799133300781
3: 10.891090393066406
1: 10.158710479736328
2: 11.328590393066406
3: 10.899396896362305
1: 10.173839569091797
2: 11.340137481689453
3: 10.910801887512207
1: 10.215269088745117
2: 11.356733322143555
3: 10.92604446411

1: 9.924386978149414
2: 11.222343444824219
3: 10.64155101776123
1: 9.922683715820312
2: 11.221854209899902
3: 10.639104843139648
1: 9.920820236206055
2: 11.220182418823242
3: 10.637971878051758
1: 9.916656494140625
2: 11.215774536132812
3: 10.633161544799805
1: 9.914573669433594
2: 11.213322639465332
3: 10.63214111328125
1: 9.908851623535156
2: 11.212721824645996
3: 10.629945755004883
1: 9.907550811767578
2: 11.212762832641602
3: 10.62768268585205
1: 9.904114723205566
2: 11.211480140686035
3: 10.626256942749023
1: 9.903158187866211
2: 11.210006713867188
3: 10.623222351074219
1: 9.900616645812988
2: 11.208273887634277
3: 10.621541023254395
1: 9.899293899536133
2: 11.206470489501953
3: 10.619346618652344
1: 9.896759986877441
2: 11.205246925354004
3: 10.618074417114258
1: 9.894403457641602
2: 11.20456314086914
3: 10.616207122802734
1: 9.892578125
2: 11.203144073486328
3: 10.614461898803711
1: 9.8909912109375
2: 11.202265739440918
3: 10.612470626831055
1: 9.889360427856445
2: 11.2009906768

3: 10.512893676757812
1: 9.753288269042969
2: 11.169742584228516
3: 10.486844062805176
1: 9.740541458129883
2: 11.125041961669922
3: 10.453739166259766
1: 9.733064651489258
2: 11.092784881591797
3: 10.425652503967285
1: 9.735682487487793
2: 11.087289810180664
3: 10.421930313110352
1: 9.72441291809082
2: 11.097868919372559
3: 10.419604301452637
1: 9.723808288574219
2: 11.10270881652832
3: 10.415645599365234
1: 9.717077255249023
2: 11.102678298950195
3: 10.416040420532227
1: 9.719450950622559
2: 11.096467971801758
3: 10.415063858032227
1: 9.72041130065918
2: 11.091634750366211
3: 10.415542602539062
1: 9.724451065063477
2: 11.08590316772461
3: 10.412148475646973
1: 9.72561264038086
2: 11.082780838012695
3: 10.411178588867188
1: 9.727187156677246
2: 11.080739974975586
3: 10.409601211547852
1: 9.725704193115234
2: 11.079961776733398
3: 10.409099578857422
1: 9.727693557739258
2: 11.080026626586914
3: 10.40885066986084
1: 9.724498748779297
2: 11.08356761932373
3: 10.405679702758789
1: 9.72841

In [None]:
orig_points = dataset.all_points()

In [None]:
import plotly.express as px
import pandas as pd

In [None]:
df = pd.DataFrame(np.concatenate(orig_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type', opacity=1)
fig.update_traces(marker={'size': 5})
fig.show()

In [None]:
pred_points = [model.run_evaluate(i, mapper(points)).detach().numpy() for i, points in enumerate(orig_points)]

In [None]:
df = pd.DataFrame(np.concatenate(pred_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [None]:
o_ptrs0 = torch.tensor(pred_points[1])
o_ptrs1 = torch.tensor(pred_points[1])
print([metric for metric in model.metrics])
model.loss(model.compute_corr_metric(o_ptrs0, o_ptrs1,1))

In [None]:
print(torch.linalg.det(model.comps[0].weight.detach()))
print(model.comps[0].weight)

In [84]:
points = pred_points[0]
points = model.comps[2](torch.tensor(points)) #+ torch.tensor(points)
points = model.normalize(points)
new_points = [ptrs.copy() for ptrs in pred_points]
new_points[0] = points.detach()
df = pd.DataFrame(np.concatenate(new_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [44]:
class CompressModel(nn.Module):
    def __init__(self, modes):
        super().__init__()
        self.comps = nn.ModuleList([nn.Linear(3, 3, bias=False) for i in range(modes-1)])
    
    def loss(self, corr_mat):
        #print(corr_mat)
        losses_row = cross_entropy(corr_mat, torch.arange(0, len(corr_mat)).long())
        losses_col = cross_entropy(torch.transpose(corr_mat,0,1), torch.arange(0, len(corr_mat)).long())
        return torch.mean(losses_row) + torch.mean(losses_col)
    
    def compute_corr_dot(self, a_points, b_points):
        return torch.sum(a_points[:,None,:] * b_points[None,:,:], dim=-1)
    
    def compute_corr_dist(self, a_points, b_points):
        return -torch.sqrt(torch.sum((a_points[:,None,:] - b_points[None,:,:])**2, dim=-1))
    
    def compute_corr_func(self, a_points, b_points, b):
        return -torch.sqrt(torch.sum((self.comps[b-1](a_points[:,None,:] - b_points[None,:,:]) + a_points[:,None,:] - b_points[None,:,:])**2, dim=-1))
    
    def forward(self, a, b, a_points, b_points):
        #a_points = self.linears[a](a_points)
        #b_points = self.linears[b](b_points)
        #a_points = self.linears2[a](relu(self.linears[a](a_points)))
        #b_points = self.linears2[b](relu(self.linears[b](b_points)))
        a_points = self.comps[b-1](a_points) #+ a_points
        corr_mat = self.compute_corr_dist(a_points, b_points)
        return corr_mat

In [45]:
model = CompressModel(extra_modes+1)
optim = SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)
def train_data(optim, model, ptrs1, ptrs2, a, b, iteration=10):
    avg_loss = 0
    for i in range(iteration):
        optim.zero_grad()
        corr_mat = model(a, b, ptrs1, ptrs2)
        loss = model.loss(corr_mat)
        loss.backward()
        optim.step()
        avg_loss += loss.item() / iteration
    return avg_loss

EPOCH = 1000
for epoch in range(EPOCH):
    for i in range(extra_modes):
        ptrs1 = dataset.points
        ptrs2 = dataset.plane_points[i]
        loss = train_data(optim, model, ptrs1, ptrs2, 0, i+1, iteration=1)
        print(f'{i+1}: {loss}')

1: 14.277408599853516
2: 13.450014114379883
3: 13.608136177062988
1: 14.27114200592041
2: 13.439750671386719
3: 13.597736358642578
1: 14.264841079711914
2: 13.42944622039795
3: 13.587303161621094
1: 14.258506774902344
2: 13.419105529785156
3: 13.576837539672852
1: 14.252137184143066
2: 13.40872859954834
3: 13.566341400146484
1: 14.245733261108398
2: 13.39831256866455
3: 13.55581283569336
1: 14.239294052124023
2: 13.387859344482422
3: 13.545255661010742
1: 14.232818603515625
2: 13.37736988067627
3: 13.534669876098633
1: 14.226308822631836
2: 13.366841316223145
3: 13.524054527282715
1: 14.219764709472656
2: 13.35627555847168
3: 13.51341438293457
1: 14.213183403015137
2: 13.345674514770508
3: 13.502748489379883
1: 14.20656681060791
2: 13.335037231445312
3: 13.492055892944336
1: 14.199913024902344
2: 13.324363708496094
3: 13.481340408325195
1: 14.193222045898438
2: 13.313652038574219
3: 13.470603942871094
1: 14.186495780944824
2: 13.30290412902832
3: 13.459844589233398
1: 14.17973327636718

1: 13.168724060058594
2: 12.069124221801758
3: 12.45132827758789
1: 13.15747356414795
2: 12.062813758850098
3: 12.444202423095703
1: 13.146183013916016
2: 12.056674003601074
3: 12.437091827392578
1: 13.13485050201416
2: 12.05069637298584
3: 12.42999267578125
1: 13.123476028442383
2: 12.04487419128418
3: 12.422906875610352
1: 13.112061500549316
2: 12.039199829101562
3: 12.415836334228516
1: 13.100604057312012
2: 12.033666610717773
3: 12.40877914428711
1: 13.089103698730469
2: 12.028268814086914
3: 12.401737213134766
1: 13.077561378479004
2: 12.022998809814453
3: 12.394708633422852
1: 13.065977096557617
2: 12.017854690551758
3: 12.387696266174316
1: 13.054347038269043
2: 12.012832641601562
3: 12.38070011138916
1: 13.042673110961914
2: 12.007927894592285
3: 12.373720169067383
1: 13.030956268310547
2: 12.00313949584961
3: 12.366756439208984
1: 13.019195556640625
2: 11.998467445373535
3: 12.359807968139648
1: 13.007390975952148
2: 11.993911743164062
3: 12.35287857055664
1: 12.99554061889648

1: 11.977367401123047
2: 11.936134338378906
3: 11.939186096191406
1: 11.975157737731934
2: 11.93610668182373
3: 11.934623718261719
1: 11.972991943359375
2: 11.936127662658691
3: 11.932741165161133
1: 11.97086238861084
2: 11.936113357543945
3: 11.938111305236816
1: 11.968767166137695
2: 11.936128616333008
3: 11.936796188354492
1: 11.966695785522461
2: 11.936107635498047
3: 11.93415641784668
1: 11.964643478393555
2: 11.936132431030273
3: 11.933431625366211
1: 11.962596893310547
2: 11.93610954284668
3: 11.937593460083008
1: 11.961164474487305
2: 11.936127662658691
3: 11.933902740478516
1: 11.967935562133789
2: 11.936110496520996
3: 11.935748100280762
1: 11.965763092041016
2: 11.93613052368164
3: 11.937255859375
1: 11.963611602783203
2: 11.936107635498047
3: 11.933828353881836
1: 11.961511611938477
2: 11.936131477355957
3: 11.93455696105957
1: 11.965156555175781
2: 11.93610954284668
3: 11.938249588012695
1: 11.967765808105469
2: 11.936128616333008
3: 11.933839797973633
1: 11.96456146240234

2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894

1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610

3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.96511

1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.9361

2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894

In [46]:
print(torch.linalg.det(model.comps[0].weight.detach()))
print(model.comps[0].weight)

tensor(0.0033)
Parameter containing:
tensor([[ 0.9099, -0.1592, -0.2409],
        [-0.1626,  0.7203, -0.4191],
        [-0.2383, -0.4208,  0.3773]], requires_grad=True)


In [50]:
points = orig_points[0]
points = model.comps[2](points) 
new_points = [ptrs.clone() for ptrs in orig_points]
new_points[0] = points.detach()
df = pd.DataFrame(np.concatenate(new_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [48]:
total_loss = 0
ptrs = torch.tensor(pred_points[1])
model.loss(model.compute_corr_dist(ptrs, ptrs))

tensor(3.7238)

In [72]:
def comp_stable_rank(ptrs):
    corr = np.matmul(ptrs.T, ptrs)/len(ptrs)
    sing = np.linalg.svd(corr)[1][0]
    norm = np.linalg.norm(corr)
    return norm**2/sing**2
    
print(comp_stable_rank(pred_points[0]))
print(comp_stable_rank(pred_points[1]))

1.001808826943963
1.0028864922275524


In [88]:
ptrs1 = dataset.points
ptrs2 = dataset.plane_points[2]
loss_total1 = 0
loss_total2 = 0
for i,ptr2 in enumerate(ptrs2):
    loss_total1 += cross_entropy(torch.sum((ptrs1 - ptr2[None,:])**2, dim=-1), torch.tensor(i).long())
print(loss_total1*2 / len(ptrs2))
for i,ptr1 in enumerate(ptrs1):
    loss_total2 += cross_entropy(torch.sum((ptr1[None,:] - ptrs2)**2, dim=-1), torch.tensor(i).long())
print(loss_total2*2 / len(ptrs1))
print((loss_total1+loss_total2)/len(ptrs2))

tensor(18.0287)
tensor(18.1396)
tensor(18.0842)


In [94]:
print(ptrs1.shape, ptrs2.shape)
l1 = model.loss(model.compute_corr_func(ptrs1, ptrs2, 3))
l1

torch.Size([1000, 3]) torch.Size([1000, 3])


tensor(13.9890, grad_fn=<AddBackward0>)

In [273]:
print(dataset.disks[0])

tensor([-0.3848,  0.8510, -0.3575])
