In [None]:
from collections import defaultdict

import pandas as pd
import plotly.express as px
import torch
from pytorch_lightning import Trainer, seed_everything
from sklearn.cluster import KMeans
from sklearn.manifold import TSNE
from torch_geometric.loader import DataLoader
from torchmetrics.functional import accuracy

from rindti.data import PreTrainDataset
from rindti.models import ProtClassModel

In [None]:
seed_everything(42)

In [None]:
ds = PreTrainDataset(
    "/scratch/SCRATCH_NVME/ilya/pretrain_data/pfam_fragments_distance_label_none.pkl"
)

In [None]:
fams = defaultdict(list)
for ind, val in enumerate(ds):
    fams[val.y].append(ind)

In [None]:
fams_vc = pd.Series({k: len(v) for k, v in fams.items()})

In [None]:
subset_index = []
for fam in fams_vc.sort_values().tail(10).index:
    print(fam)
    subset_index += fams[fam]

In [None]:
subset = ds[subset_index]
# subset = subset[torch.randperm(len(subset))][:1000]

In [None]:
dl = DataLoader(subset, batch_size=128, num_workers=16)

In [None]:
class TestModel(ProtClassModel):
    def acc(self, embed, data_y):
        pred = self.loss.mlp(embed)
        labels = torch.tensor(self.loss.label_encoder.transform(data_y))
        return accuracy(pred.cpu(), labels)

In [None]:
class NormalModel(ProtClassModel):
    def predict_step(self, data, *args):
        embed = self.encoder(data)
        pred = self.loss.mlp(embed)
        labels = torch.tensor(self.loss.label_encoder.transform(data.y))
        return dict(
            embeds=embed.detach().cpu(),
            fam=data.y,
            id=data.id,
            acc=accuracy(pred.cpu(), labels),
        )


class MaskedModel(ProtClassModel):
    def predict_step(self, data, *args):
        data.x = torch.ones_like(data.x)
        embed = self.encoder(data)
        pred = self.loss.mlp(embed)
        labels = torch.tensor(self.loss.label_encoder.transform(data.y))
        return dict(
            embeds=embed.detach().cpu(),
            fam=data.y,
            id=data.id,
            acc=accuracy(pred.cpu(), labels),
        )


class ShuffledModel(ProtClassModel):
    def predict_step(self, data, *args):
        data.x = data.x[torch.randperm(data.x.size(0))]
        embed = self.encoder(data)
        pred = self.loss.mlp(embed)
        labels = torch.tensor(self.loss.label_encoder.transform(data.y))
        return dict(
            embeds=embed.detach().cpu(),
            fam=data.y,
            id=data.id,
            acc=accuracy(pred.cpu(), labels),
        )


class SequenceModel(ProtClassModel):
    def predict_step(self, data, *args):
        ei = data.edge_index
        data.edge_index = ei[:, (ei[0] - ei[1]).abs() <= 1]
        embed = self.encoder(data)
        pred = self.loss.mlp(embed)
        labels = torch.tensor(self.loss.label_encoder.transform(data.y))
        return dict(
            embeds=embed.detach().cpu(),
            fam=data.y,
            id=data.id,
            acc=accuracy(pred.cpu(), labels),
        )


class NothingModel(ProtClassModel):
    def predict_step(self, data, *args):
        ei = data.edge_index
        data.edge_index = ei[:, (ei[0] - ei[1]).abs() <= 1]
        data.x = torch.ones_like(data.x)
        embed = self.encoder(data)
        pred = self.loss.mlp(embed)
        labels = torch.tensor(self.loss.label_encoder.transform(data.y))
        return dict(
            embeds=embed.detach().cpu(),
            fam=data.y,
            id=data.id,
            acc=accuracy(pred.cpu(), labels),
        )

In [None]:
def plot(modelname: str):
    Model = {
        "masked": MaskedModel,
        "shuffled": ShuffledModel,
        "normal": NormalModel,
        "sequence": SequenceModel,
        "nothing": NothingModel,
    }[modelname]
    model = Model.load_from_checkpoint(
        "./tb_logs/class/version_1/checkpoints/epoch=359-step=669599.ckpt"
    )
    model.eval()
    encoder = model.encoder
    encoder.return_nodes = False

    trainer = Trainer(devices=1)
    prediction = trainer.predict(model, dataloaders=[dl])

    embeds = torch.cat([x["embeds"] for x in prediction])
    batch_id = []
    batch_fam = []
    batch_acc = []
    for batch in prediction:
        batch_id += batch["id"]
        batch_fam += batch["fam"]
        batch_acc.append(batch["acc"].item())
    km = KMeans(n_clusters=10).fit(embeds)
    df = pd.DataFrame(pd.Series(km.labels_, name="km"))
    df["fam"] = batch_fam
    acc = (
        df.groupby("fam").apply(lambda x: x["km"].value_counts().head(1)).sum()
        / df.shape[0]
    )
    print("Accuracy: " + str(acc))

    tsne = TSNE(perplexity=30)
    x = tsne.fit_transform(embeds)
    x = pd.DataFrame(data=x, columns=["x", "y"])

    x.columns = ["x", "y"]
    x["fam"] = batch_fam
    x["id"] = batch_id

    ### NORMAL
    fig = px.scatter(
        x,
        "x",
        "y",
        opacity=0.4,
        width=1000,
        height=1000,
        color="fam",
        hover_name="id",
        hover_data=["fam"],
    )
    fig.update_traces(marker=dict(size=8, line=dict(width=0, color="black")))
    fig.update(layout_showlegend=False)
    fig.update_layout(margin=dict(t=5, b=5, l=5, r=5))
    fig.write_image(f"figs/{modelname}.png", scale=3.0)
    fig.show()

In [None]:
plot("normal")

In [None]:
plot("shuffled")

In [None]:
plot("masked")

In [None]:
plot("sequence")

In [None]:
plot("nothing")