In [1]:
%cd ~/repo/protein-transfer

/home/t-fli/repo/protein-transfer


In [2]:
%load_ext blackcellmagic

In [28]:
import numpy as np
import pandas as pd

In [22]:
pd.read_csv("data/structure/secondary_structure/casp12.csv").target.dtype.kind 

'O'

In [23]:
pd.read_csv("data/structure/secondary_structure/casp12.csv").target[0]

'[2, 2, 2, 0, 0, 0, 0, 0, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2]'

In [25]:
df = pd.read_csv("data/annotation/scl/balanced.csv")
df.head()

Unnamed: 0,sequence,target,set,validation
0,MEVLEEPAPGPGGADAAERRGLRRLLLSGFQEELRALLVLAGPAFL...,Cell membrane,train,
1,MMKTLSSGNCTLNVPAKNSYRMVVLGASRVGKSSIVSRFLNGRFED...,Cell membrane,train,
2,MAKRTFSNLETFLIFLLVMMSAITVALLSLLFITSGTIENHKDLGG...,Cell membrane,train,
3,MGNCQAGHNLHLCLAHHPPLVCATLILLLLGLSGLGLGSFLLTHRT...,Cell membrane,train,
4,MDPSKQGTLNRVENSVYRTAFKLRSVQTLCQLDLMDSFLIQQVLWR...,Cell membrane,train,


In [6]:
df.validation.unique(), df.set.unique()

(array([nan, True], dtype=object), array(['train', 'test'], dtype=object))

In [7]:
df.target.unique()

array(['Cell membrane', 'Cytoplasm', 'Endoplasmic reticulum',
       'Golgi apparatus', 'Lysosome/Vacuole', 'Mitochondrion', 'Nucleus',
       'Peroxisome', 'Plastid', 'Extracellular'], dtype=object)

In [27]:
df.target.nunique()

10

In [36]:
df2 = pd.read_csv("data/structure/secondary_structure/casp12.csv")
len(np.unique(np.array(df2["target"][0][1:-1].split(", "))))
# df2["target"] = df2["target"].apply(lambda x: np.array(x[1:-1].split(", ")))

3

In [45]:
class DatasetInfo:
    """
    A class returns the information of a dataset
    """

    def __init__(self, dataset_path: str) -> None:
        """
        Args:
        - dataset_path: str, the path for the csv
        """
        self._df = pd.read_csv(dataset_path)

    def get_model_type(self) -> str:
        # pick linear regression if y numerical
        if self._df.target.dtype.kind in "iufc":
            return "LinearRegression"
        else:
            # ss3
            if "[" in self._df.target[0]:
                return "MultiLabelMultiClass"
            # annotation
            else:
                return "LinearClassifier"

    def get_numb_class(self) -> int:
        """
        A function to get number of class
        """
        if self.model_type == "LinearClassifier":
            return self._df.target.nunique()
        elif self.model_type == "MultiLabelMultiClass":
            return len(np.unique(np.array(self._df["target"][0][1:-1].split(", ")))) + 1

    @property
    def model_type(self):
        """Return the pytorch model type"""
        return self.get_model_type()

    @property
    def numb_class(self):
        """Return number of classes for classification"""
        return self.get_numb_class()

In [46]:
"[" in df2["target"][0]

True

In [49]:
DatasetInfo("data/annotation/scl/balanced.csv").model_type, DatasetInfo("data/annotation/scl/balanced.csv").numb_class

('LinearClassifier', 10)

In [50]:
DatasetInfo("data/structure/secondary_structure/casp12.csv").model_type, DatasetInfo("data/structure/secondary_structure/casp12.csv").numb_class

('MultiLabelMultiClass', 4)

In [7]:
"""Pre-processing the dataset"""

from __future__ import annotations

from collections import Sequence, defaultdict

import os
from glob import glob
import tables
import pandas as pd
import numpy as np

from sklearn.preprocessing import LabelEncoder

import torch
from torch.utils.data import Dataset, DataLoader

from scr.utils import pickle_save, pickle_load, replace_ext
from scr.params.sys import RAND_SEED
from scr.params.emb import TRANSFORMER_INFO, CARP_INFO, MAX_SEQ_LEN
from scr.preprocess.seq_loader import SeqLoader
from scr.encoding.encoding_classes import (
    AbstractEncoder,
    ESMEncoder,
    CARPEncoder,
    OnehotEncoder,
)


def get_mut_name(mut_seq: str, parent_seq: str) -> str:
    """
    A function for returning the mutant name

    Args:
    - mut_seq: str, the full mutant sequence
    - parent_seq: str, the full parent sequence

    Returns:
    - str, parent, indel, or mutant name in the format of
        ParentAAMutLocMutAA:ParentAAMutLocMutAA:..., ie. W39W:D40G:G41C:V54Q
    """

    mut_list = []
    if parent_seq == mut_seq:
        return "parent"
    elif len(parent_seq) == len(mut_seq):
        for i, (p, m) in enumerate(zip(list(parent_seq), list(mut_seq))):
            if p != m:
                mut_list.append(f"{p}{i+1}{m}")
        return ":".join(mut_list)
    else:
        return "indel"


class AddMutInfo:
    """A class for appending mutation info for mainly protein engineering tasks"""

    def __init__(self, parent_seq_path: str, csv_path: str):

        """
        Args:
        - parent_seq_path: str, path for the parent sequence
        - csv_path: str, path for the fitness csv file
        """

        # Load the parent sequence from the fasta file
        self._parent_seq = SeqLoader(parent_seq_path=parent_seq_path)

        # load the dataframe
        self._init_df = pd.read_csv(csv_path)

        self._df = self._init_df.copy()
        # add a column with the mutant names
        self._df["mut_name"] = self._init_df["sequence"].apply(
            get_mut_name, parent_seq=self._parent_seq
        )
        # add a column with the number of mutations
        self._df["mut_numb"] = (
            self._df["mut_name"].str.split(":").map(len, na_action="ignore")
        )

        # get the pickle file path
        self._pkl_path = replace_ext(input_path=csv_path, ext=".pkl")

        pickle_save(what2save=self._df, where2save=self._pkl_path)

    @property
    def parent_seq(self) -> str:
        """Return the parent sequence"""
        return self._parent_seq

    @property
    def pkl_path(self) -> str:
        """Return the pkl file path for the processed dataframe"""
        return self._pkl_path

    @property
    def df(self) -> pd.DataFrame:
        """Return the processed dataframe"""
        return self._df


def std_ssdf(
    ssdf_path: str = "data/structure/secondary_structure/tape_ss3.csv",
) -> None:
    """
    A function that standardize secondary structure dataset
    to set up as columns as sequence, target, set, validation
    where set is train or test and add validation as true
    """
    folder_path = os.path.dirname(ssdf_path)

    df = pd.read_csv(ssdf_path)

    # convert the string into numpy array
    # df["ss3"] = df["ss3"].apply(lambda x: np.array(x[1:-1].split(", ")))

    # add validation column
    df["validation"] = df["split"].apply(lambda x: True if x == "valid" else "")
    # now replace valid to train
    df = df.replace("valid", "train")
    # rename all columns
    df.columns = ["sequence", "target", "set", "validation"]
    
    # get all kinds of test sets
    ss_tests = set(df["set"].unique()) - set(["train"])

    for ss_test in ss_tests:
        df.loc[~df["set"].isin(set(ss_tests) - set([ss_test]))].replace(
            ss_test, "test"
        ).to_csv(os.path.join(folder_path, ss_test + ".csv"), index=False)


