In [40]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

In [50]:
class CubeDatasetAnchor(Dataset):
    def __init__(self, extra_modes, num_points, transform_fn=None):
        self.n = num_points
        self.m = extra_modes + 1
        self.points = np.random.uniform(-1,1,size=(self.n, 3))
        self.directions = np.random.uniform(-1,1,size=(extra_modes, 3))
        self.directions /= np.linalg.norm(self.directions, axis=-1)[:, None]
        self.plane_points = [self.project_points(self.points, self.directions[i]) for i in range(extra_modes)]
        self.points = torch.tensor(self.points).float()
        self.plane_points = [torch.tensor(points).float() for points in self.plane_points]
        self.all_points = [self.points] + self.plane_points
        self.transform_fn = transform_fn
    
    def project_points(self, points, direction):
        output_points = []
        for point in points:
            output_points.append(point - np.dot(point, direction) * direction)
        return np.array(output_points)
        
    def __len__(self):
        return 1
    
    def __getitem__(self, idx):
        if self.transform_fn is not None:
            return [self.transform_fn(points) for points in self.all_points]
        return self.all_points

In [51]:
from torch.nn.functional import cross_entropy, relu
class Mapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 128)
        self.linear2 = nn.Linear(128, 128)
        self.linear3 = nn.Linear(128, 128)
        self.linear4 = nn.Linear(128, 128)
    
    def forward(self, X):
        return self.linear3(relu(self.linear2(relu(self.linear(X)))))

    def evaluate(self, X):
        self.eval()
        with torch.no_grad():
            return self.forward(X)

mapper = Mapper()

In [52]:
extra_modes = 3
dataset = CubeDatasetAnchor(extra_modes, 1000, transform_fn=mapper)

In [165]:
from torch.nn.functional import cross_entropy, relu
class ContrastiveModel(nn.Module):
    def __init__(self, modes):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(128, 128) for i in range(modes*2-2)])
        self.linears2 = nn.ModuleList([nn.Linear(128, 64) for i in range(modes*2-2)])
        self.linears3 = nn.ModuleList([nn.Linear(64, 64) for i in range(modes*2-2)])
        self.linears4 = nn.ModuleList([nn.Linear(64, 3) for i in range(modes*2-2)])

    def loss(self, corr_mat):
        #print(corr_mat)
        losses_row = cross_entropy(corr_mat, torch.arange(0, len(corr_mat)).long())
        losses_col = cross_entropy(torch.transpose(corr_mat,0,1), torch.arange(0, len(corr_mat)).long())
        return torch.mean(losses_row) + torch.mean(losses_col)
    
    def compute_corr_dot(self, a_points, b_points):
        return torch.sum(a_points[:,None,:] * b_points[None,:,:], dim=-1)
    
    def compute_corr_dist(self, a_points, b_points):
        return -torch.sum((a_points[:,None,:] - b_points[None,:,:])**2, dim=-1)
    
    def compute_corr_metric(self, a_points, b_points, b):
        diff = a_points[:,None,:] - b_points[None,:,:]
        metric = (self.metrics[b-1] + self.metrics[b-1].T)/2
        return -torch.sum(torch.tensordot(diff, metric, dims=([-1],[0])) * diff, dim=-1)
    
    def compute_corr_func(self, a_points, b_points, b):
        return -torch.sum((self.comps[b-1](a_points[:,None,:] - b_points[None,:,:]) + a_points[:,None,:] - b_points[None,:,:])**2, dim=-1)
    
    def normalize(self, points):
        return points / torch.sqrt(torch.mean(torch.sum(points**2, dim=-1), dim=0))
    
    def forward(self, all_points, binding_coeff=1.0):
        a_points = all_points[0]
        e_points = all_points[1:]
        extra_modes = len(e_points)
        ne_points = torch.zeros((extra_modes, len(a_points), 3))
        na_points = torch.zeros((extra_modes, len(a_points), 3))
        for i in range(extra_modes):
            ne_points[i] = self.run_evaluate(i, e_points[i])
            na_points[i] = self.run_evaluate(i+extra_modes, a_points)
        
        individual_losses = 0
        for i in range(extra_modes):
            individual_losses += self.loss(self.compute_corr_dist(na_points[i], ne_points[i]))

        a_points = torch.sum(na_points, dim=0) # perhaps we should normalize it?
        together_losses = 0
        for i in range(extra_modes):
            together_losses += self.loss(self.compute_corr_dist(a_points, ne_points[i]))
        
        individual_losses /= extra_modes
        together_losses /= extra_modes

        return individual_losses + together_losses * binding_coeff

    def pred_points(self, all_points):
        self.eval()
        with torch.no_grad():
            a_points = all_points[0]
            e_points = all_points[1:]
            extra_modes = len(e_points)
            ne_points = [self.run_evaluate(i, e_points[i]).detach().numpy() for i in range(extra_modes)]
            na_points = [self.run_evaluate(i, a_points).detach().numpy() for i in range(extra_modes, extra_modes*2)]
            return [sum(na_points)] + ne_points

    def run_evaluate(self, i, points):
        #return self.linears[i](points)
        #return self.linears2[i](relu(self.linears[i](points)))
        points = self.linears4[i](relu(self.linears3[i](relu(self.linears2[i](relu(self.linears[i](points)))))))
        points = self.normalize(points)
        return points

