In [96]:
import torch
import pandas as pd
import matplotlib.pyplot as plt
import torch.nn.functional as F

In [11]:
model, alphabet = torch.hub.load("facebookresearch/esm:main", "esm1b_t33_650M_UR50S")
batch_converter = alphabet.get_batch_converter()
model.eval()

Using cache found in /Users/ido/.cache/torch/hub/facebookresearch_esm_main


In [217]:
def mask_seq(seq, indices):
    new_seq = [x for x in seq]
    for i in indices:
        new_seq[i] = '<mask>'
    return "".join(new_seq)

In [151]:
def mask_seq_region(seq, region):
    start = region[0]
    end = region[1]
    indices = [x for x in range(start, end)]
    return mask_seq(seq, indices)

In [97]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k >0: keep only top k tokens with highest probability (top-k filtering).
            top_p >0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
    """
    assert logits.dim() == 1  # batch size 1 for now - could be updated for more but the code would be less clear
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        indices_to_remove = sorted_indices[sorted_indices_to_remove]
        logits[indices_to_remove] = filter_value
    return logits

In [198]:
region = (3, 6)
mask_indices = [x for x in range(region[0], region[1])]
original_seq = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"

In [206]:
temperature = 1.0
top_k = 0
top_p = 0.9

In [251]:
#TODO: Add MSA

In [247]:
def sample_candidate(seq, indices):
    print(f"original seq: {seq}")
    candidate = "" + seq
    
    # mask original sequence
    masked = mask_seq(candidate, indices)

    # prepare input
    data = [("protein1", masked)]
    _, _, batch_tokens = batch_converter(data)

    # run model
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[33], return_contacts=True)
    logits = results['logits']

    # conditionally generate masked aa's
    for i, idx in enumerate(indices):
        # nuclues sample wrt logits
        temp_logits = logits[0, idx, :] / temperature
        filtered_logits = top_k_top_p_filtering(temp_logits, top_k=top_k, top_p=top_p)
        probabilities = F.softmax(filtered_logits, dim=-1) # TODO: special tokens can be sampled!
        next_token = torch.multinomial(probabilities, 1)
        next_aa = alphabet.all_toks[int(next_token)]
        print(f"predicted: {next_aa} real is: {seq[idx]}")
        
        # update candidate with new aa
        candidate = candidate[:idx] + next_aa + candidate[idx + 1:]
        
    return candidate

In [253]:
def sample_candidate_conditional(seq, indices):
    print(f"original seq: {seq}")
    candidate = "" + seq

    # conditionally generate masked aa's
    for i, idx in enumerate(indices):
        # mask original sequence
        masked = mask_seq(candidate, indices[i:])
        print(f"candidate: {candidate}")
        
        # prepare input
        data = [("protein1", masked)]
        _, _, batch_tokens = batch_converter(data)
        
        # run model
        with torch.no_grad():
            results = model(batch_tokens, repr_layers=[33], return_contacts=True)
        logits = results['logits']
        
        # nuclues sample wrt logits
        temp_logits = logits[0, idx, :] / temperature
        filtered_logits = top_k_top_p_filtering(temp_logits, top_k=top_k, top_p=top_p)
        probabilities = F.softmax(filtered_logits, dim=-1) # TODO: special tokens can be sampled!
        next_token = torch.multinomial(probabilities, 1)
        next_aa = alphabet.all_toks[int(next_token)]
        print(f"predicted: {next_aa} real is: {seq[idx]}")
        
        # update candidate with new aa
        candidate = candidate[:idx] + next_aa + candidate[idx + 1:]
        
    return candidate