In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
class CubeDatasetAnchor(Dataset):
    def __init__(self, extra_modes, num_points, transform_fn=None):
        self.n = num_points
        self.m = extra_modes + 1
        self.points = np.random.uniform(-1,1,size=(self.n, 3))
        self.directions = np.random.uniform(-1,1,size=(extra_modes, 3))
        self.directions /= np.linalg.norm(self.directions, axis=-1)[:, None]
        self.plane_points = [self.project_points(self.points, self.directions[i]) for i in range(extra_modes)]
        self.points = torch.tensor(self.points).float()
        self.plane_points = [torch.tensor(points).float() for points in self.plane_points]
        self.transform_fn = transform_fn
    
    def project_points(self, points, direction):
        output_points = []
        for point in points:
            output_points.append(point - np.dot(point, direction) * direction)
        return np.array(output_points)
        
    def __len__(self):
        return self.m-1
    
    def __getitem__(self, idx):
        e_points = self.plane_points[idx]
        if self.transform_fn is not None:
            return self.transform_fn.evaluate(self.points), self.transform_fn.evaluate(e_points)
        return self.points, e_points
    
    def all_points(self):
        return [self.points] + list(self.plane_points)[:self.m-1]

In [3]:
from torch.nn.functional import cross_entropy, relu
class Mapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 128)
        self.linear2 = nn.Linear(128, 128)
    
    def forward(self, X):
        return self.linear2(relu(self.linear(X)))

    def evaluate(self, X):
        self.eval()
        with torch.no_grad():
            return self.forward(X)

mapper = Mapper()

In [4]:
extra_modes = 3
dataset = CubeDatasetAnchor(extra_modes, 1000, transform_fn=mapper)

In [139]:
from torch.nn.functional import cross_entropy, relu
class ContrastiveModel(nn.Module):
    def __init__(self, modes):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(128, 128) for i in range(modes)])
        self.linears2 = nn.ModuleList([nn.Linear(128, 64) for i in range(modes)])
        self.linears3 = nn.ModuleList([nn.Linear(64, 64) for i in range(modes)])
        self.linears4 = nn.ModuleList([nn.Linear(64, 3) for i in range(modes)])
        self.comps = nn.ModuleList([nn.Linear(3, 3) for i in range(modes-1)])
    
    def loss(self, corr_mat):
        #print(corr_mat)
        temp = 0.001
        losses_row = cross_entropy(corr_mat / temp, torch.arange(0, len(corr_mat)).long())
        losses_col = cross_entropy(torch.transpose(corr_mat,0,1) / temp, torch.arange(0, len(corr_mat)).long())
        return torch.mean(losses_row) + torch.mean(losses_col)
    
    def compute_corr_dot(self, a_points, b_points):
        return torch.sum(a_points[:,None,:] * b_points[None,:,:], dim=-1)
    
    def compute_corr_dist(self, a_points, b_points):
        return -torch.sum((a_points[:,None,:] - b_points[None,:,:])**2, dim=-1)
    
    def compute_corr_func(self, a_points, b_points, b):
        return -torch.sum((self.comps[b-1](a_points[:,None,:] - b_points[None,:,:]) + a_points[:,None,:] - b_points[None,:,:])**2, dim=-1)
    
    def normalize(self, points):
        return points / torch.sqrt(torch.mean(torch.sum(points**2, dim=-1), dim=0))
    
    def forward(self, a, b, a_points, b_points):
        #a_points = self.linears[a](a_points)
        #b_points = self.linears[b](b_points)
        #a_points = self.linears2[a](relu(self.linears[a](a_points)))
        #b_points = self.linears2[b](relu(self.linears[b](b_points)))
        a_points = self.linears4[a](relu(self.linears3[a](relu(self.linears2[a](relu(self.linears[a](a_points)))))))
        b_points = self.linears4[b](relu(self.linears3[b](relu(self.linears2[b](relu(self.linears[b](b_points)))))))
        #a_points = self.comps[b-1](a_points) #+ a_points
        a_points = self.normalize(a_points)
        b_points = self.normalize(b_points)
        
        corr_mat = self.compute_corr_dist(a_points, b_points)
        return corr_mat
    
    def run_evaluate(self, i, points):
        #return self.linears[i](points)
        #return self.linears2[i](relu(self.linears[i](points)))
        points = self.linears4[i](relu(self.linears3[i](relu(self.linears2[i](relu(self.linears[i](points)))))))
        points = self.normalize(points)
        return points

