In [None]:
import pandas as pd
from rindti.data import PreTrainDataset
from collections import defaultdict
from torch_geometric.loader import DataLoader
from rindti.models import PfamModel
import random
from pytorch_lightning import Trainer
import torch
from sklearn.manifold import TSNE
import plotly.express as px

In [None]:
ds = PreTrainDataset("/scratch/SCRATCH_NVME/ilya/pretrain_data/pfam_no_orphans.pkl")
fams = defaultdict(list)
for idx, prot in enumerate(ds):
    fams[prot.fam].append(idx)

In [None]:
def get_top_fam_ids(fams, k=5, sample=None):
    fam_lens = pd.Series({k:len(v) for k,v in fams.items()})
    good_fams = fam_lens.sort_values(ascending=False).head(k).index
    res = []
    for i in good_fams:
        res += fams[i]
    if sample:
        return random.choices(res, k=sample)
    return res

In [None]:
subset = ds[get_top_fam_ids(fams)]

In [None]:
class TestModel(PfamModel):
    def predict_step(self, data, *args):
        embed = self.encoder(data)
        return dict(embeds=embed.detach().cpu(), fam=data.fam, id=data.id)

In [None]:
model = TestModel.load_from_checkpoint('tb_logs/pfam/version_5/checkpoints/epoch=140-step=140999.ckpt')
model.eval()
encoder = model.encoder
encoder.return_nodes = False

In [None]:
dl = DataLoader(subset, batch_size=128)
trainer = Trainer(gpus=1)
prediction = trainer.predict(model, dl)

In [None]:
embeds = torch.cat([x['embeds'] for x in prediction])
batch_id = []
batch_fam = []
for batch in prediction:
    batch_id += batch['id']
    batch_fam += batch['fam']

In [None]:
tsne = TSNE()
x = tsne.fit_transform(embeds)
x = pd.DataFrame(data=x)
x['id'] = batch_id
x['fam'] = batch_fam

In [None]:
fig = px.scatter(x, 0, 1, opacity=0.7, color="fam", width=1000, height=800, hover_name="id")
# fig.update_traces(marker=dict(size=5))
fig.show()

In [None]:
sub_embeds = embeds[:2000]

In [None]:
dist = torch.cdist(sub_embeds, sub_embeds)

In [None]:
import seaborn as sns
sns.clustermap(dist)

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
fig = plt.figure()
sns.heatmap(dist)