class DatasetInfo:
    """
    A class returns the information of a dataset
    """

    def __init__(self, dataset_path: str) -> None:
        """
        Args:
        - dataset_path: str, the path for the csv
        """
        self._df = pd.read_csv(dataset_path)

    def get_model_type(self) -> str:
        # pick linear regression if y numerical
        if self._df.target.dtype.kind in "iufc":
            return "LinearRegression"
        else:
            # ss3
            if "[" in self._df.target[0]:
                return "MultiLabelMultiClass"
            # annotation
            else:
                return "LinearClassifier"

    def get_numb_class(self) -> int:
        """
        A function to get number of class
        """
        # annotation class number
        if self.model_type == "LinearClassifier":
            return self._df.target.nunique()
        # ss3 or ss8 secondary structure states plus padding
        elif self.model_type == "MultiLabelMultiClass":
            return len(np.unique(np.array(self._df["target"][0][1:-1].split(", ")))) + 1

    @property
    def model_type(self) -> str:
        """Return the pytorch model type"""
        return self.get_model_type()

    @property
    def numb_class(self) -> int:
        """Return number of classes for classification"""
        return self.get_numb_class()


class TaskProcess:
    """A class for handling different downstream tasks"""

    def __init__(self, data_folder: str = "data/"):
        """
        Args:
        - data_folder: str, a folder path with all the tasks as subfolders where
            all the subfolders have datasets as the subsubfolders, ie

            {data_folder}/
                proeng/
                    aav/
                        one_vs_many.csv
                        two_vs_many.csv
                        P03135.fasta
                    thermo/
                        mixed.csv
        """

        if data_folder[-1] == "/":
            self._data_folder = data_folder
        else:
            self._data_folder = data_folder + "/"

        # sumamarize all files i nthe data folder
        self._sum_file_df = self.sum_files()

    def sum_files(self) -> pd.DataFrame:
        """
        Summarize all files in the data folder

        Returns:
        - A dataframe with "task", "dataset", "split",
            "csv_path", "fasta_path", "pkl_path" as columns, ie.
            (proeng, gb1, low_vs_high, data/proeng/gb1/low_vs_high.csv,
            data/proeng/gb1/5LDE_1.fasta)
            note that csv_path is the list of lmdb files for the structure task
        """
        dataset_folders = glob(f"{self._data_folder}*/*")
        # need a list of tuples in the order of:
        # (task, dataset, split, csv_path, fasta_path)
        list_for_df = []
        for dataset_folder in dataset_folders:
            _, task, dataset = dataset_folder.split("/")
            if task == "structure":
                structure_file_list = [
                    file_path
                    for file_path in glob(f"{dataset_folder}/*.*")
                    if os.path.basename(os.path.splitext(file_path)[0]).split("_")[-1]
                    in ["train", "valid", "cb513"]
                ]
                list_for_df.append(
                    tuple([task, dataset, "cb513", structure_file_list, "", ""])
                )
            else:
                csv_paths = glob(f"{dataset_folder}/*.csv")
                fasta_paths = glob(f"{dataset_folder}/*.fasta")
                pkl_paths = glob(f"{dataset_folder}/*.pkl")

                assert len(csv_paths) >= 1, "Less than one csv"
                assert len(fasta_paths) <= 1, "More than one fasta"

                for csv_path in csv_paths:
                    # if parent seq fasta exists
                    if len(fasta_paths) == 1:
                        fasta_path = fasta_paths[0]

                        # if no existing pkl file, generate and save
                        if len(pkl_paths) == 0:
                            print(f"Adding mutation info to {csv_path}...")
                            pkl_path = AddMutInfo(
                                parent_seq_path=fasta_path, csv_path=csv_path
                            ).pkl_path
                        # pkl file exits
                        else:
                            pkl_path = replace_ext(input_path=csv_path, ext=".pkl")
                    # no parent fasta no pkl file
                    else:
                        fasta_path = ""
                        pkl_path = ""

                    list_for_df.append(
                        tuple(
                            [
                                task,
                                dataset,
                                os.path.basename(os.path.splitext(csv_path)[0]),
                                csv_path,
                                fasta_path,
                                pkl_path,
                            ]
                        )
                    )

        return pd.DataFrame(
            list_for_df,
            columns=["task", "dataset", "split", "csv_path", "fasta_path", "pkl_path"],
        )

    @property
    def sum_file_df(self) -> pd.DataFrame:
        """A summary table for all files in the data folder"""
        return self._sum_file_df


