In [1]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation
import plotly.express as px
import wandb
from src.esm_embedder import ESMEmbedder
from pytorch_lightning import seed_everything
from pathlib import Path
import pandas as pd
import pickle

seed_everything(42)

Global seed set to 42


42

In [2]:
api = wandb.Api()
current_run = wandb.init()

[34m[1mwandb[0m: Currently logged in as: [33milsenatorov[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [3]:
class Model(pl.LightningModule):
    def __init__(self, input_dim, hidden_dim: int):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.LazyLinear(1),
        )

    def forward(self, x):
        return self.model(x).squeeze(1)

In [4]:
input_dims = {
    "esm2_t33_650M_UR50D": 1280,
    "esm2_t30_150M_UR50D": 640,
    "esm2_t12_35M_UR50D": 480,
    "esm2_t6_8M_UR50D": 320,
}
model_names = {
    33: "esm2_t33_650M_UR50D",
    30: "esm2_t30_150M_UR50D",
    12: "esm2_t12_35M_UR50D",
    6: "esm2_t6_8M_UR50D",
}

In [5]:
def get_baseline(num_layers: int, current_layer: int):
    data = [("prot1", "")]
    embedder = ESMEmbedder(num_layers)
    reprs = embedder.run(data)[0]["representations"]
    return torch.stack([reprs[x].squeeze(0).mean(0) for x in range(num_layers + 1)])[current_layer]


def get_data(dataset: str, num_layers: str, current_layer: int, num_samples: int = 1000):
    p = Path("/shared") / dataset / model_names[num_layers] / "test"
    assert p.exists()
    embeddings = []
    fluorescence_values = {}
    for idx, i in enumerate(p.glob("*.pt")):
        fluorescence_values[str(i)] = float(i.stem.split("|")[-1])
    fluorescence_values = pd.Series(fluorescence_values).sort_values()
    fluorescence_values = fluorescence_values.index.to_list()
    for idx, i in enumerate(fluorescence_values[: num_samples // 2] + fluorescence_values[-num_samples // 2 :]):
        t = torch.load(i)
        emb = torch.stack([t["mean_representations"][x] for x in range(num_layers + 1)])
        embeddings.append(emb)
    return torch.stack(embeddings)[:, current_layer, :]

In [None]:
for ds_name in ["stability", "fluorescence"]:
    runs = api.runs(f"smtb2023/{ds_name}")

    run_attributions = []

    for model_run in runs:
        if model_run.state != "finished":
            continue
        run_id = model_run.id
        model_name = model_run.config["model_name"]
        # if model_name != "esm2_t6_8M_UR50D":
        #     continue
        model_layer = model_run.config["layer_num"]
        total_num_layers = int(model_name.split("_")[1][1:])
        data = get_data(ds_name, total_num_layers, model_layer)
        baseline = get_baseline(total_num_layers, model_layer)
        baseline = torch.stack([baseline] * data.size(0))
        artifact = current_run.use_artifact(f"smtb2023/{ds_name}/model-{run_id}:v0", type="model")
        artifact_dir = artifact.download()
        model = Model.load_from_checkpoint(
            f"artifacts/model-{run_id}:v0/model.ckpt", input_dim=input_dims[model_name], hidden_dim=512
        )
        model.eval()
        model = model.cuda()
        data = data.cuda()
        baseline = baseline.cuda()

        # ig = IntegratedGradients(model)
        # ig_nt = NoiseTunnel(ig)
        dl = DeepLift(model)
        # gs = GradientShap(model)
        # fa = FeatureAblation(model)

        # ig_attr = ig.attribute(data, baselines=baseline, n_steps=50)
        # ig_nt_attr = ig_nt.attribute(data, baselines=baseline)
        dl_attr = dl.attribute(data, baselines=baseline)
        # gs_attr = gs.attribute(data, baselines=baseline, X_train)
        # fa_attr = fa.attribute(data, baselines=baseline)
        run_attributions.append(
            {
                "model_name": model_name,
                "layer": model_layer,
                "run_id": run_id,
                # "ig_attr": ig_attr.mean(0).cpu(),
                # "ig_nt_attr": ig_nt_attr.mean(0).cpu(),
                "dl_attr": dl_attr.detach().cpu(),
                # "gs_attr": gs_attr.mean(0).cpu(),
                # "fa_attr": fa_attr.mean(0).cpu(),
                "data": data.cpu(),
            }
        )
    with open(f"{ds_name}.pkl", "wb") as f:
        pickle.dump(run_attributions, f)

100%|██████████| 1/1 [00:00<00:00,  7.59it/s]
[34m[1mwandb[0m:   1 of 1 files downloaded.  
               activations. The hooks and attributes will be removed
            after the attribution is finished
100%|██████████| 1/1 [00:00<00:00, 22.41it/s]
[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 1/1 [00:00<00:00, 28.02it/s]
[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 1/1 [00:00<00:00, 28.65it/s]
[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 1/1 [00:00<00:00, 28.56it/s]
[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 1/1 [00:00<00:00, 28.55it/s]
[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 1/1 [00:00<00:00, 26.75it/s]
[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 1/1 [00:00<00:00, 27.46it/s]
[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 1/1 [00:00<00:00, 28.29it/s]
[34m[1mwandb[0m:   1 of 1 files downloaded.  
100%|██████████| 1/1 [00:00<00