In [None]:
import numpy as np
import pandas as pd
from Bio import SeqIO
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.SeqUtils.ProtParam import ProteinAnalysis
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import torch
import esm

# takes FASTA files and generates dataframe and csv with descriptions and features

In [None]:
df = pd.read_csv("seq_df_first_2000.csv", index_col=0).dropna().sort_values(by="GRAVY", ascending=False)
df = df[df["Length"] < 500]

In [None]:
df = pd.read_csv("seq_df_first_20000_less_500.csv", index_col=0).dropna().sort_values(by="Instability Index", ascending=False)


In [None]:
df

### Exploratory Data Analysis

In [None]:
plt.hist(df["GRAVY"], bins=200)
plt.xlabel("GRAVY")
plt.ylabel("Frequency")
plt.title("Distribution of GRAVY values")

In [None]:
plt.hist(df["Aromaticity"], bins=50)
plt.xlabel("Aromaticity")
plt.ylabel("Frequency")
plt.title("Distribution of Aromaticity values")

In [None]:
plt.hist(df["Molecular Weight"], bins=100)
plt.xlabel("MW")
plt.ylabel("Frequency")
plt.title("Distribution of MW values")

In [None]:
plt.hist(df["Charge at pH:7.0"], bins=200)
plt.xlabel("Aromaticity")
plt.ylabel("Frequency")
plt.title("Distribution of Charge at pH:7.0 values")

### Analyze Model Weights

In [None]:
# Load ESM-2 model
model, alphabet = esm.pretrained.esm2_t6_8M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # disables dropout for deterministic results

In [None]:
def get_features(data, model):    
    batch_labels, batch_strs, batch_tokens = batch_converter(data)
    batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)

    # Extract per-residue representations (on CPU)
    with torch.no_grad():
        results = model(batch_tokens, repr_layers=[6], return_contacts=True)
    token_representations = results["representations"][6]
    # Generate per-sequence representations via averaging
    # NOTE: token 0 is always a beginning-of-sequence token, so the first residue is token 1.
    sequence_representations = []
    for i, tokens_len in enumerate(batch_lens):
        sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))
    # Look at the unsupervised self-attention map contact predictions
    import matplotlib.pyplot as plt
    for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):
        plt.matshow(attention_contacts[: tokens_len, : tokens_len])
        plt.title(seq[:10] + "...")
        plt.colorbar()
        plt.show()    
    return results

In [None]:
top_gravy = df.iloc[:10]
low_gravy = df.iloc[-10:]

top_gravy_X = [(top_gravy.iloc[i]["Name"], top_gravy.iloc[i]["Sequence"]) for i in range(len(top_gravy))]
low_gravy_X = [(low_gravy.iloc[i]["Name"], low_gravy.iloc[i]["Sequence"]) for i in range(len(low_gravy))]
top_gravy_X, low_gravy_X

In [None]:
top_gravy_results = get_features(top_gravy_X, model)
low_gravy_results = get_features(low_gravy_X, model)

In [None]:
res = dict(top_gravy_results)
print(res.keys())
res_2 = dict(low_gravy_results)
print(res_2.keys())

In [None]:
# logits
plt.imshow(res['logits'][3])

In [None]:
# attentions
res['attentions'].shape

In [None]:
# Assume this tensor: (batch, layers, heads, seq_len, seq_len)
attn = res['attentions']  # Shape: [3, 6, 20, 190, 190]

sequence_idx = 0  # pick which sequence to visualize
heads_to_plot = [0, 1, 2, 3, 4]  # pick a few heads
num_layers = attn.shape[1]

for layer in range(num_layers):
    fig, axes = plt.subplots(1, len(heads_to_plot), figsize=(15, 5))
    fig.suptitle(f"Sequence {sequence_idx}, Layer {layer}")
    
    for i, head in enumerate(heads_to_plot):
        ax = axes[i]
        ax.imshow(attn[sequence_idx, layer, head].cpu().numpy(), cmap='viridis')
        ax.set_title(f"Head {head}")
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

In [None]:
# Shape: [batch, layers, heads, seq_len, seq_len]
attn = res['attentions']
sequence_idx = 0  # pick which sequence to visualize
num_layers = attn.shape[1]
num_heads = attn.shape[2]

for layer in range(num_layers):
    fig, axes = plt.subplots(4, 5, figsize=(20, 16))  # 4x5 grid for 20 heads
    fig.suptitle(f"Attention Heads — Sequence {sequence_idx}, Layer {layer}", fontsize=16)
    axes = axes.flatten()

    for head in range(num_heads):
        ax = axes[head]
        ax.imshow(attn[sequence_idx, layer, head].cpu().numpy(), cmap='viridis')
        ax.set_title(f"Head {head}")
        ax.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])  # Leave space for the title
    plt.show()


In [None]:
# Mean over batch → shape: [layers, heads, seq_len, seq_len]
attn = res['attentions'].mean(dim=0)

num_layers = attn.shape[0]
num_heads = attn.shape[1]

for layer in range(num_layers):
    fig, axes = plt.subplots(4, 5, figsize=(20, 16))  # 4x5 grid for 20 heads
    fig.suptitle(f"Mean Attention Across Batch — Layer {layer}", fontsize=16)
    axes = axes.flatten()

    for head in range(num_heads):
        ax = axes[head]
        ax.imshow(attn[layer, head].cpu().numpy(), cmap='viridis')
        ax.set_title(f"Head {head}")
        ax.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


