In [1]:
import gymnasium as gym
import numpy as np
import pandas as pd
from gymnasium import spaces

from stable_baselines3.common.vec_env import DummyVecEnv
from stable_baselines3 import PPO

from fitness_functions import fitness_ESM, fitness_ESM_DMS
from callbacks import TQDMCallback

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class ProteinEnv(gym.Env):
    """
    State: amino acid sequence (string or int array)
    Action: mutate position i to amino acid j
    """
    metadata = {"render.modes": ["human"]}

    def __init__(self, seq, fitness_fn, DMS_path):
        ''' Requires the wild-type aa sequence (string), 
                fitness_fn (defined in fitness_functions.py),
            and DMS dataset (path to csv)
        '''
        super().__init__()
        self.amino_acids = "ACDEFGHIKLMNPQRSTVWY"
        self.aa_to_idx = {aa: i for i, aa in enumerate(self.amino_acids)}
        self.idx_to_aa = {i: aa for aa, i in self.aa_to_idx.items()}

        self.L = len(seq)
        self.fitness_fn = fitness_fn
        self.DMS = pd.read_csv(DMS_path)
        
        # convert sequence string â†’ array of indices
        self.initial_seq = np.array([self.aa_to_idx[a] for a in seq], dtype=np.int32)

        # action = choose a position to mutate, and choose an aa to mutate to
        self.action_space = spaces.Discrete(self.L * 20)

        # observation = vector of length L with values in [0,19]
        self.observation_space = spaces.MultiDiscrete([20] * self.L)

        self.state = None
    
    def idxs_to_letters(self, seq):
        ''' convert string of indexes to string of aa letters '''
        return ''.join([self.idx_to_aa[i] for i in seq])

    def _decode_action(self, action):
        pos = action // 20
        aa_idx = action % 20
        return pos, aa_idx

    def reset(self, *, seed=None, options=None):
        super().reset(seed=seed)
        self.state = self.initial_seq.copy()  # back to wild-type
        obs = self.state.copy()
        return obs, {}

    def step(self, action):
        import pdb;pdb.set_trace()
        pos, aa_idx = self._decode_action(action)

        # Apply mutation
        new_state = self.state.copy()
        new_state[pos] = aa_idx

        # Reward from fitness function
        reward = self.fitness_fn(self.idxs_to_letters(new_state), self.DMS)

        # You can choose episode termination rule:
        # e.g., fixed length episode of mutations
        terminated = False
        truncated = False

        self.state = new_state
        import pdb;pdb.set_trace()
        return new_state.copy(), reward, terminated, truncated, {}

    def render(self):
        seq_str = "".join(self.idx_to_aa[i] for i in self.state)
        print(seq_str)


In [15]:
import torch, esm
import numpy as np

# Load the pretrained ESM2 150M model
esm150_model, esm150_alphabet = esm.pretrained.esm2_t30_150M_UR50D()
esm150_model.eval()

batch_converter = esm150_alphabet.get_batch_converter()
mask_idx = esm150_alphabet.mask_idx
def esm_pseudo_log_likelihood(seq):
    """
    Computes the pseudo log-likelihood of an amino-acid sequence using ESM2 (masked LM). 
    Returns a float.
    """
    data = [("protein", seq)]
    import pdb;pdb.set_trace()
    _, _, tokens = batch_converter(data)     # shape [1, L]
    tokens = tokens[0]                       # shape [L] (technically 2 more then length bc of BOS/EOS tokens)
    L = tokens.size(0)

    # Generate all masked sequences (L-2 internal positions)
    masked_tokens = tokens.repeat(L-2, 1)
    positions = torch.arange(1, L-1)
    masked_tokens[torch.arange(L-2), positions] = mask_idx  # mask each pos

    # Add batch dimension
    masked_tokens = masked_tokens.unsqueeze(1)  # [L-2, 1, L]

    with torch.no_grad():
        import pdb;pdb.set_trace()
        logits = esm150_model(masked_tokens)["logits"]  # [L-2, 1, L, vocab]

    log_probs = []
    for i, pos in enumerate(positions):
        true_token = tokens[pos]
        log_prob_i = torch.log_softmax(logits[i, 0, pos], dim=-1)[true_token]
        log_probs.append(log_prob_i)

    return float(torch.stack(log_probs).sum().item())