In [157]:
model = ContrastiveModel(extra_modes+1)

In [158]:
from torch.optim import Adam, SGD
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

print([name for name, param in model.named_parameters()])
#model.linears.require_grad = False
optim = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.001)

['linears.0.weight', 'linears.0.bias', 'linears.1.weight', 'linears.1.bias', 'linears.2.weight', 'linears.2.bias', 'linears.3.weight', 'linears.3.bias', 'linears2.0.weight', 'linears2.0.bias', 'linears2.1.weight', 'linears2.1.bias', 'linears2.2.weight', 'linears2.2.bias', 'linears2.3.weight', 'linears2.3.bias', 'linears3.0.weight', 'linears3.0.bias', 'linears3.1.weight', 'linears3.1.bias', 'linears3.2.weight', 'linears3.2.bias', 'linears3.3.weight', 'linears3.3.bias', 'linears4.0.weight', 'linears4.0.bias', 'linears4.1.weight', 'linears4.1.bias', 'linears4.2.weight', 'linears4.2.bias', 'linears4.3.weight', 'linears4.3.bias', 'comps.0.weight', 'comps.0.bias', 'comps.1.weight', 'comps.1.bias', 'comps.2.weight', 'comps.2.bias']


In [159]:
def train_data(optim, model, ptrs1, ptrs2, a, b, iteration=10):
    avg_loss = 0
    for i in range(iteration):
        optim.zero_grad()
        corr_mat = model(a, b, ptrs1, ptrs2)
        loss = model.loss(corr_mat)
        loss.backward()
        optim.step()
        avg_loss += loss.item() / iteration
    return avg_loss

In [160]:
EPOCH = 1000
for epoch in range(EPOCH):
    for i in range(extra_modes):
        ptrs1, ptrs2 = dataset[i]
        loss = train_data(optim, model, ptrs1, ptrs2, 0, i+1, iteration=1)
        print(f'{i+1}: {loss}')

1: 630.6636962890625
2: 412.15606689453125
3: 572.4110717773438
1: 1096.369140625
2: 532.6571044921875
3: 397.05706787109375
1: 514.2940673828125
2: 301.8310241699219
3: 159.5704345703125
1: 453.7767028808594
2: 207.5424041748047
3: 101.52064514160156
1: 189.07687377929688
2: 155.71568298339844
3: 89.25154113769531
1: 211.9785614013672
2: 108.05376434326172
3: 77.354248046875
1: 179.51925659179688
2: 90.47801208496094
3: 50.48902893066406
1: 227.67941284179688
2: 71.27418518066406
3: 36.19190216064453
1: 126.552734375
2: 50.716495513916016
3: 35.336429595947266
1: 139.9535675048828
2: 41.65092849731445
3: 29.909799575805664
1: 127.98350524902344
2: 47.292442321777344
3: 27.455684661865234
1: 114.4742431640625
2: 48.00914764404297
3: 23.319477081298828
1: 89.2027587890625
2: 77.50298309326172
3: 21.952800750732422
1: 72.0858154296875
2: 38.35414123535156
3: 19.881671905517578
1: 59.06883239746094
2: 73.2842025756836
3: 19.098947525024414
1: 63.538597106933594
2: 57.911720275878906
3: 17

3: 10.676568031311035
1: 12.54478645324707
2: 11.034065246582031
3: 10.65600299835205
1: 11.787774085998535
2: 10.798065185546875
3: 10.631498336791992
1: 11.810724258422852
2: 10.992764472961426
3: 10.628213882446289
1: 11.67636489868164
2: 11.269723892211914
3: 10.589910507202148
1: 11.385099411010742
2: 11.272497177124023
3: 10.579986572265625
1: 11.349615097045898
2: 10.798456192016602
3: 10.557125091552734
1: 11.200475692749023
2: 10.839143753051758
3: 10.537088394165039
1: 11.057025909423828
2: 10.816146850585938
3: 10.51436996459961
1: 11.245673179626465
2: 10.629165649414062
3: 10.492350578308105
1: 11.480438232421875
2: 10.632210731506348
3: 10.472566604614258
1: 11.655155181884766
2: 10.826786994934082
3: 10.44986343383789
1: 11.309870719909668
2: 11.150659561157227
3: 10.429466247558594
1: 11.28469467163086
2: 10.972352981567383
3: 10.404772758483887
1: 10.822227478027344
2: 10.638628005981445
3: 10.385431289672852
1: 11.320100784301758
2: 10.807075500488281
3: 10.3636417388

