In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import torch
import pandas as pd
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots
import numpy as np
import cv2
from tqdm.notebook import tqdm
from bertviz import head_view, model_view
from pycm import ConfusionMatrix
from train import *

In [4]:
cfg = load_cfg()
cfg.data_root = "/Users/fga/data/tmh"
cfg.batch_size = 1
cfg.mean_embedding = False
cfg.embed_dim = 256

cfg.num_heads = 2
cfg.depth = 12
cfg.depth_token_only = 1

dataset = TMH(cfg=cfg, debug=True)
dataset.prepare_data()
model = CaiT(cfg=cfg)
module = CaitModule(cfg=cfg, model=model)
ckpt = torch.load("/Users/fga/docs/uni/cls-protein-prediction/models/CAIT.ckpt", map_location='cpu')
module.load_state_dict(ckpt["state_dict"])

Processed dataset found, loading pickle.


<All keys matched successfully>

In [5]:
set_names = ["train", "val", "test"]
sets = [dataset.datasets[name] for name in set_names]
colors = px.colors.sequential.gray
fig = go.Figure(data=[
    go.Bar(name=set_names[i], x=[len(sets[i])], y=['Samples'], orientation='h', marker_color=colors[i*len(colors)//len(sets)])
    for i in range(len(sets))
])
# Change the bar mode
fig.update_layout(
    barmode='stack',
    height=300
)
fig.show()

In [6]:
def classes_df_for_dataset(set, name):
    classes = [row["class"] for _, row in tqdm(set)]
    df = pd.DataFrame(classes)
    df.columns=["class"]
    return df.assign(set=name)
all_classes_df = pd.concat([classes_df_for_dataset(set, name) for set, name in zip(sets, set_names)])

  0%|          | 0/4124 [00:00<?, ?it/s]

  0%|          | 0/515 [00:00<?, ?it/s]

  0%|          | 0/515 [00:00<?, ?it/s]

In [31]:
fig = px.histogram(all_classes_df, x="class", facet_col="set", facet_col_spacing=0.05)
fig.update_yaxes(matches=None)
fig.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True))
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.update_layout(
    bargap=0.3,
    width=800,
    xaxis=dict(
        tickmode='array',
        tickvals=list(range(12)),
        ticktext=list(dataset.label_mappings.keys())*3
    )
)
fig.show()

In [57]:
outputs = torch.load("/Users/fga/Downloads/test_outputs-2.pkl", map_location="cpu")
preds = torch.cat([d[f"test_pred"].argmax(dim=1) for d in outputs]).int().numpy()
labels = torch.cat([d[f"test_label"] for d in outputs]).int().numpy()
cm = ConfusionMatrix(actual_vector=labels, predict_vector=preds)
cm_array = cm.to_array().astype(float)

In [60]:
fig = px.imshow(cm_array)
fig.update_layout(
    width=500,
    xaxis=dict(
        title="Predicted",
        tickmode='array',
        tickvals=list(range(4)),
        ticktext=list(dataset.label_mappings.keys())
    ),
    yaxis=dict(
        title="Actual",
        tickmode='array',
        tickvals=list(range(4)),
        ticktext=list(dataset.label_mappings.keys())
    )
)
fig.show()

In [98]:
metrics_df = pd.DataFrame({"acc":cm.ACC, "f1":cm.F1, "mcc":cm.MCC, "tpr":cm.TPR, "ppv":cm.PPV, "label":pd.Series(range(4))}).melt(id_vars=["label"], var_name="metric")
fig = px.bar(metrics_df, x="label", y="value", facet_col="metric", facet_col_spacing=0.04, orientation="v")
fig.update_yaxes(matches=None)
fig.for_each_yaxis(lambda yaxis: yaxis.update(showticklabels=True))
fig.for_each_annotation(lambda a: a.update(text=a.text.split("=")[-1]))
fig.update_traces(width=0.5)
fig.update_layout(
    bargap=0.1,
    xaxis = dict(
        tickmode='array',
        tickvals=list(range(4)),
        ticktext=list(dataset.label_mappings.keys())
    )
)
fig.show()

In [99]:
def ca_for(index, dataset):
    embed, row = dataset[index]
    embed = torch.Tensor(embed).unsqueeze(0)
    y_hat, attn_ws = model(embed)
    pred_correct = (y_hat.argmax() == row["class"]).all().item()

    sa_att = torch.stack(attn_ws[:-cfg.depth_token_only]).detach().squeeze(1).mean(1)
    ca_att = torch.stack(attn_ws[-cfg.depth_token_only:]).detach()[0,0,0,0,:]
    return ca_att, row

In [100]:
testset = sets[2]
aa_collector = {c:[] for c in range(4)}
anno_collector = {c:[] for c in range(4)}
for i in tqdm(range(100)):
    att, row = ca_for(i, testset)
    att = att[:-1]
    seq = row["seq"]
    anno = row["seq_anno"]
    if len(att) != len(seq) or len(att) != len(anno):
        continue
    aa_counter = {aa:[] for aa in set(seq)}
    anno_counter = {aa:[] for aa in set(anno)}
    for i, att_i in enumerate(att):
        aa_counter[seq[i]].append(att)
        anno_counter[anno[i]].append(att)

    aa_counter = {aa:torch.stack(att).mean().item() if len(att) > 0 else 0.0 for aa, att in aa_counter.items()}
    aa_counter = pd.DataFrame(aa_counter.items())
    aa_collector[row["class"]].append(aa_counter)

    anno_counter = {aa:torch.stack(att).mean().item() if len(att) > 0 else 0.0 for aa, att in anno_counter.items()}
    anno_counter = pd.DataFrame(anno_counter.items())
    anno_collector[row["class"]].append(anno_counter)

  0%|          | 0/100 [00:00<?, ?it/s]

In [119]:

px.bar(pd.concat(anno_collector[0]).groupby(0).mean(1))

In [163]:
px.line(ca_for(1000)[0])