In [16]:
esm_pseudo_log_likelihood('AGRILKYRTTT')

> [32m/tmp/ipykernel_962296/345124178.py[39m([92m17[39m)[36mesm_pseudo_log_likelihood[39m[34m()[39m
[32m     15[39m     data = [([33m"protein"[39m, seq)]
[32m     16[39m     [38;5;28;01mimport[39;00m pdb;pdb.set_trace()
[32m---> 17[39m     _, _, tokens = batch_converter(data)     [38;5;66;03m# shape [1, L][39;00m
[32m     18[39m     tokens = tokens[[32m0[39m]                       [38;5;66;03m# shape [L] (technically 2 more then length bc of BOS/EOS tokens)[39;00m
[32m     19[39m     L = tokens.size([32m0[39m)



> [32m/tmp/ipykernel_962296/345124178.py[39m([92m18[39m)[36mesm_pseudo_log_likelihood[39m[34m()[39m
[32m     16[39m     [38;5;28;01mimport[39;00m pdb;pdb.set_trace()
[32m     17[39m     _, _, tokens = batch_converter(data)     [38;5;66;03m# shape [1, L][39;00m
[32m---> 18[39m     tokens = tokens[[32m0[39m]                       [38;5;66;03m# shape [L] (technically 2 more then length bc of BOS/EOS tokens)[39;00m
[32m     19[39m     L = tokens.size([32m0[39m)
[32m     20[39m 

> [32m/tmp/ipykernel_962296/345124178.py[39m([92m19[39m)[36mesm_pseudo_log_likelihood[39m[34m()[39m
[32m     17[39m     _, _, tokens = batch_converter(data)     [38;5;66;03m# shape [1, L][39;00m
[32m     18[39m     tokens = tokens[[32m0[39m]                       [38;5;66;03m# shape [L] (technically 2 more then length bc of BOS/EOS tokens)[39;00m
[32m---> 19[39m     L = tokens.size([32m0[39m)
[32m     20[39m 
[32m     21[39m     [38;5;66;03m# Generate all maske

In [13]:
esm_pseudo_log_likelihood('MAADGYLPDWLEDTLSEGIRQWWKLKPGPPPPKPAERHKDDSRGLVLPGYKYLGPFNGLDKGEPVNEADAAALEHDKAYDRQLDSGDNPYLKYNHADAEFQERLKEDTSFGGNLGRAVFQAKKRVLEPLGLVEEPVKTAPGKKRPVEHSPVEPDSSSGTGKAGQQPARKRLNFGQTGDADSVPDPQPLGQPPAAPSGLGTNTMATGSGAPMADNNEGADGVGNSSGNWHCDSTWMGDRVITTSTRTWALPTYNNHLYKQISSQSGASNDNHYFGYSTPWGYFDFNRFHCHFSPRDWQRLINNNWGFRPKRLNFKLFNIQVKEVTQNDGTTTIANNLTSTVQVFTDSEYQLPYVLGSAHQGCLPPFPADVFMVPQYGYLTLNNGSQAVGRSSFYCLEYFPSQMLRTGNNFTFSYTFEDVPFHSSYAHSQSLDRLMNPLIDQYLYYLSRTNTPSGTTTQSRLQFSQAGASDIRDQSRNWLPGPCYRQQRVSKTSADNNNSEYSWTGATKYHLNGRDSLVNPGPAMASHKDDEEKFFPQSGVLIFGKQGSEKTNVDIEKVMITDEEEIRTTNPVATEQYGSVSTNLQRGNRQAATADVNTQGVLPGMVWQDRDVYLQGPIWAKIPHTDGHFHPSPLMGGFGLKHPPPQILIKNTPVPANPSTTFSAAKFASFITQYSTGQVSVEIEWELQKENSKRWNPEIQYTSNYNKSVNVDFTVDTNGVYSEPRPIGTRYLTRNL')

Training PPO:   0%|          | 0/1 [15:33<?, ?it/s]
Training PPO:   0%|          | 0/1 [07:11<?, ?it/s]


KeyboardInterrupt: 

In [12]:
wt

'MAADGYLPDWLEDTLSEGIRQWWKLKPGPPPPKPAERHKDDSRGLVLPGYKYLGPFNGLDKGEPVNEADAAALEHDKAYDRQLDSGDNPYLKYNHADAEFQERLKEDTSFGGNLGRAVFQAKKRVLEPLGLVEEPVKTAPGKKRPVEHSPVEPDSSSGTGKAGQQPARKRLNFGQTGDADSVPDPQPLGQPPAAPSGLGTNTMATGSGAPMADNNEGADGVGNSSGNWHCDSTWMGDRVITTSTRTWALPTYNNHLYKQISSQSGASNDNHYFGYSTPWGYFDFNRFHCHFSPRDWQRLINNNWGFRPKRLNFKLFNIQVKEVTQNDGTTTIANNLTSTVQVFTDSEYQLPYVLGSAHQGCLPPFPADVFMVPQYGYLTLNNGSQAVGRSSFYCLEYFPSQMLRTGNNFTFSYTFEDVPFHSSYAHSQSLDRLMNPLIDQYLYYLSRTNTPSGTTTQSRLQFSQAGASDIRDQSRNWLPGPCYRQQRVSKTSADNNNSEYSWTGATKYHLNGRDSLVNPGPAMASHKDDEEKFFPQSGVLIFGKQGSEKTNVDIEKVMITDEEEIRTTNPVATEQYGSVSTNLQRGNRQAATADVNTQGVLPGMVWQDRDVYLQGPIWAKIPHTDGHFHPSPLMGGFGLKHPPPQILIKNTPVPANPSTTFSAAKFASFITQYSTGQVSVEIEWELQKENSKRWNPEIQYTSNYNKSVNVDFTVDTNGVYSEPRPIGTRYLTRNL'

In [9]:
with open('aav_wt.txt', 'r') as file:
    wt = file.readline().strip()

def make_env():
    # Provide your own initial sequence + fitness_fn
    return ProteinEnv(wt, fitness_ESM, 'aav_dms.csv')

vec_env = DummyVecEnv([make_env])

model = PPO(
    policy="MlpPolicy",
    env=vec_env,
    learning_rate=3e-4,
    n_steps=2048,
    batch_size=64,
    gae_lambda=0.95,
    gamma=0.99,
    n_epochs=10,
    clip_range=0.2,
    verbose=1,
)
total_timesteps = 1
tqdm_cb = TQDMCallback(total_timesteps=total_timesteps, algo='PPO')
model.learn(total_timesteps=total_timesteps, callback=tqdm_cb)
model.save("ppo_pretraining")

Using cpu device



[A

> [32m/tmp/ipykernel_960888/2807009226.py[39m([92m50[39m)[36mstep[39m[34m()[39m
[32m     48[39m     [38;5;28;01mdef[39;00m step(self, action):
[32m     49[39m         [38;5;28;01mimport[39;00m pdb;pdb.set_trace()
[32m---> 50[39m         pos, aa_idx = self._decode_action(action)
[32m     51[39m 
[32m     52[39m         [38;5;66;03m# Apply mutation[39;00m

10102
> [32m/tmp/ipykernel_960888/2807009226.py[39m([92m53[39m)[36mstep[39m[34m()[39m
[32m     51[39m 
[32m     52[39m         [38;5;66;03m# Apply mutation[39;00m
[32m---> 53[39m         new_state = self.state.copy()
[32m     54[39m         new_state[pos] = aa_idx
[32m     55[39m 

(np.int64(505), np.int64(2))
array([10,  0,  0,  2,  5, 19,  9, 12,  2, 18,  9,  3,  2, 16,  9, 15,  3,
        5,  7, 14, 13, 18, 18,  8,  9,  8, 12,  5, 12, 12, 12, 12,  8, 12,
        0,  3, 14,  6,  8,  2,  2, 15, 14,  5,  9, 17,  9, 12,  5, 19,  8,
       19,  9,  5, 12,  4, 11,  5,  9,  2,  8,  5,  3, 12, 17,