1: 8.969810485839844
2: 9.105850219726562
3: 8.942852020263672
1: 8.956487655639648
2: 9.095120429992676
3: 8.934226989746094
1: 8.945266723632812
2: 9.079839706420898
3: 8.926576614379883
1: 8.933452606201172
2: 9.06615924835205
3: 8.919088363647461
1: 8.921414375305176
2: 9.056219100952148
3: 8.910835266113281
1: 8.909138679504395
2: 9.04737663269043
3: 8.901315689086914
1: 8.896722793579102
2: 9.040016174316406
3: 8.891502380371094
1: 8.884344100952148
2: 9.03225040435791
3: 8.882294654846191
1: 8.872538566589355
2: 9.024053573608398
3: 8.873758316040039
1: 8.860651969909668
2: 9.015776634216309
3: 8.865446090698242
1: 8.848684310913086
2: 9.008403778076172
3: 8.857061386108398
1: 8.836570739746094
2: 9.001083374023438
3: 8.84887981414795
1: 8.825045585632324
2: 8.993306159973145
3: 8.840686798095703
1: 8.813129425048828
2: 8.986698150634766
3: 8.832210540771484
1: 8.801689147949219
2: 8.97989273071289
3: 8.82363510131836
1: 8.790456771850586
2: 8.973584175109863
3: 8.81549739837646

3: 8.290075302124023
1: 7.935004234313965
2: 8.4478120803833
3: 8.287542343139648
1: 7.92596435546875
2: 8.446435928344727
3: 8.285162925720215
1: 7.917375564575195
2: 8.632171630859375
3: 8.28754997253418
1: 7.912778377532959
2: 8.544015884399414
3: 8.313726425170898
1: 7.932675361633301
2: 8.501341819763184
3: 8.402602195739746
1: 7.941024303436279
2: 8.43372917175293
3: 8.4674654006958
1: 7.925039291381836
2: 8.535152435302734
3: 8.508131980895996
1: 7.944756984710693
2: 8.681282043457031
3: 8.4970703125
1: 7.940046787261963
2: 8.82981014251709
3: 8.545721054077148
1: 7.943335056304932
2: 8.936132431030273
3: 8.65864372253418
1: 7.963836669921875
2: 9.03266716003418
3: 9.020163536071777
1: 8.111146926879883
2: 9.196479797363281
3: 9.552319526672363
1: 8.253117561340332
2: 9.516407012939453
3: 10.389493942260742
1: 8.213691711425781
2: 9.958890914916992
3: 10.537105560302734
1: 7.9678497314453125
2: 9.95388412475586
3: 10.560332298278809
1: 7.9626641273498535
2: 9.300939559936523
3: 

2: 8.278910636901855
3: 8.148865699768066
1: 7.594040870666504
2: 8.27920150756836
3: 8.144258499145508
1: 7.590404510498047
2: 8.280137062072754
3: 8.143556594848633
1: 7.585416793823242
2: 8.279435157775879
3: 8.141670227050781
1: 7.582821846008301
2: 8.279279708862305
3: 8.140525817871094
1: 7.579427719116211
2: 8.278063774108887
3: 8.14085578918457
1: 7.576716899871826
2: 8.277889251708984
3: 8.138050079345703
1: 7.574756622314453
2: 8.276908874511719
3: 8.13758659362793
1: 7.572294235229492
2: 8.276485443115234
3: 8.135157585144043
1: 7.569991111755371
2: 8.276726722717285
3: 8.134054183959961
1: 7.566959857940674
2: 8.27719783782959
3: 8.13182544708252
1: 7.564699172973633
2: 8.276949882507324
3: 8.131150245666504
1: 7.561854839324951
2: 8.276376724243164
3: 8.130363464355469
1: 7.5591864585876465
2: 8.276357650756836
3: 8.12930679321289
1: 7.556426048278809
2: 8.276169776916504
3: 8.128293991088867
1: 7.553848743438721
2: 8.276269912719727
3: 8.127132415771484
1: 7.5513744354248