In [201]:
model = ContrastiveModel(extra_modes+1)

In [202]:
from torch.optim import Adam, SGD
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)

optim = Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=0.1)

In [203]:
def train_data(optim, model, all_points):
    optim.zero_grad()
    loss = model(all_points, binding_coeff=0.05)
    loss.backward()
    optim.step()
    return loss.item()

In [204]:
EPOCH = 1000
progress_bar = tqdm(range(EPOCH), desc='Epoch loss') 
for epoch in progress_bar:
    loss_val = train_data(optim, model, dataset[0]) # using the accessor to use transform_fn
    progress_bar.set_postfix({'loss': loss_val})

    

Epoch loss: 100%|██████████| 1000/1000 [01:22<00:00, 12.12it/s, loss=11.6]


In [205]:
orig_points = dataset.all_points

In [206]:
import plotly.express as px
import pandas as pd

In [207]:
df = pd.DataFrame(np.concatenate(orig_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type', opacity=1)
fig.update_traces(marker={'size': 5})
fig.show()

In [208]:
pred_points = model.pred_points(dataset[0])

In [209]:
def plot_points(points):
    df = pd.DataFrame(np.concatenate(points, axis=0), columns=['x', 'y', 'z'])
    df['type'] = sum([[f'{i}']*1000 for i in range(len(points))], [])
    fig = px.scatter_3d(df, x='x', y='y', z='z', color='type', opacity=1)
    fig.update_traces(marker={'size': 5})
    fig.show()

In [210]:
plot_points(pred_points)

In [211]:
na_points = [model.run_evaluate(i, dataset[0][0]).detach().numpy() for i in range(extra_modes, extra_modes*2)]
p_points = na_points[0:1] + pred_points[1:]
plot_points(p_points)

In [200]:
print(torch.linalg.det(model.comps[0].weight.detach()))
print(model.comps[0].weight)

AttributeError: 'ContrastiveModel' object has no attribute 'comps'

In [84]:
points = pred_points[0]
points = model.comps[2](torch.tensor(points)) #+ torch.tensor(points)
points = model.normalize(points)
new_points = [ptrs.copy() for ptrs in pred_points]
new_points[0] = points.detach()
df = pd.DataFrame(np.concatenate(new_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [44]:
class CompressModel(nn.Module):
    def __init__(self, modes):
        super().__init__()
        self.comps = nn.ModuleList([nn.Linear(3, 3, bias=False) for i in range(modes-1)])
    
    def loss(self, corr_mat):
        #print(corr_mat)
        losses_row = cross_entropy(corr_mat, torch.arange(0, len(corr_mat)).long())
        losses_col = cross_entropy(torch.transpose(corr_mat,0,1), torch.arange(0, len(corr_mat)).long())
        return torch.mean(losses_row) + torch.mean(losses_col)
    
    def compute_corr_dot(self, a_points, b_points):
        return torch.sum(a_points[:,None,:] * b_points[None,:,:], dim=-1)
    
    def compute_corr_dist(self, a_points, b_points):
        return -torch.sqrt(torch.sum((a_points[:,None,:] - b_points[None,:,:])**2, dim=-1))
    
    def compute_corr_func(self, a_points, b_points, b):
        return -torch.sqrt(torch.sum((self.comps[b-1](a_points[:,None,:] - b_points[None,:,:]) + a_points[:,None,:] - b_points[None,:,:])**2, dim=-1))
    
    def forward(self, a, b, a_points, b_points):
        #a_points = self.linears[a](a_points)
        #b_points = self.linears[b](b_points)
        #a_points = self.linears2[a](relu(self.linears[a](a_points)))
        #b_points = self.linears2[b](relu(self.linears[b](b_points)))
        a_points = self.comps[b-1](a_points) #+ a_points
        corr_mat = self.compute_corr_dist(a_points, b_points)
        return corr_mat

In [45]:
model = CompressModel(extra_modes+1)
optim = SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=0.01)
def train_data(optim, model, ptrs1, ptrs2, a, b, iteration=10):
    avg_loss = 0
    for i in range(iteration):
        optim.zero_grad()
        corr_mat = model(a, b, ptrs1, ptrs2)
        loss = model.loss(corr_mat)
        loss.backward()
        optim.step()
        avg_loss += loss.item() / iteration
    return avg_loss

EPOCH = 1000
for epoch in range(EPOCH):
    for i in range(extra_modes):
        ptrs1 = dataset.points
        ptrs2 = dataset.plane_points[i]
        loss = train_data(optim, model, ptrs1, ptrs2, 0, i+1, iteration=1)
        print(f'{i+1}: {loss}')

1: 14.277408599853516
2: 13.450014114379883
3: 13.608136177062988
1: 14.27114200592041
2: 13.439750671386719
3: 13.597736358642578
1: 14.264841079711914
2: 13.42944622039795
3: 13.587303161621094
1: 14.258506774902344
2: 13.419105529785156
3: 13.576837539672852
1: 14.252137184143066
2: 13.40872859954834
3: 13.566341400146484
1: 14.245733261108398
2: 13.39831256866455
3: 13.55581283569336
1: 14.239294052124023
2: 13.387859344482422
3: 13.545255661010742
1: 14.232818603515625
2: 13.37736988067627
3: 13.534669876098633
1: 14.226308822631836
2: 13.366841316223145
3: 13.524054527282715
1: 14.219764709472656
2: 13.35627555847168
3: 13.51341438293457
1: 14.213183403015137
2: 13.345674514770508
3: 13.502748489379883
1: 14.20656681060791
2: 13.335037231445312
3: 13.492055892944336
1: 14.199913024902344
2: 13.324363708496094
3: 13.481340408325195
1: 14.193222045898438
2: 13.313652038574219
3: 13.470603942871094
1: 14.186495780944824
2: 13.30290412902832
3: 13.459844589233398
1: 14.17973327636718

1: 13.168724060058594
2: 12.069124221801758
3: 12.45132827758789
1: 13.15747356414795
2: 12.062813758850098
3: 12.444202423095703
1: 13.146183013916016
2: 12.056674003601074
3: 12.437091827392578
1: 13.13485050201416
2: 12.05069637298584
3: 12.42999267578125
1: 13.123476028442383
2: 12.04487419128418
3: 12.422906875610352
1: 13.112061500549316
2: 12.039199829101562
3: 12.415836334228516
1: 13.100604057312012
2: 12.033666610717773
3: 12.40877914428711
1: 13.089103698730469
2: 12.028268814086914
3: 12.401737213134766
1: 13.077561378479004
2: 12.022998809814453
3: 12.394708633422852
1: 13.065977096557617
2: 12.017854690551758
3: 12.387696266174316
1: 13.054347038269043
2: 12.012832641601562
3: 12.38070011138916
1: 13.042673110961914
2: 12.007927894592285
3: 12.373720169067383
1: 13.030956268310547
2: 12.00313949584961
3: 12.366756439208984
1: 13.019195556640625
2: 11.998467445373535
3: 12.359807968139648
1: 13.007390975952148
2: 11.993911743164062
3: 12.35287857055664
1: 12.99554061889648

1: 11.977367401123047
2: 11.936134338378906
3: 11.939186096191406
1: 11.975157737731934
2: 11.93610668182373
3: 11.934623718261719
1: 11.972991943359375
2: 11.936127662658691
3: 11.932741165161133
1: 11.97086238861084
2: 11.936113357543945
3: 11.938111305236816
1: 11.968767166137695
2: 11.936128616333008
3: 11.936796188354492
1: 11.966695785522461
2: 11.936107635498047
3: 11.93415641784668
1: 11.964643478393555
2: 11.936132431030273
3: 11.933431625366211
1: 11.962596893310547
2: 11.93610954284668
3: 11.937593460083008
1: 11.961164474487305
2: 11.936127662658691
3: 11.933902740478516
1: 11.967935562133789
2: 11.936110496520996
3: 11.935748100280762
1: 11.965763092041016
2: 11.93613052368164
3: 11.937255859375
1: 11.963611602783203
2: 11.936107635498047
3: 11.933828353881836
1: 11.961511611938477
2: 11.936131477355957
3: 11.93455696105957
1: 11.965156555175781
2: 11.93610954284668
3: 11.938249588012695
1: 11.967765808105469
2: 11.936128616333008
3: 11.933839797973633
1: 11.96456146240234

2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894

1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610

3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.96511

1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.9361

2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894012451172
1: 11.965118408203125
2: 11.93613052368164
3: 11.93589973449707
1: 11.965171813964844
2: 11.93610954284668
3: 11.935894

In [46]:
print(torch.linalg.det(model.comps[0].weight.detach()))
print(model.comps[0].weight)

tensor(0.0033)
Parameter containing:
tensor([[ 0.9099, -0.1592, -0.2409],
        [-0.1626,  0.7203, -0.4191],
        [-0.2383, -0.4208,  0.3773]], requires_grad=True)


In [50]:
points = orig_points[0]
points = model.comps[2](points) 
new_points = [ptrs.clone() for ptrs in orig_points]
new_points[0] = points.detach()
df = pd.DataFrame(np.concatenate(new_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*1000 for i in range(extra_modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [48]:
total_loss = 0
ptrs = torch.tensor(pred_points[1])
model.loss(model.compute_corr_dist(ptrs, ptrs))

tensor(3.7238)

In [72]:
def comp_stable_rank(ptrs):
    corr = np.matmul(ptrs.T, ptrs)/len(ptrs)
    sing = np.linalg.svd(corr)[1][0]
    norm = np.linalg.norm(corr)
    return norm**2/sing**2
    
print(comp_stable_rank(pred_points[0]))
print(comp_stable_rank(pred_points[1]))

1.001808826943963
1.0028864922275524


In [88]:
ptrs1 = dataset.points
ptrs2 = dataset.plane_points[2]
loss_total1 = 0
loss_total2 = 0
for i,ptr2 in enumerate(ptrs2):
    loss_total1 += cross_entropy(torch.sum((ptrs1 - ptr2[None,:])**2, dim=-1), torch.tensor(i).long())
print(loss_total1*2 / len(ptrs2))
for i,ptr1 in enumerate(ptrs1):
    loss_total2 += cross_entropy(torch.sum((ptr1[None,:] - ptrs2)**2, dim=-1), torch.tensor(i).long())
print(loss_total2*2 / len(ptrs1))
print((loss_total1+loss_total2)/len(ptrs2))

tensor(18.0287)
tensor(18.1396)
tensor(18.0842)


In [94]:
print(ptrs1.shape, ptrs2.shape)
l1 = model.loss(model.compute_corr_func(ptrs1, ptrs2, 3))
l1

torch.Size([1000, 3]) torch.Size([1000, 3])


tensor(13.9890, grad_fn=<AddBackward0>)

In [273]:
print(dataset.disks[0])

tensor([-0.3848,  0.8510, -0.3575])
