In [None]:
import pandas as pd
from rindti.data import PreTrainDataset
from collections import defaultdict
from torch_geometric.loader import DataLoader
from rindti.models import PfamModel
import random
from pytorch_lightning import Trainer
import torch
from sklearn.manifold import TSNE
import plotly.express as px
import seaborn as sns
import numpy as np
from umap import UMAP

In [None]:
ds = PreTrainDataset("/scratch/SCRATCH_NVME/ilya/pretrain_data/pfam_label_none_proc.pkl")
fams = defaultdict(list)
for idx, prot in enumerate(ds):
    fams[prot.fam].append(idx)

In [None]:
def get_top_fam_ids(fams, k=5, sample=None):
    fam_lens = pd.Series({k:len(v) for k,v in fams.items()})
    good_fams = fam_lens.sort_values(ascending=False).head(k).index
    res = []
    for i in good_fams:
        res += fams[i]
    if sample:
        return random.choices(res, k=sample)
    return res

In [None]:
subset = ds[get_top_fam_ids(fams, k=10, sample=20000)]

In [None]:
class TestModel(PfamModel):
    def predict_step(self, data, *args):
        embed = self.encoder(data)
        return dict(embeds=embed.detach().cpu(), fam=data.fam, id=data.id)

In [None]:
import yaml
with open("config/pfam.yaml", "r") as file:
    config = yaml.load(file, yaml.FullLoader)
config["feat_dim"] = 20
config['edge_type'] = "none"
config['feat_type'] = "label"

In [None]:
model = TestModel.load_from_checkpoint("tb_logs/pfam/version_83/checkpoints/epoch=266-step=266999.ckpt")
model.eval()
encoder = model.encoder
encoder.return_nodes = False

In [None]:
dl = DataLoader(subset, batch_size=128, shuffle=False)
trainer = Trainer(gpus=1)
prediction = trainer.predict(model, dl)

In [None]:
embeds = torch.cat([x['embeds'] for x in prediction])
batch_id = []
batch_fam = []
for batch in prediction:
    batch_id += batch['id']
    batch_fam += batch['fam']

In [None]:
fam_idx = defaultdict(list)
for i, fam in enumerate(batch_fam):
    fam_idx[fam].append(i)
fam_idx = [x for x in fam_idx.values()]

In [None]:
help(UMAP)

In [None]:
tsne = UMAP(n_neighbors=50, min_dist=0.05)
x = tsne.fit_transform(embeds)
x = pd.DataFrame(data=x)


fig = px.scatter(x, 0, 1, opacity=0.5, width=1000, height=800, color=batch_fam)
fig.update_traces(marker=dict(size=8, line=dict(width=0.1, color='black')))

fig.show()