This notebook is to get the activations in embedding space over all possible mutations in TP53.

In [2]:
import torch
import esm
from tqdm import tqdm
import pandas as pd

ESM-1b has 33 layers & 1280 embedding dim

In [3]:
# Load the model
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()
batch_converter = alphabet.get_batch_converter()

In [4]:
amino_acids = [
    "A",  # Alanine
    "R",  # Arginine
    "N",  # Asparagine
    "D",  # Aspartic acid
    "C",  # Cysteine
    "Q",  # Glutamine
    "E",  # Glutamic acid
    "G",  # Glycine
    "H",  # Histidine
    "I",  # Isoleucine
    "L",  # Leucine
    "K",  # Lysine
    "M",  # Methionine
    "F",  # Phenylalanine
    "P",  # Proline
    "S",  # Serine
    "T",  # Threonine
    "W",  # Tryptophan
    "Y",  # Tyrosine
    "V"   # Valine
]

In [5]:
tp53_sequence = "MEEPQSDPSVEPPLSQETFSDLWKLLPENNVLSPLPSQAMDDLMLSPDDIEQWFTEDPGPDEAPRMPEAAPPVAPAPAAPTPAAPAPAPSWPLSSSVPSQKTYQGSYGFRLGFLHSGTAKSVTCTYSPALNKMFCQLAKTCPVQLWVDSTPPPGTRVRAMAIYKQSQHMTEVVRRCPHHERCSDSDGLAPPQHLIRVEGNLRVEYLDDRNTFRHSVVVPYEPPEVGSDCTTIHYNYMCNSSCMGGMNRRPILTIITLEDSSGNLLGRNSFEVRVCACPGRDRRTEEENLRKKGEPHHELPPGSTKRALPNNTSSSPQPKKKPLDGEYFTLQIRGRERFEMFRELNEALELKDAQAGKEPGGSRAHSSHLKSKKGQSTSRHKKLMFKTEGPDSD"

In [6]:
def mutate_sequence(sequence, position, new_residue):
    """
    Mutate the sequence at the given position (1-based indexing) to the new residue.

    Args:
        sequence (str): Original amino acid sequence.
        position (int): 1-based index of the residue to mutate.
        new_residue (str): New amino acid (single letter).

    Returns:
        str: Mutated sequence.
    """
    if position < 1 or position > len(sequence):
        raise ValueError("Position out of range.")

    if len(new_residue) != 1:
        raise ValueError("New residue must be a single character.")

    mutated_sequence = sequence[:position - 1] + new_residue + sequence[position:]
    return mutated_sequence

In [7]:
def get_embedding(data, layer = 33):
    # Your protein sequence(s)
    # data = [("wt", tp53_sequence)]  # Replace with your sequence
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    
    # Extract per-residue representations (on the last layer, layer 33)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[layer], return_contacts=False)
    
    token_representations = results["representations"][layer]
    
    # Remove padding ([CLS] and [EOS]) tokens to get per-residue embeddings
    # Typically, token 0 is [CLS], token -1 is [EOS]
    embedding = token_representations[0, 1:-1]  # Shape: (sequence_length, 1280)
    
    return embedding

In [8]:
data = [("wt", tp53_sequence)]

In [9]:
for i in range(1, len(tp53_sequence) + 1):
    wt_aa = tp53_sequence[i - 1]
    for aa in amino_acids:
        if aa == wt_aa:
            continue
        data.append((f"{wt_aa}{i}{aa}", mutate_sequence(tp53_sequence, i, aa)))

# run to get layer 33

In [46]:
# batching doesn't really help with time so run one by one

embedding_list = []

for seq in tqdm(data, desc="Running ESM Embeddings"):
    embedding = get_embedding([seq])
    embedding_list.append(embedding)

Running ESM Embeddings: 100%|█████████████████████████████████████████████████████████████████████████████████████████| 7468/7468 [1:20:59<00:00,  1.54it/s]


In [48]:
len(embedding_list)

7468

In [49]:
import pickle

# # Save to a file just in case
# with open('embedding_list.pkl', 'wb') as f:
#     pickle.dump(embedding_list, f)

In [68]:
df = pd.DataFrame(zip([x[0] for x in data], embedding_list), columns = ["name", "embedding"])

In [69]:
df.head()

Unnamed: 0,name,embedding
0,wt,"[[tensor(0.1848), tensor(-0.0135), tensor(0.18..."
1,M1A,"[[tensor(0.1435), tensor(-0.0347), tensor(0.05..."
2,M1R,"[[tensor(0.2029), tensor(0.0074), tensor(0.206..."
3,M1N,"[[tensor(0.1319), tensor(0.0299), tensor(0.125..."
4,M1D,"[[tensor(0.1527), tensor(0.0441), tensor(0.073..."


In [73]:
df.to_csv("all_mut_layer_33.tsv", sep = "\t", index = None)

In [74]:
df_new = pd.read_csv("all_mut_layer_33.tsv", sep = "\t")

In [77]:
with open('all_mut_layer_33.pkl', 'wb') as f:
    pickle.dump(df, f)

In [79]:
df["embedding"].iloc[0].shape

torch.Size([393, 1280])

## get mean activation over different residue positions

In [None]:
df['embedding_mean'] = df['embedding'].apply(lambda x: x.mean(dim=0))

