In [71]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt

In [101]:
class SphereDatasetAnchor(Dataset):
    def __init__(self, extra_modes, num_points, transform_fn=None):
        self.n = num_points
        self.m = extra_modes + 1
        self.points = np.random.uniform(-1,1,size=(self.n, 3))
        mags = np.linalg.norm(self.points, axis=-1)
        self.points /= mags[...,None]
        self.disks = np.array([np.random.uniform(-1,1,size=(3)) for i in range(extra_modes)])
        self.disks = np.array([disk / np.linalg.norm(disk) for disk in self.disks])
        self.disk_points = torch.tensor(np.array([self.project_points(self.points, disk) for disk in self.disks])).float()
        self.points = torch.tensor(self.points).float()
        self.disks = torch.tensor(self.disks).float()
        self.transform_fn = transform_fn
        
    def project_points(self, points, disk):
        output_points = []
        for point in points:
            cosval = np.dot(point, disk)
            arrow = (point*1/np.dot(point, disk) - disk) * cosval # mind the sign of cosval
            arrow /= np.linalg.norm(arrow)
            output_points.append(arrow)
        return np.array(output_points)
    
    def __len__(self):
        return self.m-1
    
    def __getitem__(self, idx):
        e_points = self.disk_points[idx]
        if self.transform_fn is not None:
            return self.transform_fn.evaluate(self.points), self.transform_fn.evaluate(e_points)
        return self.points, e_points
    
    def all_points(self):
        return [self.points] + list(self.disk_points)

In [150]:
from torch.nn.functional import cross_entropy, relu
class Mapper(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 128)
        self.linear2 = nn.Linear(128, 128)
        self.linear3 = nn.Linear(128, 128)
        self.linear4 = nn.Linear(128, 128)
    
    def forward(self, X):
        return self.linear3(relu(self.linear2(relu(self.linear(X)))))

    def evaluate(self, X):
        self.eval()
        with torch.no_grad():
            return self.forward(X)

mapper = Mapper()

In [151]:
modes = 3
dataset = SphereDatasetAnchor(modes, 300, transform_fn=mapper)

In [188]:
from torch.nn.functional import cross_entropy, relu
class ContrastiveModel(nn.Module):
    def __init__(self, modes):
        super().__init__()
        self.linears = nn.ModuleList([nn.Linear(128, 128) for i in range(modes)])
        self.linears2 = nn.ModuleList([nn.Linear(128, 128) for i in range(modes)])
        self.linears3 = nn.ModuleList([nn.Linear(128, 128) for i in range(modes)])
        self.linears4 = nn.ModuleList([nn.Linear(128, 128) for i in range(modes)])
        self.linears5 = nn.ModuleList([nn.Linear(128, 3) for i in range(modes)])
        self.comps = nn.ParameterList([nn.Parameter(torch.zeros(3,3)) for i in range(modes-1)])
        for comp in self.comps:
            nn.init.xavier_uniform_(comp)
    
    def loss(self, corr_mat):
        losses_row = cross_entropy(corr_mat, torch.arange(0, len(corr_mat)).long())
        losses_col = cross_entropy(torch.transpose(corr_mat,0,1), torch.arange(0, len(corr_mat)).long())
        return torch.mean(losses_row) + torch.mean(losses_col)
    
    def compute_corr_dot(self, a_points, b_points):
        return torch.sum(a_points[:,None] * b_points[None,:], dim=-1)
    
    def compute_corr_func(self, a_points, b_points, b):
        return self.comps4[b-1](relu(self.comps3[b-1](relu(self.comps2[b-1](relu(self.comps[b-1](a_points[:,None] * b_points[None,:])))))))[...,0]
        #return self.compute_corr_dot(a_points_mapped, b_points)
    
    def forward(self, a, b, a_points, b_points):
        #a_points = self.linears2[a](relu(self.linears[a](a_points)))
        #b_points = self.linears2[b](relu(self.linears[b](b_points)))
        a_points = self.run_evaluate(a, b, a_points)
        b_points = self.run_evaluate(b, b, b_points)
        
        cov_b_points = torch.cov(b_points.T)
        a_points = a_points @ (cov_b_points + torch.eye(len(cov_b_points)))/10
        a_points = a_points / torch.linalg.norm(a_points, dim=-1, keepdim=True)
        corr_mat = self.compute_corr_dot(a_points, b_points)
        return corr_mat
    
    def run_evaluate(self, i, b, points):
        #points = self.linears2[i](relu(self.linears[i](points)))
        points = self.linears5[i](relu(self.linears4[i](relu(self.linears3[i](relu(self.linears2[i](relu(self.linears[i](points)))))))))
        return points / torch.linalg.norm(points, dim=-1, keepdim=True)
       

In [189]:
model = ContrastiveModel(modes+1)

In [190]:
from torch.optim import Adam
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)
optim = Adam(model.parameters(), lr=0.003)

In [191]:
def train_data(optim, model, ptrs1, ptrs2, a, b, iteration=10):
    avg_loss = 0
    for i in range(iteration):
        optim.zero_grad()
        corr_mat = model(a, b, ptrs1, ptrs2)
        loss = model.loss(corr_mat)
        loss.backward()
        optim.step()
        avg_loss += loss.item() / iteration
    return avg_loss