2: 8.236543655395508
3: 8.077152252197266
1: 7.406059741973877
2: 8.236577987670898
3: 8.081259727478027
1: 7.403961181640625
2: 8.235928535461426
3: 8.091266632080078
1: 7.398800373077393
2: 8.242363929748535
3: 8.12611198425293
1: 7.396831512451172
2: 8.271501541137695
3: 8.155701637268066
1: 7.422224998474121
2: 8.3270902633667
3: 8.176384925842285
1: 7.560251712799072
2: 8.411243438720703
3: 8.117752075195312
1: 7.865346908569336
2: 8.412302017211914
3: 8.088462829589844
1: 8.253708839416504
2: 8.306257247924805
3: 8.598350524902344
1: 8.50002670288086
2: 8.285616874694824
3: 10.721328735351562
1: 7.794326305389404
2: 9.17220687866211
3: 15.237604141235352
1: 7.5603718757629395
2: 12.573959350585938
3: 14.303938865661621
1: 8.849836349487305
2: 13.853636741638184
3: 11.181450843811035
1: 9.696146011352539
2: 10.325828552246094
3: 8.376825332641602
1: 7.837735176086426
2: 8.209589004516602
3: 8.172552108764648
1: 8.967081069946289
2: 8.379138946533203
3: 8.534158706665039
1: 9.58827

2: 8.175089836120605
3: 8.066755294799805
1: 7.437629699707031
2: 8.170036315917969
3: 8.072303771972656
1: 7.427252769470215
2: 8.171928405761719
3: 8.060346603393555
1: 7.422570705413818
2: 8.1771240234375
3: 8.060075759887695
1: 7.418139457702637
2: 8.171716690063477
3: 8.064970970153809
1: 7.4146952629089355
2: 8.170336723327637
3: 8.06117057800293
1: 7.410425186157227
2: 8.170873641967773
3: 8.061481475830078
1: 7.4071125984191895
2: 8.172292709350586
3: 8.05955696105957
1: 7.402799606323242
2: 8.172468185424805
3: 8.060751914978027
1: 7.398726463317871
2: 8.170646667480469
3: 8.060277938842773
1: 7.3961286544799805
2: 8.171636581420898
3: 8.058220863342285
1: 7.393611907958984
2: 8.171985626220703
3: 8.057352066040039
1: 7.391540050506592
2: 8.172189712524414
3: 8.05544662475586
1: 7.389389991760254
2: 8.171791076660156
3: 8.05504035949707
1: 7.38682222366333
2: 8.17306137084961
3: 8.052350997924805
1: 7.384485244750977
2: 8.174317359924316
3: 8.051393508911133
1: 7.3818912506103

2: 8.171308517456055
3: 8.049009323120117
1: 7.270345687866211
2: 8.1871337890625
3: 8.05550765991211
1: 7.266366958618164
2: 8.194761276245117
3: 8.073419570922852
1: 7.271938323974609
2: 8.214860916137695
3: 8.055038452148438
1: 7.29920768737793
2: 8.20248031616211
3: 8.037260055541992
1: 7.381079196929932
2: 8.187138557434082
3: 8.025468826293945
1: 7.54777717590332
2: 8.166879653930664
3: 8.148006439208984
1: 7.723735809326172
2: 8.23861312866211
3: 8.675353050231934
1: 7.946261405944824
2: 8.694263458251953
3: 9.798641204833984
1: 7.659635543823242
2: 9.520071029663086
3: 11.63113784790039
1: 7.306172847747803
2: 10.887863159179688
3: 11.643516540527344
1: 7.397597312927246
2: 10.61791706085205
3: 10.679488182067871
1: 7.922835350036621
2: 9.401552200317383
3: 8.498798370361328
1: 7.934199333190918
2: 8.302581787109375
3: 8.048625946044922
1: 7.363613128662109
2: 8.240419387817383
3: 8.040346145629883
1: 7.9775071144104
2: 8.271581649780273
3: 8.162745475769043
1: 8.66851997375488

In [161]:
orig_points = dataset.all_points()

In [162]:
import plotly.express as px
import pandas as pd

