In [1]:
from pathlib import Path
import hydra
from omegaconf import OmegaConf
import torch

from a3d.model.t5 import TinyT5

### Load model, prepare inputs

In [2]:
rundir = Path("checkpoints/tiny_t5/")
config = rundir / ".hydra/config.yaml"
checkpoint = rundir / "checkpoints/last.ckpt"

dict_conf = OmegaConf.load(config)
model: TinyT5 = hydra.utils.instantiate(dict_conf.model).eval()
weights = torch.load(checkpoint, map_location="cpu")["state_dict"]
model.load_state_dict(weights)
tokenizer = model._tokenizer

#### Generate with prompt

In [3]:
epitope = "LDSFKEELDKYFKNH"
ag_encoding = tokenizer([epitope], return_tensors="pt", padding="longest").input_ids
prompt = tokenizer(["=QVQLVQ"], return_tensors="pt").input_ids
if torch.cuda.is_available():
    ag_encoding = ag_encoding.to(device="cuda")
    prompt = prompt.to(device="cuda")
    model = model.to(device="cuda")
model.eval()
print(prompt)
with torch.no_grad():
    torch.manual_seed(42)
    outputs = model._t5.generate(
        inputs=ag_encoding,
        decoder_input_ids=prompt,
        return_dict_in_generate=True,
        use_cache = True,
        max_new_tokens = 250,
        min_length = 100,
        output_attentions = True,
        num_return_sequences = 1,
        do_sample = True,
        num_beams = 5,
        temperature = 1.2,
        top_p = 0.95,

    )
    decoded = tokenizer.batch_decode(outputs.sequences, skip_special_tokens=True)

for ab in decoded:
    print("\n".join(ab[1:].split("-")))

tensor([[ 5, 20, 24, 20, 16, 24, 20,  1]])
QVQLVQGAEVKKPGASVKVSCKASGYTFTSYIIHWVRQAPGQGLEWMGWINPNSGGTSYAQKFQGRVTMTRDTSTSTAYMELSSLRSDDTAVYYCAREGSPFYYGSGYYYYYYYWGQGTLVTVSS
EIVMTQSPATLSLSPGERATLSCRASQSVSSSYLAWYQQKPGQAPRLLIYGASSRATGIPDRFSGSGSGTDFTLTISRLEPEDFAVYYCQQYGSSPWTFGQGTKVEIK


#### BertViz

In [4]:
from bertviz import model_view

def visualize_attentions(epitope: str = "EMFAMKTKAALAIWC", heavy_chain: str = "QVQLVQSGAEVKKPGSSVKVSCKVS", light_chain: str = "", subsequence: str = ""):
    target = "=" + heavy_chain + ("" if not light_chain else "-" + light_chain)
    start_idx = 0
    end_idx = len(target)
    if subsequence:
        start_idx = target.find(subsequence)
        if start_idx == -1:
            raise ValueError(f"Subsequence {subsequence} not found in antibody")
        end_idx = start_idx + len(subsequence)
    ag_encoding = tokenizer([epitope], return_tensors="pt").input_ids
    prompt = tokenizer([target], return_tensors="pt").input_ids
    if torch.cuda.is_available():
        ag_encoding = ag_encoding.to(device="cuda")
        prompt = prompt.to(device="cuda")
    with torch.no_grad():
        outputs = model._t5.forward(input_ids=ag_encoding, decoder_input_ids=prompt, return_dict=True, output_attentions=True)
    decoder_text = tokenizer.convert_ids_to_tokens(prompt[0])
    encoder_text = tokenizer.convert_ids_to_tokens(ag_encoding[0])

    # slice
    model_view(
        cross_attention=[x[:, :, start_idx : end_idx] for x in outputs.cross_attentions],
        encoder_tokens= encoder_text,
        decoder_tokens = decoder_text[start_idx : end_idx],
        display_mode="dark"
    )

In [5]:
visualize_attentions(
    epitope=epitope,
    heavy_chain = "QVQLVQGAEVKKPGASVKVSCKASGYTFTSYIIHWVRQAPGQGLEWMGWINPNSGGTSYAQKFQGRVTMTRDTSTSTAYMELSSLRSDDTAVYYCAREGSPFYFDYWGQGTLVTVSS",
    light_chain="DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYYASNLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPPTFGQGTKVEIK",
    subsequence="EGSPFYFDY",
)

<IPython.core.display.Javascript object>

In [6]:
visualize_attentions(
    epitope=epitope,
    heavy_chain = "QVQLVQGAEVKKPGASVKVSCKASGYTFTSYIIHWVRQAPGQGLEWMGWINPNSGGTSYAQKFQGRVTMTRDTSTSTAYMELSSLRSDDTAVYYCAREGSPFYFDYWGQGTLVTVSS",
    light_chain="DIQMTQSPSSLSASVGDRVTITCRASQSISSYLNWYQQKPGKAPKLLIYYASNLQSGVPSRFSGSGSGTDFTLTISSLQPEDFATYYCQQSYSTPPTFGQGTKVEIK",
    subsequence="WGQGTLVTVSS",
)

<IPython.core.display.Javascript object>