In [None]:
# Mean over batch → shape: [layers, heads, seq_len, seq_len]
attn = res_2['attentions'].mean(dim=0)

num_layers = attn.shape[0]
num_heads = attn.shape[1]

for layer in range(num_layers):
    fig, axes = plt.subplots(4, 5, figsize=(20, 16))  # 4x5 grid for 20 heads
    fig.suptitle(f"Mean Attention Across Batch — Layer {layer}", fontsize=16)
    axes = axes.flatten()

    for head in range(num_heads):
        ax = axes[head]
        ax.imshow(attn[layer, head].cpu().numpy(), cmap='viridis')
        ax.set_title(f"Head {head}")
        ax.axis('off')

    plt.tight_layout(rect=[0, 0.03, 1, 0.95])
    plt.show()


In [None]:
attn = res["attentions"].mean(dim=0)  # shape: [layers, heads, seq_len, seq_len]

summary = []

for layer in range(attn.shape[0]):
    for head in range(attn.shape[1]):
        matrix = attn[layer, head]
        norm = torch.norm(matrix).item()
        min_val = matrix.min().item()
        max_val = matrix.max().item()
        summary.append({
            "layer": layer,
            "head": head,
            "norm": round(norm, 4),
            "min": round(min_val, 4),
            "max": round(max_val, 4)
        })

# Print a nice table
df_1 = pd.DataFrame(summary)


In [None]:
attn = res_2["attentions"].mean(dim=0)  # shape: [layers, heads, seq_len, seq_len]

summary = []

for layer in range(attn.shape[0]):
    for head in range(attn.shape[1]):
        matrix = attn[layer, head]
        norm = torch.norm(matrix).item()
        min_val = matrix.min().item()
        max_val = matrix.max().item()
        summary.append({
            "layer": layer,
            "head": head,
            "norm": round(norm, 4),
            "min": round(min_val, 4),
            "max": round(max_val, 4)
        })

# Print a nice table
df_2 = pd.DataFrame(summary)


In [None]:
(df_1 - df_2).sort_values("norm", ascending=False).head(20)

In [None]:
res['contacts'].shape

In [None]:
f, axarr = plt.subplots(1, 2, figsize=(20, 10))

criterion = "Molecular Weight"

# Compute mean contact maps once
mean_contact_1 = np.mean(res['contacts'].numpy(), axis=0) ** 0.5
mean_contact_2 = np.mean(res_2['contacts'].numpy(), axis=0) ** 0.5

# Use the global min/max for consistent color scale
v_max = max(np.max(mean_contact_1), np.max(mean_contact_2))
v_min = min(np.min(mean_contact_1), np.min(mean_contact_2))

# Plot contact maps
im0 = axarr[0].imshow(mean_contact_1[:50, :50], cmap='viridis', vmin=v_min, vmax=v_max)
axarr[0].set_title(f"Mean Contact - Top {criterion}")

im1 = axarr[1].imshow(mean_contact_2[:50, :50], cmap='viridis', vmin=v_min, vmax=v_max)
axarr[1].set_title(f"Mean Contact - Lowest {criterion}")

# Add colorbars to each subplot
f.colorbar(im0, ax=axarr[0])
f.colorbar(im1, ax=axarr[1])

plt.tight_layout()
plt.show()

In [None]:
f, axarr = plt.subplots(2, 2, figsize=(20, 10))
criterion = "Molecular_weight"

# Compute mean contact maps once
mean_contact_1 = np.mean(res['contacts'].numpy(), axis=0) ** 0.5
mean_contact_2 = np.mean(res_2['contacts'].numpy(), axis=0) ** 0.5

# Use the global min/max for consistent color scale
v_max = max(np.max(mean_contact_1), np.max(mean_contact_2))
v_min = min(np.min(mean_contact_1), np.min(mean_contact_2))

# Plot contact maps
im0 = axarr[1, 0].imshow(mean_contact_1[300:, 300:], cmap='viridis', vmin=v_min, vmax=v_max)
axarr[1, 0].set_title(f"Sequence End Mean Contact - Top {criterion}")

im1 = axarr[1,1].imshow(mean_contact_2[300:, 300:], cmap='viridis', vmin=v_min, vmax=v_max)
axarr[1, 1].set_title(f"Sequence End Mean Contact - Lowest {criterion}")

im0 = axarr[0, 0].imshow(mean_contact_1[:200, :200], cmap='viridis', vmin=v_min, vmax=v_max)
axarr[0,0].set_title(f"Sequence Start Mean Contact - Top {criterion}")

im1 = axarr[0,1].imshow(mean_contact_2[:200, :200], cmap='viridis', vmin=v_min, vmax=v_max)
axarr[0, 1].set_title(f"Sequence Start Mean Contact - Lowest {criterion}")

# Add colorbars to each subplot
# f.colorbar(im0, ax=axarr[0])
# f.colorbar(im1, ax=axarr[1])

plt.tight_layout()
plt.show()

In [None]:
plt.imshow(np.mean(res['contacts'].numpy(), axis=0))

In [None]:
plt.imshow(np.mean(res_2['contacts'].numpy(), axis=0))

In [None]:
print(res['representations'][6].shape)
plt.imshow(res['representations'][6][0])

In [None]:
res