In [2]:
%cd ~/repo/protein-transfer

/home/t-fli/repo/protein-transfer


In [3]:
%load_ext blackcellmagic

In [3]:
import pandas as pd

In [21]:
df = pd.read_csv("data/proeng/thermo/mixed_split.csv")
print(len(df))

27951


In [19]:
df.head()

Unnamed: 0,sequence,target,set,validation
0,MSGEEEKAADFYVRYYVGHKGKFGHEFLEFEFRPNGSLRYANNSNY...,37.962947,train,
1,MSMGSDFYLRYYVGHKGKFGHEFLEFEFRPDGKLRYANNSNYKNDV...,54.425342,train,
2,MRICFLLLAFLVAETFANELTRCCAGGTRHFKNSNTCSSIKSEGTS...,49.459216,train,
3,MIRVALPTTASAIPRSISTSPGETISKNHEEEVKRVWRKADAVCFD...,42.593131,train,
4,MNGDWSRAFVLSKVKNLYFFVIIDKGFSAILNDPREPVQVGGFFEV...,37.999478,train,


In [4]:
"""Script for running pytorch models"""

from __future__ import annotations

import os
from tqdm import tqdm
from concurrent import futures

import torch
import torch.nn as nn

from sklearn.metrics import ndcg_score
from scipy.stats import spearmanr

from scr.params.sys import RAND_SEED, DEVICE
from scr.params.emb import TRANSFORMER_INFO, CARP_INFO

from scr.preprocess.data_process import split_protrain_loader
from scr.encoding.encoding_classes import get_emb_info, ESMEncoder, CARPEncoder
from scr.model.pytorch_model import LinearRegression
from scr.model.train_test import train, test
from scr.vis.learning_vis import plot_lc
from scr.utils import get_folder_file_names, pickle_save, get_default_output_path


