# Input to Output Contribution

## Dependencies

In [1]:
# ! pip install amplify textualheatmap captum

In [None]:
import torch
from captum.attr import IntegratedGradients, configure_interpretable_embedding_layer
from textualheatmap import TextualHeatmap


from utils import load_pickle_dataset, load_from_hf, load_from_mila

## Arguments

In [3]:
# Model
source = "mila"
model_name = "AMPLIFY350M"
model_path = "../outputs/MILA_PLM_350M_UR100P/checkpoint/pytorch_model.pt"
tokenizer_path = None 
config_path = "../outputs/MILA_PLM_350M_UR100P/checkpoint/config.yaml"
device = "cuda"
compile = False
fp16 = True

# Dataset
data_name = "CASP14"
data_path = "../uniref/mila/casp14.pickle"
n_proteins = 5
max_length = 512 # AMPLIFY was trained with a context length of 512

# Integrated Gradient
num_steps = 32
batch_size = 32
    
# Log
output_file = "../outputs/AMPLIFY_Attribution_CASP14.csv"

## Attribution

In [None]:
# Get model and tokenizer
if source == "hf":
    model, tokenizer = load_from_hf(model_path, tokenizer_path, fp16=fp16)
    ig = IntegratedGradients(lambda src: model(inputs_embeds=src).logits)
    interpretable_embedding = configure_interpretable_embedding_layer(model, "esm.embeddings.word_embeddings")
    bos_id, mask_id, eos_id = tokenizer.cls_token_id, tokenizer.mask_token_id, tokenizer.eos_token_id
elif source == "mila":
    model, tokenizer = load_from_mila(model_path, config_path)
    ig = IntegratedGradients(lambda src: model(src).logits)
    interpretable_embedding = configure_interpretable_embedding_layer(model, "encoder")
    bos_id, mask_id, eos_id = tokenizer.bos_token_id, tokenizer.mask_token_id, tokenizer.eos_token_id
else:
    raise Exception("Only 'hf' and 'mila' sources are supported, not {source}.")
model.to(device)

In [5]:
# Load dataset
labels, proteins, dist_matrices = load_pickle_dataset(data_path, n_proteins, max_length)

Skipped T1033_6vr4_A because sequence length is longer than 510
Skipped T1033_6vr4_B because sequence length is longer than 510
Skipped T1061_7zqb_b because sequence length is longer than 510
Skipped T1061_7zqb_d because sequence length is longer than 510
Skipped T1061_7zqb_c because sequence length is longer than 510
Skipped T1061_7zn2_d because sequence length is longer than 510
Skipped T1061_7zn2_c because sequence length is longer than 510
Skipped T1061_7zn2_b because sequence length is longer than 510
Skipped T1061_7zqp_d because sequence length is longer than 510
Skipped T1061_7zqp_c because sequence length is longer than 510
Skipped T1061_7zhj_b because sequence length is longer than 510
Skipped T1061_7zqp_b because sequence length is longer than 510
Skipped T1061_7zhj_d because sequence length is longer than 510
Skipped T1061_7zhj_c because sequence length is longer than 510


In [6]:
# TF32 abd FP16
with torch.autocast(device_type=device, dtype=torch.float16, enabled=fp16):
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

    # Iterate over the first n_proteins of the dataset
    for label, protein, dist_matrix in zip(labels, proteins, dist_matrices):
       
        # Tokenize the protein and decode to get the special tokens
        x = torch.as_tensor(tokenizer.encode(protein)).to(torch.long)
        protein = tokenizer.decode(x, skip_special_tokens=False).split()
        
        # Mask all positions
        masked_x = torch.where(torch.eye(x.size(0), dtype=torch.bool), mask_id, x.repeat(x.size(0), 1))
        masked_x = masked_x.to(device)

        # Get the reference sequence (<BOS> <PAD>  <PAD> ... <PAD> <PAD><EOS>)
        reference = torch.as_tensor([bos_id] + [mask_id] * (x.size(0) - 2) + [eos_id], dtype=torch.long)
        reference = reference.unsqueeze(0).to(device)

        # Compute the embeddings
        x_embeddings = interpretable_embedding.indices_to_embeddings(masked_x)
        reference_embeddings = interpretable_embedding.indices_to_embeddings(reference)

        # Iterate over the protein and compute the integrated gradients
        attributions = []
        for k in range(1, x.size(0) - 1):
            # Compute the integrated gradients
            att = ig.attribute(
                inputs=x_embeddings[k].unsqueeze(0),
                baselines=reference_embeddings,
                target=(k, x[k]),
                n_steps=num_steps,
                internal_batch_size=batch_size,
            )
            att = att.detach().cpu().sum(dim=2).squeeze(0)
            attributions.append({"token": protein[k], "heat": [v.item() for v in att[1:-1]]})
        
        heatmap = TextualHeatmap(facet_titles=[label], width=900)
        heatmap.set_data([attributions])

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>