In [None]:
import pandas as pd
from rindti.data import PreTrainDataset
from collections import defaultdict
from torch_geometric.loader import DataLoader
from rindti.models import PfamModel
import random
from pytorch_lightning import Trainer
import torch
from sklearn.manifold import TSNE
import plotly.express as px
import seaborn as sns
import numpy as np

In [None]:
ds = PreTrainDataset("/scratch/SCRATCH_NVME/ilya/pretrain_data/pfam_label_none_proc.pkl")
fams = defaultdict(list)
for idx, prot in enumerate(ds):
    fams[prot.fam].append(idx)

In [None]:
def get_top_fam_ids(fams, k=5, sample=None):
    fam_lens = pd.Series({k:len(v) for k,v in fams.items()})
    good_fams = fam_lens.sort_values(ascending=False).head(k).index
    res = []
    for i in good_fams:
        res += fams[i]
    if sample:
        return random.choices(res, k=sample)
    return res

In [None]:
subset = ds[get_top_fam_ids(fams, k=5, sample=1000)]

In [None]:
class TestModel(PfamModel):
    def predict_step(self, data, *args):
        embed = self.encoder(data)
        return dict(embeds=embed.detach().cpu(), fam=data.fam, id=data.id)

In [None]:
import yaml
with open("config/pfam.yaml", "r") as file:
    config = yaml.load(file, yaml.FullLoader)
config["feat_dim"] = 20
config['edge_type'] = "none"
config['feat_type'] = "label"

In [None]:
model = TestModel.load_from_checkpoint("tb_logs/pfam/version_67/checkpoints/epoch=62-step=62999.ckpt")
model.eval()
encoder = model.encoder
encoder.return_nodes = False

In [None]:
dl = DataLoader(subset, batch_size=128, shuffle=False)
trainer = Trainer(gpus=0)
prediction = trainer.predict(model, dl)

In [None]:
embeds = torch.cat([x['embeds'] for x in prediction])
batch_id = []
batch_fam = []
for batch in prediction:
    batch_id += batch['id']
    batch_fam += batch['fam']

In [None]:
fam_idx = []
for i in set(batch_fam):
    fam_list = []
    for j, fam in enumerate(batch_fam):
        if fam == i:
            fam_list.append(j)
    fam_idx.append(fam_list)

In [None]:
def soft_nearest_neighbor_loss(embeds):
    norm_emb = torch.nn.functional.normalize(embeds)
    sim = 1 - torch.matmul(norm_emb, norm_emb.t())
    return _get_loss(fam_idx, sim, 100)

def _get_fam_loss(expsim, idx):
    pos_idxt = torch.tensor(idx)
    pos = expsim[pos_idxt[:, None], pos_idxt]
    batch = expsim[:, pos_idxt]
    print(pos.shape)
    print(batch.shape)
    return -torch.log(pos.sum(dim=0) / batch.sum(dim=0))

def _inverted_eye(bsize):
    return 1.0 - torch.eye(bsize)

def _get_loss(fam_idx, sim, tau):
    expsim = torch.exp(-sim / tau) - torch.eye(sim.size(0))
    return torch.cat([_get_fam_loss(expsim, idx) for idx in fam_idx])

In [None]:
all_idx = set(range(50))
def generalised_lifted_structure_loss(embeds):
    losses = []
    for idx in fam_idx:
        dist = torch.cdist(embeds, embeds)
        pos_idxt = torch.tensor(idx)
        neg_idxt = torch.tensor(list(all_idx.difference(idx)))
        pos = dist[pos_idxt[:, None], pos_idxt]
        neg = dist[neg_idxt[:, None], pos_idxt]
        pos_loss = torch.logsumexp(pos, dim=0)
        neg_loss = torch.logsumexp(0.2 - neg, dim=0)
        losses.append(torch.relu(pos_loss + neg_loss) ** 2)
    return torch.cat(losses)

In [None]:
a = torch.randn((25, 5)) + 0.1
b = torch.randn((25, 5)) - 0.1
embeds = torch.cat([a,b]).type(torch.float32)
embeds = torch.nn.functional.normalize(embeds)

tsne = TSNE()
x = tsne.fit_transform(embeds)
x = pd.DataFrame(data=x)

losses = generalised_lifted_structure_loss(embeds)

fig = px.scatter(x, 0, 1, opacity=0.7, symbol=["a"] * 25 + ['b'] * 25, width=1000, height=800, color=losses.tolist(), symbol_sequence=['circle', 'cross'])
fig.update_traces(marker=dict(size=15, line=dict(width=0.5, color='black')))

fig.show()

In [None]:
fam_idx = [list(range(25)), list(range(25, 50))]