class Run_Pytorch:
    def __init__(
        self,
        dataset_path: str,
        encoder_name: str,
        reset_param: bool = False,
        resample_param: bool = False,
        embed_batch_size: int = 128,
        flatten_emb: bool | str = False,
        embed_folder: str | None = None,
        seq_start_idx: bool | int = False,
        seq_end_idx: bool | int = False,
        loader_batch_size: int = 64,
        worker_seed: int = RAND_SEED,
        if_encode_all: bool = True,
        learning_rate: float = 1e-4,
        lr_decay: float = 0.1,
        epochs: int = 100,
        early_stop: bool = True,
        tolerance: int = 10,
        min_epoch: int = 5,
        device: torch.device | str = DEVICE,
        all_plot_folder: str = "results/learning_curves",
        all_result_folder: str = "results/train_val_test",
        **encoder_params,
    ) -> None:

        self._dataset_path = dataset_path
        self._encoder_name = encoder_name
        self._reset_param = reset_param
        self._resample_param = resample_param
        self._embed_batch_size = embed_batch_size
        self._flatten_emb = flatten_emb

        self._learning_rate = learning_rate
        self._lr_decay = lr_decay
        self._epochs = epochs
        self._early_stop = early_stop
        self._tolerance = tolerance
        self._min_epoch = min_epoch
        self._device = device
        self._all_plot_folder = all_plot_folder
        self._all_result_folder = all_result_folder
        self._encoder_params = encoder_params

        """
        A function for running pytorch model

        Args:
        - dataset_path: str, full path to the dataset, in pkl or panda readable format
            columns include: sequence, target, set, validation, mut_name (optional), mut_numb (optional)
        - encoder_name: str, the name of the encoder
        
        - embed_batch_size: int, set to 0 to encode all in a single batch
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - embed_folder: str = None, path to presaved embedding
        - seq_start_idx: bool | int = False, the index for the start of the sequence
        - seq_end_idx: bool | int = False, the index for the end of the sequence
        - loader_batch_size: int, the batch size for train, val, and test dataloader
        - worker_seed: int, the seed for dataloader
        - learning_rate: float
        - lr_decay: float, factor by which to decay LR on plateau
        - epochs: int, number of epochs to train for
        - device: torch.device or str
        - all_plot_folder: str, the parent folder path for saving all the learning curves
        - all_result_folder: str = "results/train_val_test", the parent folder for all results
        - encoder_params: kwarg, additional parameters for encoding

        Returns:
        - result_dict: dict, with the keys and dict values
            "losses": {"train_losses": np.ndarray, "val_losses": np.ndarray}
            "train": {"mse": float, 
                    "pred": np.ndarray,
                    "true": np.ndarray,
                    "ndcg": float,
                    "rho": SpearmanrResults(correlation=float, pvalue=float)}
            "val":   {"mse": float, 
                    "pred": np.ndarray,
                    "true": np.ndarray,
                    "ndcg": float,
                    "rho": SpearmanrResults(correlation=float, pvalue=float)}
            "test":  {"mse": float, 
                    "pred": np.ndarray,
                    "true": np.ndarray,
                    "ndcg": float,
                    "rho": SpearmanrResults(correlation=float, pvalue=float)}

        """

        self._train_loader, self._val_loader, self._test_loader = split_protrain_loader(
            dataset_path=self._dataset_path,
            encoder_name=self._encoder_name,
            reset_param=self._reset_param,
            resample_param=self._resample_param,
            embed_batch_size=self._embed_batch_size,
            flatten_emb=self._flatten_emb,
            embed_folder=embed_folder,
            seq_start_idx=seq_start_idx,
            seq_end_idx=seq_end_idx,
            subset_list=["train", "val", "test"],
            loader_batch_size=loader_batch_size,
            worker_seed=worker_seed,
            if_encode_all=if_encode_all,
            **encoder_params,
        )

        encoder_name, encoder_class, total_emb_layer = get_emb_info(encoder_name)

        if encoder_class == ESMEncoder:
            self._encoder_info_dict = TRANSFORMER_INFO
        elif encoder_class == CARPEncoder:
            self._encoder_info_dict = CARP_INFO

        future_path = {}
        # add the thredpool max_workers=None
        with futures.ProcessPoolExecutor() as pool:
            # for each layer train the model and save the model
            for embed_layer in tqdm(range(total_emb_layer)):
                pool.submit(
                    self.run_pytorch_layer,
                    embed_layer
                )

    def run_pytorch_layer(self, embed_layer):
        model = LinearRegression(
            input_dim=self._encoder_info_dict[self._encoder_name][0], output_dim=1
        )
        model.to(self._device, non_blocking=True)

        criterion = nn.MSELoss()
        criterion.to(self._device, non_blocking=True)

        train_losses, val_losses = train(
            model=model,
            criterion=criterion,
            train_loader=self._train_loader,
            val_loader=self._val_loader,
            encoder_name=self._encoder_name,
            embed_layer=embed_layer,
            reset_param=self._reset_param,
            resample_param=self._resample_param,
            embed_batch_size=self._embed_batch_size,
            flatten_emb=self._flatten_emb,
            device=self._device,
            learning_rate=self._learning_rate,
            lr_decay=self._lr_decay,
            epochs=self._epochs,
            early_stop=self._early_stop,
            tolerance=self._tolerance,
            min_epoch=self._min_epoch,
            **self._encoder_params,
        )

        # record the losses
        result_dict = {
            "losses": {"train_losses": train_losses, "val_losses": val_losses}
        }

        plot_lc(
            train_losses=train_losses,
            val_losses=val_losses,
            dataset_path=self._dataset_path,
            encoder_name=self._encoder_name,
            embed_layer=embed_layer,
            flatten_emb=self._flatten_emb,
            all_plot_folder=self._all_plot_folder,
        )

        # now test the model with the test data
        for subset, loader in zip(
            ["train", "val", "test"],
            [self._train_loader, self._val_loader, self._test_loader],
        ):
            mse, pred, true = test(
                model=model, loader=loader, device=self._device, criterion=criterion
            )

            result_dict[subset] = {
                "mse": mse,
                "pred": pred,
                "true": true,
                "ndcg": ndcg_score(true[None, :], pred[None, :]),
                "rho": spearmanr(true, pred),
            }

        dataset_subfolder, file_name = get_folder_file_names(
            parent_folder=get_default_output_path(self._all_result_folder),
            dataset_path=self._dataset_path,
            encoder_name=self._encoder_name,
            embed_layer=embed_layer,
            flatten_emb=self._flatten_emb,
        )

        print(f"Saving results for {file_name} to: {dataset_subfolder}...")
        pickle_save(
            what2save=result_dict,
            where2save=os.path.join(dataset_subfolder, file_name + ".pkl"),
        )


