In [None]:
import pandas as pd
from rindti.data import PreTrainDataset
from collections import defaultdict
from torch_geometric.loader import DataLoader
from rindti.models import PfamModel
import random
from pytorch_lightning import Trainer
import torch
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA
import plotly.express as px
import seaborn as sns
import numpy as np
from umap import UMAP
from rindti.losses import SoftNearestNeighborLoss, GeneralisedLiftedStructureLoss

In [None]:
ds = PreTrainDataset("/scratch/SCRATCH_NVME/ilya/pretrain_data/pfam_fragments_label_none.pkl")
fams = defaultdict(list)
for idx, prot in enumerate(ds):
    fams[prot.fam].append(idx)

In [None]:
def get_top_fam_ids(fams, k=5, sample=None):
    fam_lens = pd.Series({k:len(v) for k,v in fams.items()})
    good_fams = fam_lens.sort_values(ascending=False).head(k).index
    res = []
    for i in good_fams:
        res += fams[i]
    if sample:
        return random.choices(res, k=sample)
    return res

In [None]:
index = random.sample(range(len(ds)), 1000)

subset = ds[get_top_fam_ids(fams, k=20, sample=10000)]

In [None]:
class TestModel(PfamModel):
    def predict_step(self, data, *args):
        embed = self.encoder(data)
        return dict(embeds=embed.detach().cpu(), fam=data.fam, id=data.id)

In [None]:
model = TestModel.load_from_checkpoint("./tb_logs/pfam/version_15/checkpoints/epoch=158-step=158999.ckpt")
model.eval()
encoder = model.encoder
encoder.return_nodes = False

In [None]:
dl = DataLoader(subset, batch_size=64, shuffle=False)
trainer = Trainer(gpus=1)
prediction = trainer.predict(model, dl)

In [None]:
embeds = torch.cat([x['embeds'] for x in prediction])
batch_id = []
batch_fam = []
for batch in prediction:
    batch_id += batch['id']
    batch_fam += batch['fam']

In [None]:
tsne = TSNE(perplexity=5)
x = tsne.fit_transform(embeds)
x = pd.DataFrame(data=x, columns=["x", "y"])

In [None]:
x.columns = ["x", "y"]
x['fam'] = batch_fam
x['id'] = batch_id

In [None]:
fig = px.scatter(x, "x", "y", opacity=0.4,
                 width=1000, height=800, color="fam",
                 hover_name="id", hover_data=["fam"], 
                color_discrete_sequence=px.colors.qualitative.Light24,)
fig.update_traces(marker=dict(size=8, line=dict(width=0, color='black')))
fig.write_html("test.html")
fig.show()

In [None]:
x.to_csv("fragment_embed_top20.tsv", sep='\t')