In [1]:
import os
import torch
import esm 
from sae_model import SparseAutoencoder
from esm_wrapper import ESM2Model

In [2]:
from collections import OrderedDict

# Mapping from one-letter residue name to three-letter residue name
RESTYPE_1_TO_3 = OrderedDict({
    'A': 'ALA',
    'R': 'ARG',
    'N': 'ASN',
    'D': 'ASP',
    'C': 'CYS',
    'Q': 'GLN',
    'E': 'GLU',
    'G': 'GLY',
    'H': 'HIS',
    'I': 'ILE',
    'L': 'LEU',
    'K': 'LYS',
    'M': 'MET',
    'F': 'PHE',
    'P': 'PRO',
    'S': 'SER',
    'T': 'THR',
    'W': 'TRP',
    'Y': 'TYR',
    'V': 'VAL',
})

# Mapping from three-letter residue name to one-letter residue name
RESTYPE_3_TO_1 = {v: k for k, v in RESTYPE_1_TO_3.items()}

# List of residue names
RESTYPES = list(RESTYPE_1_TO_3.keys())

# Mapping from one-letter residue name to its correponding index in the list of residue names
RESTYPE_ORDER = {restype: i for i, restype in enumerate(RESTYPES)}

def compute_rigid_alignment(A, B):
    """
    Use Kabsch algorithm to compute alignment from point cloud A to point cloud B.

    Source: https://gist.github.com/bougui505/e392a371f5bab095a3673ea6f4976cc8
    See: https://en.wikipedia.org/wiki/Kabsch_algorithm
    
    Args:
        A:
            [N, D] Point Cloud to Align (source)
        B:
            [N, D] Reference Point Cloud (target)
    
    Returns:
        R:
            Optimal rotation
        t: 
            Optimal translation
    """

    # Center
    a_mean = A.mean(axis=0)
    b_mean = B.mean(axis=0)
    A_c = A - a_mean
    B_c = B - b_mean

    # Covariance matrix
    H = A_c.T.mm(B_c)
    U, S, V = torch.svd(H)

    # Rotation matrix
    R = V.mm(U.T)

    # Translation vector
    t = b_mean[None, :] - R.mm(a_mean[None, :].T).T
    t = t.T
    
    return R, t.squeeze()

def parse_pdb(filepath, motif_residues=None):
    seqs, atoms, chains, coords = [], [], [], []
    is_motif = []
    with open(filepath, 'r') as file:
        for line in file:
            if line.startswith('ATOM'):
                restype_3 = line[17:20]
                restype_1 = RESTYPE_3_TO_1[restype_3]
                chain = line[21]
                x = float(line[30:38])
                y = float(line[38:46])
                z = float(line[46:54])
                if motif_residues is not None:
                    is_motif.append(int(line[22:26]) in motif_residues)
                if line[13:15].strip() == 'CA':
                    seqs.append(restype_1)
                atoms.append(line[13:15].strip())
                coords.append([x, y, z])

    return seqs, atoms, torch.tensor(coords), is_motif

In [3]:
import subprocess
pdbs = os.listdir('pdbs_l24_dim4096_k128/org')
for item_path in pdbs:
    org_filepath = os.path.join('pdbs_l24_dim4096_k128/org', item_path)
    rec_filepath = os.path.join('pdbs_l24_dim4096_k128/rec', item_path)
    output_filepath = item_path.replace('.pdb', '.txt')
    subprocess.call(f'packages/TMscore/TMalign {org_filepath} {rec_filepath} > tm/{output_filepath}', shell=True)

In [4]:
def parse_tm_file(filepath):
    results = {}
    with open(filepath, 'r') as file:
        for line in file:
            if line[:7] == 'Aligned':
                seqlen, rmsd, seqid = line.split(',')
                results['seqlen'] = float(seqlen.split('=')[1])
                results['rmsd'] = float(rmsd.split('=')[1])
                results['seqid'] = float(seqid.split('=')[2])
            elif line[:8] == 'TM-score':
                results['tm'] = float(line.split('(')[0].split('=')[1])
    assert len(results.keys()) == 4
    return results

In [5]:
rows = []
for result_txt in os.listdir('tm'):
    output = parse_tm_file(os.path.join('tm', result_txt))
    rows.append([result_txt.replace('.txt', ''), output['seqlen'], output['seqid'], output['rmsd'], output['tm']]) 

In [6]:
import pandas as pd