def run_pytorch(
    dataset_path: str,
    encoder_name: str,
    reset_param: bool = False,
    resample_param: bool = False,
    embed_batch_size: int = 128,
    flatten_emb: bool | str = False,
    embed_folder: str | None = None,
    seq_start_idx: bool | int = False,
    seq_end_idx: bool | int = False,
    loader_batch_size: int = 64,
    worker_seed: int = RAND_SEED,
    if_encode_all: bool = True,
    learning_rate: float = 1e-4,
    lr_decay: float = 0.1,
    epochs: int = 100,
    early_stop: bool = True,
    tolerance: int = 10,
    min_epoch: int = 5,
    device: torch.device | str = DEVICE,
    all_plot_folder: str = "results/learning_curves",
    all_result_folder: str = "results/train_val_test",
    **encoder_params,
) -> dict:

    """
    A function for running pytorch model

    Args:
    - dataset_path: str, full path to the dataset, in pkl or panda readable format
        columns include: sequence, target, set, validation, mut_name (optional), mut_numb (optional)
    - encoder_name: str, the name of the encoder
    
    - embed_batch_size: int, set to 0 to encode all in a single batch
    - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
    - embed_folder: str = None, path to presaved embedding
    - seq_start_idx: bool | int = False, the index for the start of the sequence
    - seq_end_idx: bool | int = False, the index for the end of the sequence
    - loader_batch_size: int, the batch size for train, val, and test dataloader
    - worker_seed: int, the seed for dataloader
    - learning_rate: float
    - lr_decay: float, factor by which to decay LR on plateau
    - epochs: int, number of epochs to train for
    - device: torch.device or str
    - all_plot_folder: str, the parent folder path for saving all the learning curves
    - all_result_folder: str = "results/train_val_test", the parent folder for all results
    - encoder_params: kwarg, additional parameters for encoding

    Returns:
    - result_dict: dict, with the keys and dict values
        "losses": {"train_losses": np.ndarray, "val_losses": np.ndarray}
        "train": {"mse": float, 
                  "pred": np.ndarray,
                  "true": np.ndarray,
                  "ndcg": float,
                  "rho": SpearmanrResults(correlation=float, pvalue=float)}
        "val":   {"mse": float, 
                  "pred": np.ndarray,
                  "true": np.ndarray,
                  "ndcg": float,
                  "rho": SpearmanrResults(correlation=float, pvalue=float)}
        "test":  {"mse": float, 
                  "pred": np.ndarray,
                  "true": np.ndarray,
                  "ndcg": float,
                  "rho": SpearmanrResults(correlation=float, pvalue=float)}

    """

    train_loader, val_loader, test_loader = split_protrain_loader(
        dataset_path=dataset_path,
        encoder_name=encoder_name,
        reset_param=reset_param,
        resample_param=resample_param,
        embed_batch_size=embed_batch_size,
        flatten_emb=flatten_emb,
        embed_folder=embed_folder,
        seq_start_idx=seq_start_idx,
        seq_end_idx=seq_end_idx,
        subset_list=["train", "val", "test"],
        loader_batch_size=loader_batch_size,
        worker_seed=worker_seed,
        if_encode_all=if_encode_all,
        **encoder_params,
    )

    encoder_name, encoder_class, total_emb_layer = get_emb_info(encoder_name)

    if encoder_class == ESMEncoder:
        encoder_info_dict = TRANSFORMER_INFO
    elif encoder_class == CARPEncoder:
        encoder_info_dict = CARP_INFO

    # for each layer train the model and save the model
    for embed_layer in range(total_emb_layer):

        pass

In [5]:
Run_Pytorch(
    dataset_path="data/proeng/thermo/mixed_split.csv",
    encoder_name="esm1b_t33_650M_UR50S",
    reset_param = False,
    resample_param = False,
    embed_batch_size = 128,
    flatten_emb= "mean",
    embed_folder= "embeddings/proeng/thermo/mixed_split",
    seq_start_idx= False,
    seq_end_idx = False,
    loader_batch_size = 64,
    worker_seed = RAND_SEED,
    if_encode_all = False,
    learning_rate = 1e-4,
    lr_decay = 0.1,
    epochs = 100,
    early_stop = True,
    tolerance = 10,
    min_epoch = 5,
    device = DEVICE,
    all_plot_folder = "test/learning_curves",
    all_result_folder = "test/train_val_test",
    # **encoder_params,
    )

Generating esm1b_t33_650M_UR50S upto 33 layer embedding ...


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main


Generating esm1b_t33_650M_UR50S upto 33 layer embedding ...


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main


Generating esm1b_t33_650M_UR50S upto 33 layer embedding ...


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main
100%|██████████| 34/34 [00:02<00:00, 15.72it/s]


Running model for layer 0




Running model for layer 1




Running model for layer 2




Running model for layer 3




Running model for layer 4




Running model for layer 5




Running model for layer 6




Running model for layer 7




Running model for layer 8




Running model for layer 9




Running model for layer 10




Running model for layer 11




Running model for layer 12




Running model for layer 13
Running model for layer 14Running model for layer 15





KeyboardInterrupt: 