In [84]:
with open('all_mut_layer_33_pos_mean.pkl', 'wb') as f:
    pickle.dump(df[["name", "embedding_mean"]], f)

In [11]:
import pickle

with open('all_mut_layer_33_pos_mean.pkl', 'rb') as f:
    data = pickle.load(f)

In [17]:
## variance is super low

tensor_lists = data['embedding_mean'].tolist()

# Convert each list of tensors into a tensor of floats
tensor_lists = [torch.stack([t for t in tensor_list]) for tensor_list in tensor_lists]

# Stack all the tensors into one big 2D tensor (rows = samples, cols = embedding dimension)
all_embeddings = torch.stack(tensor_lists)

# Now compute variance across rows (dim=0)
variance = torch.var(all_embeddings, dim=0, unbiased=True)  # unbiased=True matches pandas default

print(variance.max())

tensor(0.0001)


## get activation diff for the mutated position

In [94]:
df.head()

Unnamed: 0,name,embedding,embedding_mean
0,wt,"[[tensor(0.1848), tensor(-0.0135), tensor(0.18...","[tensor(-0.0231), tensor(0.1752), tensor(-0.02..."
1,M1A,"[[tensor(0.1435), tensor(-0.0347), tensor(0.05...","[tensor(-0.0299), tensor(0.1748), tensor(-0.02..."
2,M1R,"[[tensor(0.2029), tensor(0.0074), tensor(0.206...","[tensor(-0.0287), tensor(0.1731), tensor(-0.02..."
3,M1N,"[[tensor(0.1319), tensor(0.0299), tensor(0.125...","[tensor(-0.0255), tensor(0.1742), tensor(-0.02..."
4,M1D,"[[tensor(0.1527), tensor(0.0441), tensor(0.073...","[tensor(-0.0298), tensor(0.1734), tensor(-0.02..."


In [91]:
wild_type = df.iloc[0]["embedding"]

In [93]:
wild_type.shape

torch.Size([393, 1280])

In [108]:
# function to apply
def subtract_from_wild_type(row):
    if row["name"] == "wt":
        return 0
    index_0dim = int(row["name"][1:-1])
    diff = row['embedding'] - wild_type
    # index
    return diff[index_0dim-1]

# apply to the dataframe
df['embedding_diff'] = df.apply(subtract_from_wild_type, axis=1)

In [109]:
df

Unnamed: 0,name,embedding,embedding_mean,embedding_diff
0,wt,"[[tensor(0.1848), tensor(-0.0135), tensor(0.18...","[tensor(-0.0231), tensor(0.1752), tensor(-0.02...",0
1,M1A,"[[tensor(0.1435), tensor(-0.0347), tensor(0.05...","[tensor(-0.0299), tensor(0.1748), tensor(-0.02...","[tensor(-0.0413), tensor(-0.0212), tensor(-0.1..."
2,M1R,"[[tensor(0.2029), tensor(0.0074), tensor(0.206...","[tensor(-0.0287), tensor(0.1731), tensor(-0.02...","[tensor(0.0181), tensor(0.0210), tensor(0.0262..."
3,M1N,"[[tensor(0.1319), tensor(0.0299), tensor(0.125...","[tensor(-0.0255), tensor(0.1742), tensor(-0.02...","[tensor(-0.0529), tensor(0.0434), tensor(-0.05..."
4,M1D,"[[tensor(0.1527), tensor(0.0441), tensor(0.073...","[tensor(-0.0298), tensor(0.1734), tensor(-0.02...","[tensor(-0.0320), tensor(0.0577), tensor(-0.10..."
...,...,...,...,...
7463,D393S,"[[tensor(0.1874), tensor(-0.0092), tensor(0.18...","[tensor(-0.0249), tensor(0.1759), tensor(-0.02...","[tensor(-0.2499), tensor(-0.0110), tensor(0.03..."
7464,D393T,"[[tensor(0.1870), tensor(-0.0067), tensor(0.18...","[tensor(-0.0231), tensor(0.1754), tensor(-0.02...","[tensor(-0.0735), tensor(0.0936), tensor(-0.06..."
7465,D393W,"[[tensor(0.1946), tensor(-0.0108), tensor(0.18...","[tensor(-0.0224), tensor(0.1722), tensor(-0.02...","[tensor(-0.0679), tensor(-0.1721), tensor(-0.0..."
7466,D393Y,"[[tensor(0.1956), tensor(-0.0135), tensor(0.17...","[tensor(-0.0220), tensor(0.1710), tensor(-0.02...","[tensor(-0.0593), tensor(-0.1182), tensor(-0.2..."


In [115]:
(df.iloc[7467]["embedding"] - wild_type)[-1]

tensor([-0.0046,  0.0367, -0.1130,  ..., -0.2177, -0.1049,  0.0407])

In [116]:
with open('all_mut_layer_33_diff.pkl', 'wb') as f:
    pickle.dump(df[["name", "embedding_diff"]], f)

# get layer 0

In [10]:
# batching doesn't really help with time so run one by one

embedding_list = []

for seq in tqdm(data, desc="Running ESM Embeddings"):
    embedding = get_embedding([seq], layer = 0)
    embedding_list.append(embedding)

Running ESM Embeddings:   4%|███▎                                                                                       | 272/7468 [07:23<3:15:29,  1.63s/it]


KeyboardInterrupt: 