In [163]:
df = pd.DataFrame(np.concatenate(orig_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type', opacity=1)
fig.update_traces(marker={'size': 5})
fig.show()

In [164]:
pred_points = [model.run_evaluate(i, mapper(points)).detach().numpy() for i, points in enumerate(orig_points)]

In [165]:
df = pd.DataFrame(np.concatenate(pred_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [166]:
o_ptrs0 = torch.tensor(orig_points[0])
o_ptrs1 = torch.tensor(orig_points[0])

model.loss(model.compute_corr_dist(o_ptrs0, o_ptrs1))


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).


To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).



tensor(0.0505)

In [15]:
print(torch.linalg.det(model.comps[0].weight.detach()))
print(model.comps[0].weight)

tensor(0.0375)
Parameter containing:
tensor([[-0.1340, -0.0468, -0.2098],
        [ 0.1773, -0.4875,  0.0951],
        [-0.1969, -0.4759,  0.0661]], requires_grad=True)


In [16]:
points = pred_points[0]
points = model.comps[2](torch.tensor(points)) #+ torch.tensor(points)
new_points = [ptrs.copy() for ptrs in pred_points]
new_points[0] = points.detach()
df = pd.DataFrame(np.concatenate(new_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [44]:
class CompressModel(nn.Module):
    def __init__(self, modes):
        super().__init__()
        self.comps = nn.ModuleList([nn.Linear(3, 3, bias=False) for i in range(modes-1)])
    
    def loss(self, corr_mat):
        #print(corr_mat)
        losses_row = cross_entropy(corr_mat, torch.arange(0, len(corr_mat)).long())
        losses_col = cross_entropy(torch.transpose(corr_mat,0,1), torch.arange(0, len(corr_mat)).long())
        return torch.mean(losses_row) + torch.mean(losses_col)
    
    def compute_corr_dot(self, a_points, b_points):
        return torch.sum(a_points[:,None,:] * b_points[None,:,:], dim=-1)
    
    def compute_corr_dist(self, a_points, b_points):
        return -torch.sqrt(torch.sum((a_points[:,None,:] - b_points[None,:,:])**2, dim=-1))
    
    def compute_corr_func(self, a_points, b_points, b):
        return -torch.sqrt(torch.sum((self.comps[b-1](a_points[:,None,:] - b_points[None,:,:]) + a_points[:,None,:] - b_points[None,:,:])**2, dim=-1))
    
    def forward(self, a, b, a_points, b_points):
        #a_points = self.linears[a](a_points)
        #b_points = self.linears[b](b_points)
        #a_points = self.linears2[a](relu(self.linears[a](a_points)))
        #b_points = self.linears2[b](relu(self.linears[b](b_points)))
        a_points = self.comps[b-1](a_points) #+ a_points
        corr_mat = self.compute_corr_dist(a_points, b_points)
        return corr_mat

In [45]:
model = CompressModel(extra_modes+1)
optim = SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)
def train_data(optim, model, ptrs1, ptrs2, a, b, iteration=10):
    avg_loss = 0
    for i in range(iteration):
        optim.zero_grad()
        corr_mat = model(a, b, ptrs1, ptrs2)
        loss = model.loss(corr_mat)
        loss.backward()
        optim.step()
        avg_loss += loss.item() / iteration
    return avg_loss

EPOCH = 1000
for epoch in range(EPOCH):
    for i in range(extra_modes):
        ptrs1 = dataset.points
        ptrs2 = dataset.plane_points[i]
        loss = train_data(optim, model, ptrs1, ptrs2, 0, i+1, iteration=1)
        print(f'{i+1}: {loss}')

1: 14.277408599853516
2: 13.450014114379883
3: 13.608136177062988
1: 14.27114200592041
2: 13.439750671386719
3: 13.597736358642578
1: 14.264841079711914
2: 13.42944622039795
3: 13.587303161621094
1: 14.258506774902344
2: 13.419105529785156
3: 13.576837539672852
1: 14.252137184143066
2: 13.40872859954834
3: 13.566341400146484
1: 14.245733261108398
2: 13.39831256866455
3: 13.55581283569336
1: 14.239294052124023
2: 13.387859344482422
3: 13.545255661010742
1: 14.232818603515625
2: 13.37736988067627
3: 13.534669876098633
1: 14.226308822631836
2: 13.366841316223145
3: 13.524054527282715
1: 14.219764709472656
2: 13.35627555847168
3: 13.51341438293457
1: 14.213183403015137
2: 13.345674514770508
3: 13.502748489379883
1: 14.20656681060791
2: 13.335037231445312
3: 13.492055892944336
1: 14.199913024902344
2: 13.324363708496094
3: 13.481340408325195
1: 14.193222045898438
2: 13.313652038574219
3: 13.470603942871094
1: 14.186495780944824
2: 13.30290412902832
3: 13.459844589233398
1: 14.17973327636718

1: 13.168724060058594
2: 12.069124221801758
3: 12.45132827758789
1: 13.15747356414795
2: 12.062813758850098
3: 12.444202423095703
1: 13.146183013916016
2: 12.056674003601074
3: 12.437091827392578
1: 13.13485050201416
2: 12.05069637298584
3: 12.42999267578125
1: 13.123476028442383
2: 12.04487419128418
3: 12.422906875610352
1: 13.112061500549316
2: 12.039199829101562
3: 12.415836334228516
1: 13.100604057312012
2: 12.033666610717773
3: 12.40877914428711
1: 13.089103698730469
2: 12.028268814086914
3: 12.401737213134766
1: 13.077561378479004
2: 12.022998809814453
3: 12.394708633422852
1: 13.065977096557617
2: 12.017854690551758
3: 12.387696266174316
1: 13.054347038269043
2: 12.012832641601562
3: 12.38070011138916
1: 13.042673110961914
2: 12.007927894592285
3: 12.373720169067383
1: 13.030956268310547
2: 12.00313949584961
3: 12.366756439208984
1: 13.019195556640625
2: 11.998467445373535
3: 12.359807968139648
1: 13.007390975952148
2: 11.993911743164062
3: 12.35287857055664
1: 12.99554061889648

1: 11.977367401123047
2: 11.936134338378906
3: 11.939186096191406
1: 11.975157737731934
2: 11.93610668182373
3: 11.934623718261719
1: 11.972991943359375
2: 11.936127662658691
3: 11.932741165161133
1: 11.97086238861084
2: 11.936113357543945
3: 11.938111305236816
1: 11.968767166137695
2: 11.936128616333008
3: 11.936796188354492
1: 11.966695785522461
2: 11.936107635498047
3: 11.93415641784668
1: 11.964643478393555
2: 11.936132431030273
3: 11.933431625366211
1: 11.962596893310547
2: 11.93610954284668
3: 11.937593460083008
1: 11.961164474487305
2: 11.936127662658691
3: 11.933902740478516
1: 11.967935562133789
2: 11.936110496520996
3: 11.935748100280762
1: 11.965763092041016
2: 11.93613052368164
3: 11.937255859375
1: 11.963611602783203
2: 11.936107635498047
3: 11.933828353881836
1: 11.961511611938477
2: 11.936131477355957
3: 11.93455696105957
1: 11.965156555175781
2: 11.93610954284668
3: 11.938249588012695
1: 11.967765808105469
2: 11.936128616333008
3: 11.933839797973633
1: 11.96456146240234

2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894

1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610

3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.96511

1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.9361

2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894

In [46]:
print(torch.linalg.det(model.comps[0].weight.detach()))
print(model.comps[0].weight)

tensor(0.0033)
Parameter containing:
tensor([[ 0.9099, -0.1592, -0.2409],
        [-0.1626,  0.7203, -0.4191],
        [-0.2383, -0.4208,  0.3773]], requires_grad=True)


In [50]:
points = orig_points[0]
points = model.comps[2](points) 
new_points = [ptrs.clone() for ptrs in orig_points]
new_points[0] = points.detach()
df = pd.DataFrame(np.concatenate(new_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [48]:
total_loss = 0
ptrs = torch.tensor(pred_points[1])
model.loss(model.compute_corr_dist(ptrs, ptrs))

tensor(3.7238)

In [72]:
def comp_stable_rank(ptrs):
    corr = np.matmul(ptrs.T, ptrs)/len(ptrs)
    sing = np.linalg.svd(corr)[1][0]
    norm = np.linalg.norm(corr)
    return norm**2/sing**2
    
print(comp_stable_rank(pred_points[0]))
print(comp_stable_rank(pred_points[1]))

1.001808826943963
1.0028864922275524


In [88]:
ptrs1 = dataset.points
ptrs2 = dataset.plane_points[2]
loss_total1 = 0
loss_total2 = 0
for i,ptr2 in enumerate(ptrs2):
    loss_total1 += cross_entropy(torch.sum((ptrs1 - ptr2[None,:])**2, dim=-1), torch.tensor(i).long())
print(loss_total1*2 / len(ptrs2))
for i,ptr1 in enumerate(ptrs1):
    loss_total2 += cross_entropy(torch.sum((ptr1[None,:] - ptrs2)**2, dim=-1), torch.tensor(i).long())
print(loss_total2*2 / len(ptrs1))
print((loss_total1+loss_total2)/len(ptrs2))

tensor(18.0287)
tensor(18.1396)
tensor(18.0842)


In [94]:
print(ptrs1.shape, ptrs2.shape)
l1 = model.loss(model.compute_corr_func(ptrs1, ptrs2, 3))
l1

torch.Size([1000, 3]) torch.Size([1000, 3])


tensor(13.9890, grad_fn=<AddBackward0>)

In [273]:
print(dataset.disks[0])

tensor([-0.3848,  0.8510, -0.3575])
