In [None]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
from captum.attr import IntegratedGradients
import plotly.express as px
import wandb
from src.esm_embedder import ESMEmbedder
from pytorch_lightning import seed_everything
from pathlib import Path
import pandas as pd


seed_everything(42)

In [None]:
api = wandb.Api()
current_run = wandb.init()

In [None]:
class Model(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim: int):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.LazyLinear(1),
        )

    def forward(self, x):
        return self.model(x).squeeze(1)

In [None]:
input_dims = {
    "esm2_t33_650M_UR50D": 1280,
    "esm2_t30_150M_UR50D": 640,
    "esm2_t12_35M_UR50D": 480,
    "esm2_t6_8M_UR50D": 320,
}
model_names = {
    33: "esm2_t33_650M_UR50D",
    30: "esm2_t30_150M_UR50D",
    12: "esm2_t12_35M_UR50D",
    6: "esm2_t6_8M_UR50D",
}

In [None]:
def get_baseline(num_layers: int, current_layer: int):
    data = [("prot1", "")]
    embedder = ESMEmbedder(num_layers)
    reprs = embedder.run(data)[0]["representations"]
    return torch.stack([reprs[x].squeeze(0).mean(0) for x in range(num_layers + 1)])[current_layer]


def get_data(dataset: str, num_layers: str, current_layer: int, num_samples: int = 4):
    p = Path("/shared") / dataset / model_names[num_layers] / "test"
    assert p.exists()
    embeddings = []
    print(p)
    fluorescence_values = {}
    for idx, i in enumerate(p.glob("*.pt")):
        fluorescence_values[str(i)] = float(i.stem.split("|")[-1])
    fluorescence_values = pd.Series(fluorescence_values).sort_values()
    fluorescence_values = fluorescence_values.index.to_list()
    for idx, i in enumerate(fluorescence_values[: num_samples // 2] + fluorescence_values[-num_samples // 2 :]):
        t = torch.load(i)
        emb = torch.stack([t["mean_representations"][x] for x in range(num_layers + 1)])
        embeddings.append(emb)
    return torch.stack(embeddings)[:, current_layer, :]

In [None]:
ds_name = "fluorescence"
runs = api.runs(f"smtb2023/{ds_name}")

run_attributions = []


for model_run in runs:
    if model_run.state != "finished":
        continue
    run_id = model_run.id
    model_name = model_run.config["model_name"]
    model_layer = model_run.config["layer_num"]
    total_num_layers = int(model_name.split("_")[1][1:])
    data = get_data(ds_name, total_num_layers, model_layer)
    baseline = get_baseline(total_num_layers, model_layer)
    baseline = torch.stack([baseline] * data.size(0))
    artifact = current_run.use_artifact(f"smtb2023/fluorescence/model-{run_id}:v0", type="model")
    artifact_dir = artifact.download()
    model = Model.load_from_checkpoint(
        f"artifacts/model-{run_id}:v0/model.ckpt", input_dim=input_dims[model_name], hidden_dim=512
    )
    model.eval()

    ig = IntegratedGradients(model.cuda())
    attributions, delta = ig.attribute(data.cuda(), baseline.cuda(), return_convergence_delta=True)
    run_attributions.append(
        {
            "model_name": model_name,
            "layer": model_layer,
            "run_id": run_id,
            "attributions": attributions.mean(0).cpu(),
            "delta": delta.cpu(),
            "data": data.cpu(),
        }
    )

In [None]:
run_attributions[0]['delta'].size()

In [None]:
agg_runs = defaultdict(list)
for run in run_attributions:
    agg_runs[run["layer"]].append(run['attributions'])
agg_runs = {k:torch.stack(v) for k,v in agg_runs.items()}
agg_runs = torch.stack([v.mean(0) for k,v in agg_runs.items()])

In [None]:
def normalize(tensor, dim:int):
    return (tensor - tensor.min(dim=dim)[0]) / (tensor.max(dim=dim)[0] - tensor.min(dim=dim)[0])

In [None]:
px.imshow(agg_runs.mean(1), aspect=False)