In [192]:
EPOCH = 2000
for epoch in range(EPOCH):
    for i in range(modes):
        ptrs1, ptrs2 = dataset[i]
        loss = train_data(optim, model, ptrs1, ptrs2, 0, i+1, iteration=1)
        print(f'{i+1}: {loss}')

1: 11.407588005065918
2: 11.407648086547852
3: 11.407527923583984
1: 11.406124114990234
2: 11.400733947753906
3: 11.399676322937012
1: 11.349061012268066
2: 11.226985931396484
3: 11.29177474975586
1: 11.08752727508545
2: 10.708202362060547
3: 11.052816390991211
1: 10.681585311889648
2: 10.641124725341797
3: 10.821950912475586
1: 10.51395034790039
2: 10.880460739135742
3: 11.27234172821045
1: 10.78799057006836
2: 10.583988189697266
3: 10.614498138427734
1: 10.593138694763184
2: 10.916351318359375
3: 10.654699325561523
1: 10.539761543273926
2: 10.505266189575195
3: 10.592628479003906
1: 10.611528396606445
2: 10.534530639648438
3: 10.557348251342773
1: 10.435039520263672
2: 10.392190933227539
3: 10.536432266235352
1: 10.516067504882812
2: 10.377138137817383
3: 10.53541374206543
1: 10.314519882202148
2: 10.383537292480469
3: 10.486799240112305
1: 10.264172554016113
2: 10.426309585571289
3: 10.449847221374512
1: 10.428665161132812
2: 10.367196083068848
3: 10.391636848449707
1: 10.2380771636

In [199]:
orig_points = dataset.all_points()

In [200]:
import plotly.express as px
import pandas as pd

In [201]:
df = pd.DataFrame(np.concatenate(orig_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*300 for i in range(modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [202]:
pred_points = [model.run_evaluate(i, i, mapper(points)).detach().numpy() for i, points in enumerate(orig_points)]

In [203]:
df = pd.DataFrame(np.concatenate(pred_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*300 for i in range(modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

In [204]:
model.loss(model.compute_corr_dot(orig_points[1], orig_points[1]))

tensor(9.8895)

In [205]:
a = (model.comps[0]+model.comps[0].T)/2
print(a)
torch.linalg.eig(a)

tensor([[-0.3131,  0.1132,  0.3825],
        [ 0.1132, -0.1749,  0.4432],
        [ 0.3825,  0.4432,  0.5491]], grad_fn=<DivBackward0>)


torch.return_types.linalg_eig(
eigenvalues=tensor([ 0.8889+0.j, -0.4790+0.j, -0.3488+0.j], grad_fn=<LinalgEigBackward0>),
eigenvectors=tensor([[ 0.3122+0.j,  0.8078+0.j,  0.4999+0.j],
        [ 0.3935+0.j,  0.3690+0.j, -0.8420+0.j],
        [ 0.8647+0.j, -0.4596+0.j,  0.2027+0.j]], grad_fn=<LinalgEigBackward0>))

In [209]:
b=1
points = torch.tensor(pred_points[0])
b_points = torch.tensor(pred_points[b])
cov_b_points = torch.cov(b_points.T)
print(cov_b_points)
points = points @ cov_b_points
#points = torch.tensordot(points, (model.comps[b-1] + model.comps[b-1].T)/2, dims=1)
points /= torch.linalg.norm(points, dim=-1, keepdim=True)
new_points = [ptrs.copy() for ptrs in pred_points]
new_points[0] = points.detach()
df = pd.DataFrame(np.concatenate(new_points, axis=0), columns=['x', 'y', 'z'])
df['type'] = sum([[f'{i}']*300 for i in range(modes+1)], [])
fig = px.scatter_3d(df, x='x', y='y', z='z', color='type')
fig.update_traces(marker={'size': 5})
fig.show()

tensor([[ 0.3256,  0.1009, -0.0043],
        [ 0.1009,  0.4323, -0.0123],
        [-0.0043, -0.0123,  0.2279]])


In [56]:
def comp_stable_rank(ptrs):
    corr = np.matmul(ptrs.T, ptrs)/len(ptrs)
    sing = np.linalg.svd(corr)[1][0]
    norm = np.linalg.norm(corr)
    return norm**2/sing**2
    
print(comp_stable_rank(pred_points[0]))
print(comp_stable_rank(pred_points[1]))

1.0441717074138692
1.594033906400954


In [57]:
ptrs1 = dataset.points
ptrs2 = dataset.disk_points[1]
loss_total = 0
for i,ptr2 in enumerate(ptrs2):
    loss_total += cross_entropy(torch.sum(ptrs1 * ptr2[None,:], dim=-1), torch.tensor(i).long())
for i,ptr1 in enumerate(ptrs1):
    loss_total += cross_entropy(torch.sum(ptr1[None,:] * ptrs2, dim=-1), torch.tensor(i).long())
print(loss_total/len(ptrs2))

tensor(10.1642)


In [58]:
l1 = cross_entropy(torch.sum(ptrs1[:,None]*ptrs2[None,:], dim=-1), torch.arange(len(ptrs1)).long())
l2 = cross_entropy(torch.sum(ptrs2[:,None]*ptrs1[None,:], dim=-1), torch.arange(len(ptrs1)).long())
l1+l2

tensor(10.1642)

In [62]:
model.loss(model.compute_corr_dot(ptrs2, ptrs2))

tensor(12.2896)