In [7]:
info = pd.DataFrame(rows, columns=['id', 'seqlen', 'seqid', 'scrmsd', 'sctm'])

In [133]:
info.to_csv('sc.csv', index=False)

In [8]:
D_MODEL = 1280
D_HIDDEN = 4096
device = 'cuda:0'

esm2_weight = os.path.join('weights', 'esm2_t33_650M_UR50D.pt')
sae_weight = 'esm2_plm1280_l24_sae4096_100Kseqs.pt'
alphabet = esm.data.Alphabet.from_architecture("ESM-1b")

In [9]:
esm2_model = ESM2Model(num_layers=33, embed_dim=1280, attention_heads=20, 
                       alphabet=alphabet, token_dropout=False)
esm2_model.load_esm_ckpt(esm2_weight)
esm2_model.eval()
esm2_model = esm2_model.to(device)
sae_model = SparseAutoencoder(D_MODEL, D_HIDDEN)
sae_model.load_state_dict(torch.load(sae_weight))
sae_model.eval()
sae_model = sae_model.to(device)

  model_data = torch.load(esm_pretrained)["model"]
  sae_model.load_state_dict(torch.load(sae_weight))


In [306]:
pdb_file = 'AF-A0A1Y1VHB5-F1-model_v4.pdb'

In [335]:
AA_identity_latents = [
    (3267, 'A'),
    # (3812, 'C'), 
    (2830, 'D'),
    (2152, 'E'),
    (252, 'F'),
    (3830, 'D'),
    (743, 'H'),
    (3978, 'I'),
    (3073, 'K'),
    (1497, 'L'),
    (444, 'M'),
    (21, 'N'),
    (1386, 'P'),
    (1266, 'Q'),
    (3569, 'R'),
    (1473, 'S'),
    (220, 'T'),
    (3383, 'V'),
    (2685, 'W'),
    (3481, 'Y')
]


In [337]:
import time 
from IPython.display import clear_output
for dim_idx, aa in AA_identity_latents:
    for pdb_file in os.listdir('pdbs_l24_dim4096_k128/org')[:6]:
        print(f'### Steering SAE dimension = {dim_idx} Corresponds to {aa} ###')
        seqs, atoms, _, _ = parse_pdb(os.path.join('pdbs_l24_dim4096_k128/org', pdb_file))
        sequence = ''.join(seqs)
        print('Original Sequence:     ', sequence)
        for mult_by in [0.6, 0.8, 1.0]:
            with torch.no_grad():
                tokens, embed = esm2_model.get_layer_activations(sequence, 24)
                acts, mu, std = sae_model.encode(embed[0])
                acts_dec = sae_model.decode(acts, mu, std)
                error = embed - acts_dec
                max_act = acts.max()
                acts[:, dim_idx] = max_act * mult_by
                steered_acts_dec = sae_model.decode(acts, mu, std)
                steered = esm2_model.get_sequence((error + steered_acts_dec), 24)
                steered_seq = torch.argmax(steered[0, 1:-1, 4:24], dim=-1)
                steered_seq = ''.join([alphabet.all_toks[i+4] for i in steered_seq])
                print(f'{mult_by}X Steered Sequence: ', steered_seq)
        time.sleep(1)
        clear_output(wait=True)  

### Steering SAE dimension = 3481 Corresponds to Y ###
Original Sequence:      MAVVLVGVVRRLSDTWWIWGAVVSMVFLIFFVMIAPVFIFPLLNKYTVLEDPRITQPILSLARANGIPARKVYQMDASRQTARMSANVSGYGQIMRITLNDNLLRRGSPEEIQSVMGHEMGHYVLHHIAKDILYFSAVIVISLRCGGHSNGPWLAGENAGRCAELGIPRCFPWSSFWVRSLPSYTRHSLTPTSAPTNTKPICMA
0.6X Steered Sequence:  MYYYLYYYYRYYPYYWWYYYYYYSYYYLYYYYYYYPIYYYPYYNYYYYYYYYYIYYYIYYYAYYNYYYYYYYYYMYMSYYTYYYSANYSGYGPYKRITLYDYLYYYYYYEEYYYIYGYEYGYYYYYHYYYYIYYYYYYIYYYYYCYYYSYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYSYYSYYYYYYYYYYYYYYYYYYYYY
0.8X Steered Sequence:  YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYPYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYTYYYYYYYYYYYYYYYYYYYDYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY
1.0X Steered Sequence:  YYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYYY


In [341]:
import time 
from IPython.display import clear_output
def distance(s1, s2):
    return sum([int(a1!=a2) for a1,a2 in zip(s1, s2)])
for dim_idx in [3, 5, 7, 9]:
    for pdb_file in os.listdir('pdbs_l24_dim4096_k128/org')[:6]:
        print(f'### Steering Other SAE dimensions = {dim_idx}  ###')
        seqs, atoms, _, _ = parse_pdb(os.path.join('pdbs_l24_dim4096_k128/org', pdb_file))
        sequence = ''.join(seqs)
        print('Original Sequence:                 ', sequence)
        for mult_by in [0.6, 0.8, 1.0]:
            with torch.no_grad():
                tokens, embed = esm2_model.get_layer_activations(sequence, 24)
                acts, mu, std = sae_model.encode(embed[0])
                acts_dec = sae_model.decode(acts, mu, std)
                error = embed - acts_dec
                max_act = acts.max()
                acts[:, dim_idx] = max_act * mult_by
                steered_acts_dec = sae_model.decode(acts, mu, std)
                steered = esm2_model.get_sequence((error + steered_acts_dec), 24)
                steered_seq = torch.argmax(steered[0, 1:-1, 4:24], dim=-1)
                steered_seq = ''.join([alphabet.all_toks[i+4] for i in steered_seq])
                print(f'{mult_by}X Steered Sequence: [Distance={distance(sequence, steered_seq)}]', steered_seq)
        time.sleep(1)
        clear_output(wait=True)  

### Steering Other SAE dimensions = 9  ###
Original Sequence:                  MAVVLVGVVRRLSDTWWIWGAVVSMVFLIFFVMIAPVFIFPLLNKYTVLEDPRITQPILSLARANGIPARKVYQMDASRQTARMSANVSGYGQIMRITLNDNLLRRGSPEEIQSVMGHEMGHYVLHHIAKDILYFSAVIVISLRCGGHSNGPWLAGENAGRCAELGIPRCFPWSSFWVRSLPSYTRHSLTPTSAPTNTKPICMA
0.6X Steered Sequence: [Distance=7] MAVVLVGVVRRLPDTWWIWGAVVSMVFLIFFVMIAPVFIFPLFNKYTPLEDPRITQPILSLARANGIPARKVYQMDASRQTARMSANVSGLGQTMRITLNDNLLRRGSPEEIQSVMGHEMGHYVLHHIAKGILYFSAVIVISLRCGGHSNGPWLAGENAGRCAELGIPRCFPWSSFWVRSLPSPTRHSLTPTSAPTNTKPICMA
0.8X Steered Sequence: [Distance=7] MAVVLVGVVRRLPDTWWIWGAVVSMVFLIFFVMIAPVFIFPLFNKYTPLEDPRITQPILSLARANGIPARKVYQMDASRQTARMSANVSGLGQTMRITLNDNLLRRGSPEEIQSVMGHEMGHYVLHHIAKGILYFSAVIVISLRCGGHSNGPWLAGENAGRCAELGIPRCFPWSSFWVRSLPSPTRHSLTPTSAPTNTKPICMA
1.0X Steered Sequence: [Distance=8] MAVVLVGVVRRLPDRWWIWGAVVSMVFLIFFVMIAPVFIFPLFNKYTPLEDPRITQPILSLARANGIPARKVYQMDASRQTARMSANVSGLGQSMRITLNDNLLRRGSPEEIQSVMGHEMGHYVLHHIAKGILYFSAVIVISLRCGGHSNGPWLAGENAGRCAELGIPRCFPWSSFWVRSLPSPTRHSLTPTSAPTNT

In [325]:
steered_seq

'SDLSYFSSYLSSTLASIVASSKVFSTSRLSSSSKLTLSSTEYTSLRSSSESSSASVYLVRKSDSSSSSLYALKRLSSSSSSSEERVRNELLALLSLLSSSRSLSLSDYASVKSSDSTSSSSLLLSSSSSSSLSSLSSSTSKRTRSSLSFSLSSSILSDISSSLSSSSKSSSSLASRSLKLSNVALSSSSSASLSSLSSVSSASSTSSTRSSASVSKSFSSSTSTSTYRSSEFLSSLVDSRSDSRTDSSSLSCTLLALALSSSSSSSTSTSTLSSVVISDDSTYDSLFLLLLKKLLSSDSSKRSTLSEVKSTLYLFLSASSFSSNAV'

In [326]:
len(sequence)

326

In [328]:
distance(steered_seq, sequence)

148