class ProtranDataset(Dataset):

    """A dataset class for processing protein transfer data"""

    def __init__(
        self,
        dataset_path: str,
        subset: str,
        encoder_name: str,
        reset_param: bool = False,
        resample_param: bool = False,
        embed_batch_size: int = 0,
        flatten_emb: bool | str = False,
        embed_folder: str = None,
        seq_start_idx: bool | int = False,
        seq_end_idx: bool | int = False,
        if_encode_all: bool = True,
        **encoder_params,
    ):

        """
        Args:
        - dataset_path: str, full path to the dataset, in pkl or panda readable format, ie
            "data/proeng/gb1/low_vs_high.csv"
            columns include: sequence, target, set, validation,
            mut_name (optional), mut_numb (optional)
        - subset: str, train, val, test
        - encoder_name: str, the name of the encoder
        - reset_param: bool = False, if update the full model to xavier_uniform_
        - resample_param: bool = False, if update the full model to xavier_normal_
        - embed_batch_size: int, set to 0 to encode all in a single batch
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - embed_folder: str = None, path to presaved embedding folder, ie
            "embeddings/proeng/gb1/low_vs_high"
            for which then can add the subset to be, ie
            "embeddings/proeng/gb1/low_vs_high/esm1_t6_43M_UR50S/mean/test/embedding.h5"
        - seq_start_idx: bool | int = False, the index for the start of the sequence
        - seq_end_idx: bool | int = False, the index for the end of the sequence
        - if_encode_all: bool = True, if encode full dataset all layers on the fly
        - encoder_params: kwarg, additional parameters for encoding
        """

        # with additional info mut_name, mut_numb
        if os.path.splitext(dataset_path)[-1] in [".pkl", ".PKL", ""]:
            self._df = pickle_load(dataset_path)
            self._add_mut_info = True
        # without such info
        else:
            self._df = pd.read_csv(dataset_path)
            self._ds_info = DatasetInfo(dataset_path)
            self._model_type = self._ds_info.model_type
            self._numb_class = self._ds_info.numb_class

            self._add_mut_info = False

        assert "set" in self._df.columns, f"set is not a column in {dataset_path}"
        assert (
            "validation" in self._df.columns
        ), f"validation is not a column in {dataset_path}"

        self._df_train = self._df.loc[
            (self._df["set"] == "train") & (self._df["validation"] != True)
        ]
        self._df_val = self._df.loc[
            (self._df["set"] == "train") & (self._df["validation"] == True)
        ]
        self._df_test = self._df.loc[(self._df["set"] == "test")]

        self._df_dict = {
            "train": self._df_train,
            "val": self._df_val,
            "test": self._df_test,
        }

        assert subset in list(
            self._df_dict.keys()
        ), "split can only be 'train', 'val', or 'test'"
        self._subset = subset

        self._subdf_len = len(self._df_dict[self._subset])

        # not specified seq start will be from 0
        if seq_start_idx == False:
            self._seq_start_idx = 0
        else:
            self._seq_start_idx = int(seq_start_idx)
        # not specified seq end will be the full sequence length
        if seq_end_idx == False:
            self._seq_end_idx = -1
            self._max_seq_len = self._df.sequence.str.len().max()
        else:
            self._seq_end_idx = int(seq_end_idx)
            self._max_seq_len = self._seq_end_idx - self._seq_start_idx

        # get unencoded string of input sequence
        # will need to convert data type
        self.sequence = self._get_column_value("sequence")

        self.if_encode_all = if_encode_all
        self._embed_folder = embed_folder

        self._encoder_name = encoder_name
        self._flatten_emb = flatten_emb

        # get the encoder class
        if self._encoder_name in TRANSFORMER_INFO.keys():
            encoder_class = ESMEncoder
        elif self._encoder_name in CARP_INFO.keys():
            encoder_class = CARPEncoder
        else:
            encoder_class = OnehotEncoder
            encoder_params["max_seq_len"] = self._max_seq_len

        # get the encoder
        self._encoder = encoder_class(
            encoder_name=self._encoder_name,
            reset_param=reset_param,
            resample_param=resample_param,
            **encoder_params,
        )
        self._total_emb_layer = self._encoder.total_emb_layer

        # encode all and load in memory
        if self.if_encode_all and self._embed_folder is None:
            # encode the sequences without the mut_name
            # init an empty dict with empty list to append emb
            encoded_dict = defaultdict(list)

            # use the encoder generator for batch emb
            # assume no labels included
            for encoded_batch_dict in self._encoder.encode(
                mut_seqs=self.sequence,
                batch_size=embed_batch_size,
                flatten_emb=self._flatten_emb,
            ):

                for layer, emb in encoded_batch_dict.items():
                    encoded_dict[layer].append(emb)

            # assign each layer as its own variable
            for layer, emb in encoded_dict.items():
                setattr(
                    self,
                    "layer" + str(layer),
                    np.vstack(emb)
                    # torch.tensor(np.vstack(emb), dtype=torch.float32),
                )

        # get and format the fitness or secondary structure values
        # can be numbers or string
        # will need to convert data type
        # make 1D tensor 2D
        self.y = np.expand_dims(self._get_column_value("target"), 1)

        # add mut_name and mut_numb for relevant proeng datasets
        if self._add_mut_info:
            self.mut_name = self._get_column_value("mut_name")
            self.mut_numb = self._get_column_value("mut_numb")
        else:
            self.mut_name = [""] * self._subdf_len
            self.mut_numb = [np.nan] * self._subdf_len

    def __len__(self):
        """Return the length of the selected subset of the dataframe"""
        return self._subdf_len

    def __getitem__(self, idx: int):

        """
        Return the item in the order of
        target (y), sequence, mut_name (optional), mut_numb (optional),
        embedding per layer upto the max number of layer for the encoder

        Args:
        - idx: int
        """
        if self.if_encode_all and self._embed_folder is None:
            return (
                self.y[idx],
                self.sequence[idx],
                self.mut_name[idx],
                self.mut_numb[idx],
                *(
                    getattr(self, "layer" + str(layer))[idx]
                    for layer in range(self._total_emb_layer)
                ),
            )
        elif self._embed_folder is not None:
            # load the .h5 file with the embeddings
            """
            gb1_emb = tables.open_file("embeddings/proeng/gb1/low_vs_high/esm1_t6_43M_UR50S/mean/test/embedding.h5")
            gb1_emb.flush()
            gb1_emb.root.layer0[0:5]
            """
            emb_table = tables.open_file(
                os.path.join(
                    self._embed_folder,
                    self._encoder_name,
                    self._flatten_emb,
                    self._subset,
                    "embedding.h5",
                )
            )

            emb_table.flush()

            layer_embs = [
                getattr(emb_table.root, "layer" + str(layer))[idx]
                for layer in range(self._total_emb_layer)
            ]

            emb_table.close()

            return (
                self.y[idx],
                self.sequence[idx],
                self.mut_name[idx],
                self.mut_numb[idx],
                layer_embs,
            )
        else:
            return (
                self.y[idx],
                self.sequence[idx],
                self.mut_name[idx],
                self.mut_numb[idx],
            )

    def _get_column_value(self, column_name: str) -> np.ndarray:
        """
        Check and return the column values of the selected dataframe subset

        Args:
        - column_name: str, the name of the dataframe column
        """
        if column_name in self._df.columns:
            y = self._df_dict[self._subset][column_name]

            if column_name == "sequence":
                return (
                    # self._df_dict[self._subset]["sequence"]
                    y
                    .astype(str)
                    .str[self._seq_start_idx : self._seq_end_idx]
                    .apply(
                        lambda x: x[: int(MAX_SEQ_LEN // 2)]
                        + x[-int(MAX_SEQ_LEN // 2) :]
                        if len(x) > MAX_SEQ_LEN
                        else x
                    )
                    .values
                )
            elif column_name == "target":
                if self._model_type == "LinearClassifier":
                    print("Converting classes into int...")
                    le = LabelEncoder()
                    return le.fit_transform(y.values.flatten())
                elif self._model_type == "MultiLabelMultiClass":
                    print("Converting ss3/ss8 into np.array...")
                    return y.apply(lambda x: np.array(x[1:-1].split(", "))).values
            else:
                return y.values

    @property
    def df_full(self) -> pd.DataFrame:
        """Return the full loaded dataset"""
        return self._df

    @property
    def df_train(self) -> pd.DataFrame:
        """Return the dataset for training only"""
        return self._df_train

    @property
    def df_val(self) -> pd.DataFrame:
        """Return the dataset for validation only"""
        return self._df_val

    @property
    def df_test(self) -> pd.DataFrame:
        """Return the dataset for training only"""
        return self._df_test

    @property
    def max_seq_len(self) -> int:
        """Longest sequence length"""
        return self._max_seq_len


def split_protrain_loader(
    dataset_path: str,
    encoder_name: str,
    reset_param: bool = False,
    resample_param: bool = False,
    embed_batch_size: int = 128,
    flatten_emb: bool | str = False,
    embed_folder: str | None = None,
    seq_start_idx: bool | int = False,
    seq_end_idx: bool | int = False,
    subset_list: list[str] = ["train", "val", "test"],
    loader_batch_size: int = 64,
    worker_seed: int = RAND_SEED,
    if_encode_all: bool = True,
    **encoder_params,
):

    """
    A function encode and load the data from a path

    Args:
    - dataset_path: str, full path to the dataset, in pkl or panda readable format, ie
        "data/proeng/gb1/low_vs_high.csv"
        columns include: sequence, target, set, validation,
        mut_name (optional), mut_numb (optional)
    - encoder_name: str, the name of the encoder
    - reset_param: bool = False, if update the full model to xavier_uniform_
    - resample_param: bool = False, if update the full model to xavier_normal_
    - embed_batch_size: int, set to 0 to encode all in a single batch
    - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
    - embed_folder: str = None, path to presaved embedding
    - seq_start_idx: bool | int = False, the index for the start of the sequence
    - seq_end_idx: bool | int = False, the index for the end of the sequence
    - subset_list: list of str, train, val, test
    - loader_batch_size: int, the batch size for train, val, and test dataloader
    - worker_seed: int, the seed for dataloader
    - if_encode_all: bool = True, if encode full dataset all layers on the fly
    - encoder_params: kwarg, additional parameters for encoding
    """

    assert set(subset_list) <= set(
        ["train", "val", "test"]
    ), "subset_list can only contain terms with in be 'train', 'val', or 'test'"

    # specify no shuffling for validation and test
    if_shuffle_list = [True if subset == "train" else False for subset in subset_list]

    return (
        DataLoader(
            dataset=ProtranDataset(
                dataset_path=dataset_path,
                subset=subset,
                encoder_name=encoder_name,
                reset_param=reset_param,
                resample_param=resample_param,
                embed_batch_size=embed_batch_size,
                flatten_emb=flatten_emb,
                embed_folder=embed_folder,
                seq_start_idx=seq_start_idx,
                seq_end_idx=seq_end_idx,
                if_encode_all=if_encode_all,
                **encoder_params,
            ),
            batch_size=loader_batch_size,
            shuffle=if_shuffle,
            worker_init_fn=worker_seed,
        )
        for subset, if_shuffle in zip(subset_list, if_shuffle_list)
    )

In [4]:
"""A script with model training and testing details assuming"""

from __future__ import annotations

import random
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from sklearn.metrics import mean_squared_error, log_loss, accuracy_score, roc_auc_score
from sklearn.preprocessing import LabelEncoder

from scr.params.sys import RAND_SEED, DEVICE

# seed everything
random.seed(RAND_SEED)
np.random.seed(RAND_SEED)
torch.manual_seed(RAND_SEED)
torch.cuda.manual_seed(RAND_SEED)
torch.cuda.manual_seed_all(RAND_SEED)
torch.backends.cudnn.deterministic = True


def get_x_y(
    model, device, batch, embed_layer: int,
):
    """
    A function process x and y from the loader

    Args:
    -

    Returns:
    - x
    - y
    """
    print(batch)
    # for each batch: y, sequence, mut_name, mut_numb, [layer0, ...]
    x = batch[4][embed_layer]
    y = batch[0]

    """
    # process y depends on model type
    # annotation classification
    if model.model_name == "LinearClassifier":
        le = LabelEncoder()
        y = le.fit_transform(y.flatten())
    """
    
    # ss3 / ss8 type
    if model.model_name == "MultiLabelMultiClass":
        # convert the y into np.arrays with -1 padding to the same length
        y = np.stack(
            [
                np.pad(
                    i,
                    pad_width=(0, x.shape[1] - len(i)),
                    constant_values=-1,
                )
                for i in y
            ]
        )

    return x.to(device, non_blocking=True), y.to(device, non_blocking=True)


def run_epoch(
    model: nn.Module,
    loader: DataLoader,
    encoder_name: str,
    embed_layer: int,
    reset_param: bool = False,
    resample_param: bool = False,
    embed_batch_size: int = 0,
    flatten_emb: bool | str = False,
    # if_encode_all: bool = True,
    device: torch.device | str = DEVICE,
    criterion: nn.Module | None = None,
    optimizer: torch.optim.Optimizer | None = None,
    **encoder_params,
) -> float:

    """
    Runs one epoch.
    
    Args:
    - model: nn.Module, already moved to device
    - loader: torch.utils.data.DataLoader
    - device: torch.device or str
    - criterion: optional nn.Module, loss function, already moved to device
    - optimizer: optional torch.optim.Optimizer, must also provide criterion,
        only provided for training

    Returns: 
    - float, average loss over batches
    """
    if optimizer is not None:
        assert criterion is not None
        model.train()
        is_train = True
    else:
        model.eval()
        is_train = False

    cum_loss = 0.0

    with torch.set_grad_enabled(is_train):
        # if not if_encode_all:
        # for each batch: y, sequence, mut_name, mut_numb, [layer0, ...]
        for batch in loader:
            x, y = get_x_y(model, device, batch, embed_layer)
            print(x, y)

            """
            x = batch[4][embed_layer]
            y = batch[0]

            # process y depends on model type
            # annotation classification
            if model.model_name == "LinearClassifier":
                le = LabelEncoder()
                y = le.fit_transform(y.flatten())
            # ss3 / ss8 type
            elif model.model_name == "MultiLabelMultiClass":
                # convert the y into np.arrays with -1 padding to the same length
                y = np.stack(
                    [
                        np.pad(
                            np.array(i[1:-1].split(", ")),
                            pad_width=(0, x.shape[1] - len(i)),
                            constant_values=-1,
                        )
                        for i in y
                    ]
                )

            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            """

            outputs = model(x)

            if criterion is not None:
                loss = criterion(outputs, y.float())

                if optimizer is not None:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                cum_loss += loss.item()

    return cum_loss / len(loader)


def train(
    model: nn.Module,
    criterion: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    encoder_name: str,
    embed_layer: int,
    reset_param: bool = False,
    resample_param: bool = False,
    embed_batch_size: int = 0,
    flatten_emb: bool | str = False,
    device: torch.device | str = DEVICE,
    learning_rate: float = 1e-4,
    lr_decay: float = 0.1,
    epochs: int = 100,
    early_stop: bool = True,
    tolerance: int = 10,
    min_epoch: int = 5,
    **encoder_params,
) -> tuple[np.ndarray, np.ndarray]:

    """
    Args:
    - model: nn.Module, already moved to device
    - train_loader: torch.utils.data.DataLoader, 
    - val_loader: torch.utils.data.DataLoader, 
    - criterion: nn.Module, loss function, already moved to device
    - device: torch.device or str
    - learning_rate: float
    - lr_decay: float, factor by which to decay LR on plateau
    - epochs: int, number of epochs to train for
    - early_stop: bool = True,

    Returns: 
    - tuple of np.ndarray, (train_losses, val_losses)
        train/val_losses: np.ndarray, shape [epochs], entries are average loss
        over batches for that epoch
    """

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=lr_decay
    )

    train_losses = np.zeros(epochs)
    val_losses = np.zeros(epochs)

    # init for early stopping
    counter = 0
    min_val_loss = np.Inf

    for epoch in tqdm(range(epochs)):

        train_losses[epoch] = run_epoch(
            model=model,
            loader=train_loader,
            encoder_name=encoder_name,
            embed_layer=embed_layer,
            reset_param=reset_param,
            resample_param=resample_param,
            embed_batch_size=embed_batch_size,
            flatten_emb=flatten_emb,
            device=device,
            criterion=criterion,
            optimizer=optimizer,
            **encoder_params,
        )

        val_loss = run_epoch(
            model=model,
            loader=val_loader,
            encoder_name=encoder_name,
            embed_layer=embed_layer,
            reset_param=reset_param,
            resample_param=resample_param,
            embed_batch_size=embed_batch_size,
            flatten_emb=flatten_emb,
            device=device,
            criterion=criterion,
            optimizer=None,
            **encoder_params,
        )
        val_losses[epoch] = val_loss

        scheduler.step(val_loss)

        if early_stop:
            # when val loss decrease, reset min loss and counter
            if val_loss < min_val_loss:
                min_val_loss = val_loss
                counter = 0
            else:
                counter += 1

            if epoch > min_epoch and counter == tolerance:
                break

    return train_losses, val_losses


def test(
    model: nn.Module,
    loader: DataLoader,
    embed_layer: int,
    criterion: nn.Module | None,
    device: torch.device | str = DEVICE,
    print_every: int = 1000,
) -> tuple[float, np.ndarray, np.ndarray]:
    """
    Runs one epoch of testing, returning predictions and labels.
    
    Args:
    - model: nn.Module, already moved to device
    - device: torch.device or str
    - loader: torch.utils.data.DataLoader
    - criterion: optional nn.Module, loss function, already moved to device
    - print_every: int, how often (number of batches) to print avg loss
    
    Returns: tuple (avg_loss, preds, labels)
    - avg_loss: float, average loss per training example 
    - preds: np.ndarray, shape [num_examples, ...], predictions over dataset
    - labels: np.ndarray, shape [num_examples, ...], dataset labels
    """
    model.eval()
    msg = "[{step:5d}] loss: {loss:.3f}"

    cum_loss = 0.0

    pred_probs = []
    pred_classes = []
    labels = []

    with torch.no_grad():

        for i, batch in enumerate(tqdm(loader)):
            # for each batch: y, sequence, mut_name, mut_numb, [layer0, ...]
            x, y = get_x_y(model, device, batch, embed_layer)

            """
            x = batch[4][embed_layer]
            y = batch[0]

            x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
            """

            # forward + backward + optimize
            outputs = model(x)
            print(outputs)

            # append results
            labels.extend(y.detach().cpu().squeeze().numpy())

            # append class
            if model.model_name == "LinearClassifier":
                pred_classes.extend(
                    outputs.detach()
                    .cpu()
                    .data.max(1, keepdim=True)[1]
                    .squeeze()
                    .numpy()
                )
                pred_probs.append(outputs.detach().cpu().squeeze().numpy())
            else:
                pred_probs.extend(outputs.detach().cpu().squeeze().numpy())

            if criterion is not None:
                loss = criterion(outputs, y)
                cum_loss += loss.item()

                if ((i + 1) % print_every == 0) or (i + 1 == len(loader)):
                    tqdm.write(msg.format(step=i + 1, loss=cum_loss / len(loader)))

    avg_loss = cum_loss / len(loader)
    return avg_loss, pred_probs, pred_classes, labels

In [5]:
"""Script for running pytorch models"""

from __future__ import annotations

import os
from tqdm import tqdm
from concurrent import futures

import torch
import torch.nn as nn

from sklearn.metrics import ndcg_score
from scipy.stats import spearmanr

from scr.params.sys import RAND_SEED, DEVICE
from scr.params.emb import TRANSFORMER_INFO, CARP_INFO

# from scr.preprocess.data_process import split_protrain_loader, DatasetInfo
from scr.preprocess.data_process import DatasetInfo
from scr.encoding.encoding_classes import get_emb_info, ESMEncoder, CARPEncoder
from scr.model.pytorch_model import LinearRegression, LinearClassifier
# from scr.model.train_test import train, test
from scr.vis.learning_vis import plot_lc
from scr.utils import get_folder_file_names, pickle_save, get_default_output_path


class Run_Pytorch:
    def __init__(
        self,
        dataset_path: str,
        encoder_name: str,
        reset_param: bool = False,
        resample_param: bool = False,
        embed_batch_size: int = 128,
        flatten_emb: bool | str = False,
        embed_folder: str | None = None,
        seq_start_idx: bool | int = False,
        seq_end_idx: bool | int = False,
        loader_batch_size: int = 64,
        worker_seed: int = RAND_SEED,
        if_encode_all: bool = True,
        if_multiprocess: bool = False,
        learning_rate: float = 1e-4,
        lr_decay: float = 0.1,
        epochs: int = 100,
        early_stop: bool = True,
        tolerance: int = 10,
        min_epoch: int = 5,
        device: torch.device | str = DEVICE,
        all_plot_folder: str = "results/learning_curves",
        all_result_folder: str = "results/train_val_test",
        **encoder_params,
    ) -> None:

        """
        A function for running pytorch model

        Args:
        - dataset_path: str, full path to the dataset, in pkl or panda readable format
            columns include: sequence, target, set, validation, mut_name (optional), mut_numb (optional)
        - encoder_name: str, the name of the encoder
        
        - embed_batch_size: int, set to 0 to encode all in a single batch
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - embed_folder: str = None, path to presaved embedding
        - seq_start_idx: bool | int = False, the index for the start of the sequence
        - seq_end_idx: bool | int = False, the index for the end of the sequence
        - loader_batch_size: int, the batch size for train, val, and test dataloader
        - worker_seed: int, the seed for dataloader
        - learning_rate: float
        - lr_decay: float, factor by which to decay LR on plateau
        - epochs: int, number of epochs to train for
        - device: torch.device or str
        - all_plot_folder: str, the parent folder path for saving all the learning curves
        - all_result_folder: str = "results/train_val_test", the parent folder for all results
        - encoder_params: kwarg, additional parameters for encoding

        Returns:
        - result_dict: dict, with the keys and dict values
            "losses": {"train_losses": np.ndarray, "val_losses": np.ndarray}
            "train": {"mse": float, 
                    "pred": np.ndarray,
                    "true": np.ndarray,
                    "ndcg": float,
                    "rho": SpearmanrResults(correlation=float, pvalue=float)}
            "val":   {"mse": float, 
                    "pred": np.ndarray,
                    "true": np.ndarray,
                    "ndcg": float,
                    "rho": SpearmanrResults(correlation=float, pvalue=float)}
            "test":  {"mse": float, 
                    "pred": np.ndarray,
                    "true": np.ndarray,
                    "ndcg": float,
                    "rho": SpearmanrResults(correlation=float, pvalue=float)}

        """

        self._dataset_path = dataset_path
        self._encoder_name = encoder_name
        self._reset_param = reset_param
        self._resample_param = resample_param
        self._embed_batch_size = embed_batch_size
        self._flatten_emb = flatten_emb

        self._learning_rate = learning_rate
        self._lr_decay = lr_decay
        self._epochs = epochs
        self._early_stop = early_stop
        self._tolerance = tolerance
        self._min_epoch = min_epoch
        self._device = device
        self._all_plot_folder = all_plot_folder
        self._all_result_folder = all_result_folder
        self._encoder_params = encoder_params

        self._ds_info = DatasetInfo(self._dataset_path)
        self._model_type = self._ds_info.model_type
        self._numb_class = self._ds_info.numb_class

        self._train_loader, self._val_loader, self._test_loader = split_protrain_loader(
            dataset_path=self._dataset_path,
            encoder_name=self._encoder_name,
            reset_param=self._reset_param,
            resample_param=self._resample_param,
            embed_batch_size=self._embed_batch_size,
            flatten_emb=self._flatten_emb,
            embed_folder=embed_folder,
            seq_start_idx=seq_start_idx,
            seq_end_idx=seq_end_idx,
            subset_list=["train", "val", "test"],
            loader_batch_size=loader_batch_size,
            worker_seed=worker_seed,
            if_encode_all=if_encode_all,
            **encoder_params,
        )

        encoder_name, encoder_class, total_emb_layer = get_emb_info(encoder_name)

        if encoder_class == ESMEncoder:
            self._encoder_info_dict = TRANSFORMER_INFO
        elif encoder_class == CARPEncoder:
            self._encoder_info_dict = CARP_INFO

        if if_multiprocess:
            print("Running different emb layer in parallel...")
            # add the thredpool max_workers=None
            with futures.ProcessPoolExecutor(max_workers=os.cpu_count() - 1) as pool:
                # for each layer train the model and save the model
                for embed_layer in tqdm(range(total_emb_layer)):
                    pool.submit(self.run_pytorch_layer, embed_layer)

        else:
            for embed_layer in range(total_emb_layer):
                print(f"Running pytorch model for layer {embed_layer}")
                self.run_pytorch_layer(embed_layer)

    def run_pytorch_layer(self, embed_layer):

        # init model based on datasets
        if self._model_type == "LinearRegression":
            model = LinearRegression(
                input_dim=self._encoder_info_dict[self._encoder_name][0], output_dim=1
            )
            criterion = nn.MSELoss()

        elif self._model_type == "LinearClassifier":
            model = LinearClassifier(
                input_dim=self._encoder_info_dict[self._encoder_name][0], 
                numb_class=self._numb_class
            )
            criterion = nn.CrossEntropyLoss()
        
        model.to(self._device, non_blocking=True)

        criterion.to(self._device, non_blocking=True)

        train_losses, val_losses = train(
            model=model,
            criterion=criterion,
            train_loader=self._train_loader,
            val_loader=self._val_loader,
            encoder_name=self._encoder_name,
            embed_layer=embed_layer,
            reset_param=self._reset_param,
            resample_param=self._resample_param,
            embed_batch_size=self._embed_batch_size,
            flatten_emb=self._flatten_emb,
            device=self._device,
            learning_rate=self._learning_rate,
            lr_decay=self._lr_decay,
            epochs=self._epochs,
            early_stop=self._early_stop,
            tolerance=self._tolerance,
            min_epoch=self._min_epoch,
            **self._encoder_params,
        )

        # record the losses
        result_dict = {
            "losses": {"train_losses": train_losses, "val_losses": val_losses}
        }

        plot_lc(
            train_losses=train_losses,
            val_losses=val_losses,
            dataset_path=self._dataset_path,
            encoder_name=self._encoder_name,
            embed_layer=embed_layer,
            flatten_emb=self._flatten_emb,
            all_plot_folder=self._all_plot_folder,
        )

        # now test the model with the test data
        for subset, loader in zip(
            ["train", "val", "test"],
            [self._train_loader, self._val_loader, self._test_loader],
        ):
            loss, pred, cls, true = test(
                model=model, loader=loader, device=self._device, criterion=criterion
            )

            if model.model_name == "LinearRegression":
                result_dict[subset] = {
                    "mse": loss,
                    "pred": pred,
                    "true": true,
                    "ndcg": ndcg_score(true[None, :], pred[None, :]),
                    "rho": spearmanr(true, pred),
                }

            elif model.model_name == "LinearClassifier":
                result_dict[subset] = {
                    "cross-entropy": loss,
                    "pred": pred,
                    "true": true,
                    "acc": accuracy_score(true, cls),
                    "rocauc": roc_auc_score(true, pred, multi_class="ovo"),
                }            

        dataset_subfolder, file_name = get_folder_file_names(
            parent_folder=get_default_output_path(self._all_result_folder),
            dataset_path=self._dataset_path,
            encoder_name=self._encoder_name,
            embed_layer=embed_layer,
            flatten_emb=self._flatten_emb,
        )

        print(f"Saving results for {file_name} to: {dataset_subfolder}...")
        pickle_save(
            what2save=result_dict,
            where2save=os.path.join(dataset_subfolder, file_name + ".pkl"),
        )

In [8]:
Run_Pytorch(
        dataset_path="data/annotation/scl/balanced.csv",
        encoder_name="esm1_t6_43M_UR50S",
        reset_param = False,
        resample_param = False,
        embed_batch_size = 128,
        flatten_emb = "mean",
        embed_folder= "embeddings/annotation/scl/balanced",
        seq_start_idx = False,
        seq_end_idx = False,
        loader_batch_size = 64,
        worker_seed= RAND_SEED,
        if_encode_all = False,
        learning_rate = 1e-4,
        lr_decay = 0.1,
        epochs = 2,
        early_stop = True,
        tolerance = 10,
        min_epoch = 5,
        device = DEVICE,
        all_plot_folder = "results/learning_curves",
        all_result_folder = "results/train_val_test",
        # **encoder_params
    )

Generating esm1_t6_43M_UR50S upto 6 layer embedding ...


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main


Converting classes into int...
Generating esm1_t6_43M_UR50S upto 6 layer embedding ...


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main


Converting classes into int...
Generating esm1_t6_43M_UR50S upto 6 layer embedding ...


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main
  0%|          | 0/2 [00:00<?, ?it/s]

Converting classes into int...
Running pytorch model for layer 0
[tensor([[4],
        [9],
        [1],
        [9],
        [1],
        [7],
        [5],
        [7],
        [8],
        [1],
        [1],
        [7],
        [3],
        [1],
        [1],
        [5],
        [3],
        [7],
        [7],
        [7],
        [6],
        [7],
        [7],
        [7],
        [9],
        [6],
        [1],
        [0],
        [9],
        [2],
        [3],
        [9],
        [7],
        [7],
        [0],
        [6],
        [1],
        [1],
        [9],
        [7],
        [6],
        [1],
        [7],
        [0],
        [1],
        [6],
        [7],
        [3],
        [3],
        [6],
        [7],
        [3],
        [3],
        [1],
        [9],
        [2],
        [3],
        [1],
        [2],
        [7],
        [6],
        [0],
        [3],
        [9]]), ('MSAIFNFQSLLTVILLLICTCAYIRSLAPSLLDRNKTGLLGIFWKCARIGERKSPYVAVCCIVMAFSILFI', 'ASYKVKLITPEGAVEFDCPDDVY




RuntimeError: 1D target tensor expected, multi-target not supported

In [9]:
val_ds = ProtranDataset(
                dataset_path="data/annotation/scl/balanced.csv",
                subset="",
                encoder_name="",
                reset_param=False,
                resample_param=False,
                embed_batch_size=64,
                flatten_emb="flatten",
                embed_path=None,
                seq_start_idx=0,
                seq_end_idx=False,
                # **encoder_params,
            )

  7%|▋         | 2/27 [00:00<00:01, 13.12it/s]

Generating onehot upto 0 layer embedding ...


100%|██████████| 27/27 [00:01<00:00, 13.96it/s]


In [14]:
gb1_val = ProtranDataset(
                dataset_path="data/proeng/gb1/sampled.csv",
                subset="val",
                encoder_name="",
                reset_param=False,
                resample_param=False,
                embed_batch_size=64,
                flatten_emb="flatten",
                embed_path=None,
                seq_start_idx=0,
                seq_end_idx=False,
                # **encoder_params,
            )

100%|██████████| 11/11 [00:00<00:00, 177.99it/s]

Generating onehot upto 0 layer embedding ...





In [72]:
clf = LogisticRegression(random_state=0, C=0.1, multi_class="multinomial", max_iter=1000).fit(val_ds.layer0, y)
pred_y = clf.predict(val_ds.layer0)
pred_prob = clf.predict_proba(val_ds.layer0)

In [74]:
from sklearn.metrics import log_loss, accuracy_score, roc_auc_score

In [76]:
accuracy_score(y, pred_y), log_loss(y, pred_prob), roc_auc_score(y, pred_prob, multi_class="ovo")

(1.0, 0.097391642777984, 1.0)

In [3]:
"""Script for run sklearn models"""

from __future__ import annotations

import os
import random
import numpy as np

from sklearn.linear_model import Ridge, LogisticRegression
from sklearn.metrics import mean_squared_error, log_loss, accuracy_score, roc_auc_score
from sklearn.preprocessing import StandardScaler, LabelEncoder
from scipy.stats import spearmanr

from scr.utils import get_folder_file_names, pickle_save, ndcg_scale
from scr.params.emb import TRANSFORMER_INFO, CARP_INFO, MAX_SEQ_LEN
from scr.params.sys import RAND_SEED, SKLEARN_ALPHAS
from scr.encoding.encoding_classes import ESMEncoder, CARPEncoder, OnehotEncoder
from scr.preprocess.data_process import ProtranDataset

# seed
random.seed(RAND_SEED)
np.random.seed(RAND_SEED)


class RunSK:
    """A class for running sklearn models"""

    def __init__(
        self,
        dataset_path: str,
        encoder_name: str,
        reset_param: bool = False,
        resample_param: bool = False,
        embed_batch_size: int = 128,
        flatten_emb: bool | str = False,
        embed_path: str | None = None,
        seq_start_idx: bool | int = False,
        seq_end_idx: bool | int = False,
        alphas: np.ndarray | int = SKLEARN_ALPHAS,
        sklearn_state: int = RAND_SEED,
        sklearn_params: dict | None = None,
        all_result_folder: str = "results/sklearn",
        **encoder_params,
    ) -> None:

        """
        Args:
        - dataset_path: str, full path to the dataset, in pkl or panda readable format
            columns include: sequence, target, set, validation,
            mut_name (optional), mut_numb (optional)
        - encoder_name: str, the name of the encoder
        - reset_param: bool = False, if update the full model to xavier_uniform_
        - resample_param: bool = False, if update the full model to xavier_normal_
        - embed_batch_size: int, set to 0 to encode all in a single batch
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - embed_path: str = None, path to presaved embedding
        - seq_start_idx: bool | int = False, the index for the start of the sequence
        - seq_end_idx: bool | int = False, the index for the end of the sequence
        - alphas: np.ndarray, arrays of alphas to be tested
        - sklearn_state: int = RAND_SEED, seed the ridge or logistic regression
        - sklearn_params: dict | None = None, other ridge or logistic regression args
        - all_result_folder: str = "results/train_val_test", the parent folder for all results
        - encoder_params: kwarg, additional parameters for encoding
        """

        self.dataset_path = dataset_path
        self.encoder_name = encoder_name
        self.reset_param = reset_param
        self.resample_param = resample_param
        self.flatten_emb = flatten_emb

        if not isinstance(alphas, np.ndarray):
            alphas = np.array([alphas])
        self.alphas = alphas

        self.sklearn_state = sklearn_state
        self.sklearn_params = sklearn_params
        self.all_result_folder = all_result_folder

        if self.reset_param and "-rand" not in self.all_result_folder:
            self.all_result_folder = f"{self.all_result_folder}-rand"

        if self.resample_param and "-stat" not in self.all_result_folder:
            self.all_result_folder = f"{self.all_result_folder}-stat"

        # loader has ALL embedding layers
        self.train_ds, self.val_ds, self.test_ds = (
            ProtranDataset(
                dataset_path=dataset_path,
                subset=subset,
                encoder_name=encoder_name,
                reset_param=reset_param,
                resample_param=resample_param,
                embed_batch_size=embed_batch_size,
                flatten_emb=flatten_emb,
                embed_path=embed_path,
                seq_start_idx=seq_start_idx,
                seq_end_idx=seq_end_idx,
                **encoder_params,
            )
            for subset in ["train", "val", "test"]
        )

        # pick ridge regression if y numerical
        if self.val_ds.y.dtype.kind in "iufc":
            self.sklearn_model = Ridge

        # pick logistic regression if y is categorical
        else:
            le = LabelEncoder()
            self.train_ds.y, self.val_ds.y, self.test_ds.y = [
                le.fit_transform(y.flatten())
                for y in [self.train_ds.y, self.val_ds.y, self.test_ds.y]
            ]
            self.sklearn_model = LogisticRegression
            # convert alpha to C
            self.alphas = 1 / self.alphas
            # add other params
            if self.sklearn_params is None:
                self.sklearn_params["multi_class"] = "multinomial"
                self.sklearn_params["max_iter"] = 1000

        all_sklearn_results = {}

        # TODO for easier total_emb_layer
        if self.encoder_name in TRANSFORMER_INFO.keys():
            total_emb_layer = TRANSFORMER_INFO[encoder_name][1] + 1
        elif self.encoder_name in CARP_INFO.keys():
            total_emb_layer = CARP_INFO[encoder_name][1]
        else:
            # for onehot
            self.encoder_name = "onehot"
            total_emb_layer = 1

        for layer in range(total_emb_layer):
            all_sklearn_results[layer] = self.run_sklearn_layer(embed_layer=layer,)

        self._all_sklearn_results = all_sklearn_results

    def sk_test(
        self, model: sklearn.linear_model, ds: ProtranDataset, embed_layer: int
    ):
        """
        A function for testing sklearn models for a specific layer of embeddings

        Args:
        - model: sklearn.linear_model, trained model
        - ds: ProtranDataset, train, val, or test dataset
        - embed_layer: int, specific layer of the embedding

        Returns:
        - np.concatenate(pred): np.ndarray, 1D predicted fitness values
        - np.concatenate(true): np.ndarry, 1D true fitness values
        - 
        """

        if self.sklearn_model == Ridge:
            pred_prob = None
        else:
            pred_prob = model.predict_proba(
                getattr(ds, "layer" + str(embed_layer)).cpu().numpy()
            ).squeeze()

        return (
            model.predict(
                getattr(ds, "layer" + str(embed_layer)).cpu().numpy()
            ).squeeze(),
            ds.y.squeeze(),
            pred_prob,
        )

    def pick_model(
        self, embed_layer: int,
    ):
        """
        A function for picking the best model for given alaphs, meaning
        lower train_mse and higher test_ndcg
        NOTE: alphas tuning is NOT currently optimal

        Args:
        - embed_layer: int, specific layer of the embedding

        Returns:
        - sklearn.linear_model, the model with the best alpha
        """

        # init values for comparison
        if self.sklearn_model == Ridge:
            best_mse = np.Inf
            best_ndcg = -1
            best_rho = -1
        else:
            best_loss = np.Inf
            best_acc = 0
            best_auc = 0

        best_model = None

        # loop through all alphas
        for alpha in self.alphas:

            # init model for each alpha
            if self.sklearn_params is None:
                self.sklearn_params = {}
            model = self.sklearn_model(
                alpha=alpha, random_state=self.sklearn_state, **self.sklearn_params
            )

            # fit the model for a given layer of embedding
            fitness_scaler = StandardScaler()
            model.fit(
                getattr(self.train_ds, "layer" + str(embed_layer)).cpu().numpy(),
                fitness_scaler.fit_transform(self.train_ds.y),
            )

            # eval the model with train and test
            train_pred, train_true, train_prob = self.sk_test(
                model, self.train_ds, embed_layer=embed_layer
            )
            val_pred, val_true, val_prob = self.sk_test(
                model, self.val_ds, embed_layer=embed_layer
            )

            if self.sklearn_model == Ridge:
                # calc the metrics
                train_mse = mean_squared_error(train_true, train_pred)
                val_ndcg = ndcg_scale(val_true, val_pred)
                val_rho = spearmanr(val_true, val_pred)[0]

                # update the model if it has lower train_mse and higher val_ndcg
                if train_mse < best_mse and val_ndcg > best_ndcg:
                    best_model = model
                    best_mse = train_mse
                    best_ndcg = val_ndcg
                    best_rho = val_rho

            else:
                # calc the metrics
                train_loss = log_loss(train_true, train_prob)
                val_acc = accuracy_score(val_true, val_pred)
                val_auc = roc_auc_score(val_true, val_prob, multi_class="ovo")

                # update the model if it has lower log_loss and higher val_auc
                if train_loss < best_loss and val_auc > best_auc:
                    best_loss = train_loss
                    best_acc = val_acc
                    best_auc = val_auc

        print(f"best model is {best_model}")
        return best_model

    def run_sklearn_layer(
        self, embed_layer: int,
    ):

        """
        A function for running ridge or logistics regression for a given layer of embedding

        Args:
        - embed_layer: int, specific layer of the embedding

        Returns:
        - dict, with the keys and dict values
            "train": {"mse": float,
                    "pred": np.ndarray,
                    "true": np.ndarray,
                    "ndcg": float,
                    "rho": SpearmanrResults(correlation=float, pvalue=float)}
            "val":   {"mse": float,
                    "pred": np.ndarray,
                    "true": np.ndarray,
                    "ndcg": float,
                    "rho": SpearmanrResults(correlation=float, pvalue=float)}
            "test":  {"mse": float,
                    "pred": np.ndarray,
                    "true": np.ndarray,
                    "ndcg": float,
                    "rho": SpearmanrResults(correlation=float, pvalue=float)}
        """

        # train and get the best alpha
        best_model = self.pick_model(embed_layer=embed_layer,)

        # init dict for resulted outputs
        result_dict = {}

        # now test the model with the test data
        for subset, ds in zip(
            ["train", "val", "test"], [self.train_ds, self.val_ds, self.test_ds],
        ):
            pred, true, prob = self.sk_test(best_model, ds, embed_layer=embed_layer)

            if self.sklearn_model == Ridge:
                result_dict[subset] = {
                    "mse": mean_squared_error(true, pred),
                    "pred": pred,
                    "true": true,
                    "ndcg": ndcg_scale(true, pred),
                    "rho": spearmanr(true, pred),
                }

            else:
                result_dict[subset] = {
                    "log": log_loss(true, prob),
                    "pred": pred,
                    "prob": prob,
                    "true": true,
                    "acc": accuracy_score(true, pred),
                    "rocauc": roc_auc_score(true, prob, multi_class="ovo"),
                }

        dataset_subfolder, file_name = get_folder_file_names(
            parent_folder=self.all_result_folder,
            dataset_path=self.dataset_path,
            encoder_name=self.encoder_name,
            embed_layer=embed_layer,
            flatten_emb=self.flatten_emb,
        )

        print(f"Saving results for {file_name} to: {dataset_subfolder}...")
        pickle_save(
            what2save=result_dict,
            where2save=os.path.join(dataset_subfolder, file_name + ".pkl"),
        )

        return result_dict

    @property
    def all_sklearn_results(self):
        """
        Returns:
        - dict, with the keys and dict values
            "layer#": {
                        "train": {"mse": float,
                                "pred": np.ndarray,
                                "true": np.ndarray,
                                "ndcg": float,
                                "rho": SpearmanrResults(correlation=float, pvalue=float)}
                        "val":   {"mse": float,
                                "pred": np.ndarray,
                                "true": np.ndarray,
                                "ndcg": float,
                                "rho": SpearmanrResults(correlation=float, pvalue=float)}
                        "test":  {"mse": float,
                                "pred": np.ndarray,
                                "true": np.ndarray,
                                "ndcg": float,
                                "rho": SpearmanrResults(correlation=float, pvalue=float)}
                        }
        """
        return self._all_sklearn_results

In [15]:
scl = RunSK(
        dataset_path="data/annotation/scl/balanced.csv",
        encoder_name="esm1_t6_43M_UR50S",
        reset_param=False,
        resample_param = False,
        embed_batch_size = 64,
        flatten_emb = "flatten",
        embed_path = None,
        seq_start_idx = False,
        seq_end_idx = False,
        alphas = SKLEARN_ALPHAS,
        sklearn_state = RAND_SEED,
        sklearn_params = None,
        all_result_folder = "results/sklearn-test",
        # **encoder_params,
    )

Generating esm1_t6_43M_UR50S upto 6 layer embedding ...


Using cache found in /home/t-fli/.cache/torch/hub/facebookresearch_esm_main
 50%|████▉     | 74/149 [23:26<23:58, 19.18s/it]

: 

: 