In [1]:
%cd ~/repo/protein-transfer

/home/t-fli/repo/protein-transfer


In [2]:
%load_ext blackcellmagic

In [15]:
def test_unpack():
    return ("Test", "test1", *["test2", "test3"])

In [17]:
test_unpack(), test_unpack()[3]

(('Test', 'test1', 'test2', 'test3'), 'test3')

In [3]:
"""Pre-processing the dataset"""

from __future__ import annotations

from collections import defaultdict

import os
from glob import glob
import tables
import pandas as pd
import numpy as np

from sklearn.preprocessing import LabelEncoder

import torch
from torch.utils.data import Dataset, DataLoader

from scr.utils import pickle_save, pickle_load, replace_ext
from scr.params.sys import RAND_SEED
from scr.params.emb import TRANSFORMER_INFO, CARP_INFO, MAX_SEQ_LEN
from scr.preprocess.seq_loader import SeqLoader
from scr.encoding.encoding_classes import (
    AbstractEncoder,
    ESMEncoder,
    CARPEncoder,
    OnehotEncoder,
)


def get_mut_name(mut_seq: str, parent_seq: str) -> str:
    """
    A function for returning the mutant name

    Args:
    - mut_seq: str, the full mutant sequence
    - parent_seq: str, the full parent sequence

    Returns:
    - str, parent, indel, or mutant name in the format of
        ParentAAMutLocMutAA:ParentAAMutLocMutAA:..., ie. W39W:D40G:G41C:V54Q
    """

    mut_list = []
    if parent_seq == mut_seq:
        return "parent"
    elif len(parent_seq) == len(mut_seq):
        for i, (p, m) in enumerate(zip(list(parent_seq), list(mut_seq))):
            if p != m:
                mut_list.append(f"{p}{i+1}{m}")
        return ":".join(mut_list)
    else:
        return "indel"


class AddMutInfo:
    """A class for appending mutation info for mainly protein engineering tasks"""

    def __init__(self, parent_seq_path: str, csv_path: str):

        """
        Args:
        - parent_seq_path: str, path for the parent sequence
        - csv_path: str, path for the fitness csv file
        """

        # Load the parent sequence from the fasta file
        self._parent_seq = SeqLoader(parent_seq_path=parent_seq_path)

        # load the dataframe
        self._init_df = pd.read_csv(csv_path)

        self._df = self._init_df.copy()
        # add a column with the mutant names
        self._df["mut_name"] = self._init_df["sequence"].apply(
            get_mut_name, parent_seq=self._parent_seq
        )
        # add a column with the number of mutations
        self._df["mut_numb"] = (
            self._df["mut_name"].str.split(":").map(len, na_action="ignore")
        )

        # get the pickle file path
        self._pkl_path = replace_ext(input_path=csv_path, ext=".pkl")

        pickle_save(what2save=self._df, where2save=self._pkl_path)

    @property
    def parent_seq(self) -> str:
        """Return the parent sequence"""
        return self._parent_seq

    @property
    def pkl_path(self) -> str:
        """Return the pkl file path for the processed dataframe"""
        return self._pkl_path

    @property
    def df(self) -> pd.DataFrame:
        """Return the processed dataframe"""
        return self._df


def std_split_ssdf(
    ssdf_path: str = "data/structure/secondary_structure/tape_ss3.csv",
    split_test: bool = True,
) -> None:
    """
    A function that standardize secondary structure dataset
    to set up as columns as sequence, target, set, validation
    where set is train or test and add validation as true
    """
    folder_path = os.path.dirname(ssdf_path)

    df = pd.read_csv(ssdf_path)

    # convert the string into numpy array
    # df["ss3"] = df["ss3"].apply(lambda x: np.array(x[1:-1].split(", ")))

    # add validation column
    df["validation"] = df["split"].apply(lambda x: True if x == "valid" else "")
    # now replace valid to train
    df = df.replace("valid", "train")
    # rename all columns
    df.columns = ["sequence", "target", "set", "validation"]

    if split_test:
        # get all kinds of test sets
        ss_tests = set(df["set"].unique()) - set(["train"])

        for ss_test in ss_tests:
            df.loc[~df["set"].isin(set(ss_tests) - set([ss_test]))].replace(
                ss_test, "test"
            ).to_csv(os.path.join(folder_path, ss_test + ".csv"), index=False)

    else:
        df.to_csv(f"{os.path.splitext(ssdf_path)[0]}_processed.csv", index=False)


def split_df_sets(df: pd.DataFrame) -> dict[pd.DataFrame]:
    """
    Return split dataframe for training, validation, and testing

    Args:
    - df: pd.DataFrame, input dataframe

    Returns:
    - a dict of dataframes for train, val, test (or ss3 tasks)
    """

    assert "set" in df.columns, f"set is not a column in the dataframe"
    assert "validation" in df.columns, f"validation is not a column in the dataframe"

    # init split df dict output
    df_dict = {}

    df_dict["train"] = df.loc[(df["set"] == "train") & (df["validation"] != True)].reset_index(drop=True)
    df_dict["val"] = df.loc[(df["set"] == "train") & (df["validation"] == True)].reset_index(drop=True)

    test_tasks = set(df["set"].unique()) - set(["train"])

    for test_task in test_tasks:
        df_dict[test_task] = df.loc[(df["set"] == test_task)].reset_index(drop=True)

    return df_dict


class DatasetInfo:
    """
    A class returns the information of a dataset
    """

    def __init__(self, dataset_path: str) -> None:
        """
        Args:
        - dataset_path: str, the path for the csv
        """
        self._df = pd.read_csv(dataset_path)

        assert "set" in self._df.columns, f"set is not a column in {dataset_path}"
        assert (
            "validation" in self._df.columns
        ), f"validation is not a column in {dataset_path}"

    def get_model_type(self) -> str:
        # pick linear regression if y numerical
        if self._df.target.dtype.kind in "iufc":
            return "LinearRegression"
        else:
            # ss3
            if "[" in self._df.target[0]:
                return "MultiLabelMultiClass"
            # annotation
            else:
                return "LinearClassifier"

    def get_numb_class(self) -> int:
        """
        A function to get number of class
        """
        # annotation class number
        if self.model_type == "LinearClassifier":
            return self._df.target.nunique()
        # ss3 or ss8 secondary structure states plus padding
        elif self.model_type == "MultiLabelMultiClass":
            return len(np.unique(np.array(self._df["target"][0][1:-1].split(", "))))

    @property
    def model_type(self) -> str:
        """Return the pytorch model type"""
        return self.get_model_type()

    @property
    def numb_class(self) -> int:
        """Return number of classes for classification"""
        return self.get_numb_class()

    @property
    def subset_list(self) -> list[str]:
        """Return a list of subset"""
        subset_list = list(self._df["set"].unique())
        subset_list.insert(1, "val")
        return subset_list


class TaskProcess:
    """A class for handling different downstream tasks"""

    def __init__(self, data_folder: str = "data/"):
        """
        Args:
        - data_folder: str, a folder path with all the tasks as subfolders where
            all the subfolders have datasets as the subsubfolders, ie

            {data_folder}/
                proeng/
                    aav/
                        one_vs_many.csv
                        two_vs_many.csv
                        P03135.fasta
                    thermo/
                        mixed.csv
        """

        if data_folder[-1] == "/":
            self._data_folder = data_folder
        else:
            self._data_folder = data_folder + "/"

        # sumamarize all files i nthe data folder
        self._sum_file_df = self.sum_files()

    def sum_files(self) -> pd.DataFrame:
        """
        Summarize all files in the data folder

        Returns:
        - A dataframe with "task", "dataset", "split",
            "csv_path", "fasta_path", "pkl_path" as columns, ie.
            (proeng, gb1, low_vs_high, data/proeng/gb1/low_vs_high.csv,
            data/proeng/gb1/5LDE_1.fasta)
            note that csv_path is the list of lmdb files for the structure task
        """
        dataset_folders = glob(f"{self._data_folder}*/*")
        # need a list of tuples in the order of:
        # (task, dataset, split, csv_path, fasta_path)
        list_for_df = []
        for dataset_folder in dataset_folders:
            _, task, dataset = dataset_folder.split("/")

            csv_paths = glob(f"{dataset_folder}/*.csv")

            if task == "structure":
                csv_paths = set(csv_paths) - set(
                    glob(f"{dataset_folder}/tape_ss3*.csv")
                )

            fasta_paths = glob(f"{dataset_folder}/*.fasta")
            pkl_paths = glob(f"{dataset_folder}/*.pkl")

            assert len(csv_paths) >= 1, "Less than one csv"
            assert len(fasta_paths) <= 1, "More than one fasta"

            for csv_path in csv_paths:
                # if parent seq fasta exists
                if len(fasta_paths) == 1:
                    fasta_path = fasta_paths[0]

                    # if no existing pkl file, generate and save
                    if len(pkl_paths) == 0:
                        print(f"Adding mutation info to {csv_path}...")
                        pkl_path = AddMutInfo(
                            parent_seq_path=fasta_path, csv_path=csv_path
                        ).pkl_path
                    # pkl file exits
                    else:
                        pkl_path = replace_ext(input_path=csv_path, ext=".pkl")
                # no parent fasta no pkl file
                else:
                    fasta_path = ""
                    pkl_path = ""

                df_dict = split_df_sets(pd.read_csv(csv_path))

                list_for_df.append(
                    tuple(
                        [
                            task,
                            dataset,
                            os.path.basename(os.path.splitext(csv_path)[0]),
                            len(df_dict["train"]),
                            len(df_dict["val"]),
                            len(df_dict["test"]),
                            csv_path,
                            fasta_path,
                            pkl_path,
                        ]
                    )
                )

        return pd.DataFrame(
            list_for_df,
            columns=[
                "task",
                "dataset",
                "split",
                "train_numb",
                "val_numb",
                "test_numb",
                "csv_path",
                "fasta_path",
                "pkl_path",
            ],
        )

    @property
    def sum_file_df(self) -> pd.DataFrame:
        """A summary table for all files in the data folder"""
        return self._sum_file_df


class ProtranDataset(Dataset):

    """A dataset class for processing protein transfer data"""

    def __init__(
        self,
        dataset_path: str,
        subset: str,
        encoder_name: str,
        reset_param: bool = False,
        resample_param: bool = False,
        embed_batch_size: int = 0,
        flatten_emb: bool | str = False,
        embed_folder: str = None,
        embed_layer: int | None = None,
        seq_start_idx: bool | int = False,
        seq_end_idx: bool | int = False,
        if_encode_all: bool = True,
        **encoder_params,
    ):

        """
        Args:
        - dataset_path: str, full path to the dataset, in pkl or panda readable format, ie
            "data/proeng/gb1/low_vs_high.csv"
            columns include: sequence, target, set, validation,
            mut_name (optional), mut_numb (optional)
        - subset: str, train, val, test
        - encoder_name: str, the name of the encoder
        - reset_param: bool = False, if update the full model to xavier_uniform_
        - resample_param: bool = False, if update the full model to xavier_normal_
        - embed_batch_size: int, set to 0 to encode all in a single batch
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - embed_folder: str = None, path to presaved embedding folder, ie
            "embeddings/proeng/gb1/low_vs_high"
            for which then can add the subset to be, ie
            "embeddings/proeng/gb1/low_vs_high/esm1_t6_43M_UR50S/mean/test/embedding.h5"
        - seq_start_idx: bool | int = False, the index for the start of the sequence
        - seq_end_idx: bool | int = False, the index for the end of the sequence
        - if_encode_all: bool = True, if encode full dataset all layers on the fly
        - encoder_params: kwarg, additional parameters for encoding
        """

        # with additional info mut_name, mut_numb
        if os.path.splitext(dataset_path)[-1] in [".pkl", ".PKL", ""]:
            self._df = pickle_load(dataset_path)
            self._add_mut_info = True
        # without such info
        else:
            self._df = pd.read_csv(dataset_path)
            self._ds_info = DatasetInfo(dataset_path)
            self._model_type = self._ds_info.model_type
            self._numb_class = self._ds_info.numb_class

            self._add_mut_info = False

        assert "set" in self._df.columns, f"set is not a column in {dataset_path}"
        assert (
            "validation" in self._df.columns
        ), f"validation is not a column in {dataset_path}"

        self._df_dict = split_df_sets(self._df)

        assert subset in list(
            self._df_dict.keys()
        ), "split can only be 'train', 'val', 'test' or 'cb513', 'ts115', 'casp12'"
        self._subset = subset

        # print(f"{self._subset} for ProtranDataset...")
        # display(self._df_dict[self._subset].head())

        self._subdf_len = len(self._df_dict[self._subset])

        # not specified seq start will be from 0
        if seq_start_idx == False:
            self._seq_start_idx = 0
        else:
            self._seq_start_idx = int(seq_start_idx)
        # not specified seq end will be the full sequence length
        if seq_end_idx == False:
            self._seq_end_idx = -1
            self._max_seq_len = self._df.sequence.str.len().max()
        else:
            self._seq_end_idx = int(seq_end_idx)
            self._max_seq_len = self._seq_end_idx - self._seq_start_idx

        # get unencoded string of input sequence
        # will need to convert data type
        self.sequence = self._get_column_value("sequence")

        self.if_encode_all = if_encode_all
        self._embed_folder = embed_folder

        self._encoder_name = encoder_name
        self._flatten_emb = flatten_emb

        # get the encoder class
        if self._encoder_name in TRANSFORMER_INFO.keys():
            encoder_class = ESMEncoder
        elif self._encoder_name in CARP_INFO.keys():
            encoder_class = CARPEncoder
        else:
            self._encoder_name == "onehot"
            encoder_class = OnehotEncoder
            encoder_params["max_seq_len"] = self._max_seq_len

        # get the encoder
        self._encoder = encoder_class(
            encoder_name=self._encoder_name,
            reset_param=reset_param,
            resample_param=resample_param,
            **encoder_params,
        )
        self._total_emb_layer = self._encoder.total_emb_layer
        self._embed_layer = embed_layer

        # encode all and load in memory
        if self.if_encode_all or (
            self._embed_folder is None and self._embed_layer is None
        ):
            print("Encoding all...")
            # encode the sequences without the mut_name
            # init an empty dict with empty list to append emb
            encoded_dict = defaultdict(list)

            # use the encoder generator for batch emb
            # assume no labels included
            for encoded_batch_dict in self._encoder.encode(
                mut_seqs=self.sequence,
                batch_size=embed_batch_size,
                flatten_emb=self._flatten_emb,
            ):

                for layer, emb in encoded_batch_dict.items():
                    encoded_dict[layer].append(emb)

            # assign each layer as its own variable
            for layer, emb in encoded_dict.items():
                setattr(self, "layer" + str(layer), np.vstack(emb))

        # load full one layer embedding
        if self._embed_folder is not None and self._embed_layer is not None:
            print(f"Load {self._embed_layer} from {self._embed_folder}...")

            emb_table = tables.open_file(
                os.path.join(
                    self._embed_folder,
                    self._encoder_name,
                    self._flatten_emb,
                    self._subset,
                    "embedding.h5",
                )
            )

            emb_table.flush()

            setattr(
                self,
                "layer" + str(self._embed_layer),
                getattr(emb_table.root, "layer" + str(self._embed_layer))[:],
            )

            emb_table.close()
        # get and format the fitness or secondary structure values
        # can be numbers or string
        # will need to convert data type
        # make 1D tensor 2D
        self.y = np.expand_dims(self._get_column_value("target"), 1)

        # add mut_name and mut_numb for relevant proeng datasets
        if self._add_mut_info:
            self.mut_name = self._get_column_value("mut_name")
            self.mut_numb = self._get_column_value("mut_numb")
        else:
            self.mut_name = [""] * self._subdf_len
            self.mut_numb = [np.nan] * self._subdf_len

    def __len__(self):
        """Return the length of the selected subset of the dataframe"""
        return self._subdf_len

    def __getitem__(self, idx: int):

        """
        Return the item in the order of
        target (y), sequence, mut_name (optional), mut_numb (optional),
        embedding per layer upto the max number of layer for the encoder

        Args:
        - idx: int
        """
        if self.if_encode_all and self._embed_folder is None:

            return (
                self.y[idx],
                self.sequence[idx],
                self.mut_name[idx],
                self.mut_numb[idx],
                *(
                    getattr(self, "layer" + str(layer))[idx]
                    for layer in range(self._total_emb_layer)
                ),
            )
        elif self._embed_folder is not None:
            # load the .h5 file with the embeddings
            """
            gb1_emb = tables.open_file("embeddings/proeng/gb1/low_vs_high/esm1_t6_43M_UR50S/mean/test/embedding.h5")
            gb1_emb.flush()
            gb1_emb.root.layer0[0:5]
            """
            # return all
            if self._embed_layer is None:

                emb_table = tables.open_file(
                    os.path.join(
                        self._embed_folder,
                        self._encoder_name,
                        self._flatten_emb,
                        self._subset,
                        "embedding.h5",
                    )
                )

                emb_table.flush()

                layer_embs = [
                    getattr(emb_table.root, "layer" + str(layer))[idx]
                    for layer in range(self._total_emb_layer)
                ]

                emb_table.close()

                return (
                    self.y[idx],
                    self.sequence[idx],
                    self.mut_name[idx],
                    self.mut_numb[idx],
                    layer_embs,
                )
            # only pick particular embeding layer
            else:

                return (
                    self.y[idx],
                    self.sequence[idx],
                    self.mut_name[idx],
                    self.mut_numb[idx],
                    # getattr(emb_table.root, "layer" + str(self._embed_layer))[idx]
                    getattr(self, "layer" + str(self._embed_layer))[idx],
                )
        else:
            return (
                self.y[idx],
                self.sequence[idx],
                self.mut_name[idx],
                self.mut_numb[idx],
            )

    def _get_column_value(self, column_name: str) -> np.ndarray:
        """
        Check and return the column values of the selected dataframe subset

        Args:
        - column_name: str, the name of the dataframe column
        """
        if column_name in self._df.columns:

            y = self._df_dict[self._subset][column_name]

            if column_name == "sequence":

                return (
                    # self._df_dict[self._subset]["sequence"]
                    y.astype(str)
                    .str[self._seq_start_idx : self._seq_end_idx]
                    .apply(
                        lambda x: x[: int(MAX_SEQ_LEN // 2)]
                        + x[-int(MAX_SEQ_LEN // 2) :]
                        if len(x) > MAX_SEQ_LEN
                        else x
                    )
                    .values
                )
            elif column_name == "target" and self._model_type == "LinearClassifier":
                print("Converting classes into int...")
                le = LabelEncoder()
                return le.fit_transform(y.values.flatten())
            elif column_name == "target" and self._model_type == "MultiLabelMultiClass":
                print("Converting ss3/ss8 into np.array and pad -1...")
                """le = LabelEncoder()
                print(le.fit_transform(np.concatenate(y.apply(lambda x: np.array(x[1:-1].split(", ")).astype("int")))).shape)
                return le.fit_transform(np.concatenate(y.apply(lambda x: np.array(x[1:-1].split(", ")).astype("int"))))"""
                np_y = (
                    y.apply(lambda x: np.array(x[1:-1].split(", ")).astype("int"))
                )
                print(f"y shape in dataset is {np.stack([np.pad(i,pad_width=(0, self._max_seq_len - len(i)),constant_values=-1,) for i in np_y]).shape}")
                return np.stack(
                    [
                        np.pad(
                            i,
                            pad_width=(0, self._max_seq_len - len(i)),
                            constant_values=-1,
                        )
                        for i in np_y
                    ]
                )
            else:
                return y.values

    @property
    def df_full(self) -> pd.DataFrame:
        """Return the full loaded dataset"""
        return self._df

    @property
    def df_train(self) -> pd.DataFrame:
        """Return the dataset for training only"""
        return self._df_dict["train"]

    @property
    def df_val(self) -> pd.DataFrame:
        """Return the dataset for validation only"""
        return self._df_dict["val"]

    @property
    def df_dict(self) -> pd.DataFrame:
        """Return the dict with different dataframe split"""
        return self._df_dict

    @property
    def max_seq_len(self) -> int:
        """Longest sequence length"""
        return self._max_seq_len


def split_protrain_loader(
    dataset_path: str,
    encoder_name: str,
    reset_param: bool = False,
    resample_param: bool = False,
    embed_batch_size: int = 128,
    flatten_emb: bool | str = False,
    embed_folder: str | None = None,
    embed_layer: int | None = None,
    seq_start_idx: bool | int = False,
    seq_end_idx: bool | int = False,
    subset_list: list[str] = ["train", "val", "test"],
    loader_batch_size: int = 64,
    worker_seed: int = RAND_SEED,
    if_encode_all: bool = True,
    **encoder_params,
):

    """
    A function encode and load the data from a path

    Args:
    - dataset_path: str, full path to the dataset, in pkl or panda readable format, ie
        "data/proeng/gb1/low_vs_high.csv"
        columns include: sequence, target, set, validation,
        mut_name (optional), mut_numb (optional)
    - encoder_name: str, the name of the encoder
    - reset_param: bool = False, if update the full model to xavier_uniform_
    - resample_param: bool = False, if update the full model to xavier_normal_
    - embed_batch_size: int, set to 0 to encode all in a single batch
    - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
    - embed_folder: str = None, path to presaved embedding
    - seq_start_idx: bool | int = False, the index for the start of the sequence
    - seq_end_idx: bool | int = False, the index for the end of the sequence
    - subset_list: list of str, train, val, test
    - loader_batch_size: int, the batch size for train, val, and test dataloader
    - worker_seed: int, the seed for dataloader
    - if_encode_all: bool = True, if encode full dataset all layers on the fly
    - encoder_params: kwarg, additional parameters for encoding
    """

    assert set(subset_list) <= set(
        ["train", "val", "test", "cb513", "ts115", "casp12"]
    ), "subset_list can only contain 'train', 'val', 'test', or 'cb513', 'ts115', 'casp12'"

    # specify no shuffling for validation and test
    if_shuffle_list = [True if subset == "train" else False for subset in subset_list]
    """print(f"{subset_list} in split_protrain_loader")
    print(f"{if_shuffle_list} in if_shuffle_list")
    print(f"zip {zip(subset_list, if_shuffle_list)}")

    loader_dict = {}

    for subset, if_shuffle in zip(subset_list, if_shuffle_list):
        print(subset, if_shuffle)
        loader_dict[subset] = DataLoader(
            dataset=ProtranDataset(
                dataset_path=dataset_path,
                subset=subset,
                encoder_name=encoder_name,
                reset_param=reset_param,
                resample_param=resample_param,
                embed_batch_size=embed_batch_size,
                flatten_emb=flatten_emb,
                embed_folder=embed_folder,
                embed_layer=embed_layer,
                seq_start_idx=seq_start_idx,
                seq_end_idx=seq_end_idx,
                if_encode_all=if_encode_all,
                **encoder_params,
            ),
            batch_size=loader_batch_size,
            shuffle=if_shuffle,
            worker_init_fn=worker_seed,
        )

    return loader_dict"""

    return {
        subset: DataLoader(
            dataset=ProtranDataset(
                dataset_path=dataset_path,
                subset=subset,
                encoder_name=encoder_name,
                reset_param=reset_param,
                resample_param=resample_param,
                embed_batch_size=embed_batch_size,
                flatten_emb=flatten_emb,
                embed_folder=embed_folder,
                embed_layer=embed_layer,
                seq_start_idx=seq_start_idx,
                seq_end_idx=seq_end_idx,
                if_encode_all=if_encode_all,
                **encoder_params,
            ),
            batch_size=loader_batch_size,
            shuffle=if_shuffle,
            worker_init_fn=worker_seed,
        ) for subset, if_shuffle in zip(subset_list, if_shuffle_list) 
    }

In [53]:
std_split_ssdf(split_test=False)

In [54]:
pd.read_csv("data/structure/secondary_structure/tape_ss3_processed.csv")

Unnamed: 0,sequence,target,set,validation
0,AETVESCLAKSHTENSFTNVXKDDKTLDRYANYEGCLWNATGVVVC...,"[2, 2, 2, 0, 0, 0, 0, 0, 2, 2, 2, 2, 1, 1, 1, ...",train,
1,ASQEISKSIYTCNDNQVXEVIYVNTEAGNAYAIISQVNEXIPXRLX...,"[2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, ...",train,
2,XGSSHHHHHHSSGRENLYFQGXNISEINGFEVTGFVVRTTNADEXN...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",train,
3,SNADYNRLSVPGNVIGKGGNAVVYEDAEDATKVLKMFTTSQSNEEV...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, ...",train,
4,DTVGRPLPHLAAAMQASGEAVYCDDIPRYENELFLRLVTSTRAHAK...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, ...",train,
...,...,...,...,...
11492,GLPVPSPPGTLLPGQSPDEAFARNSVVFLVPGAEYNWKNVVIRKPV...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, ...",casp12,
11493,XNKYLFELPYERSEPGWTIRSYFDLXYNENRFLDAVENIVNKESYI...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",casp12,
11494,GHXASGPWKLTASKTHIXKSADVEKLADELHXPSLPEXXFGDNVLR...,"[2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, ...",casp12,
11495,SMWAFSELPMPLLINLIVSLLGFVATVTLIPAFRGHFIAARLCGQD...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, ...",casp12,


In [55]:
pd.read_csv("data/structure/secondary_structure/tape_ss3_processed.csv").validation.unique()

array([nan, True], dtype=object)

In [56]:
pd.read_csv("data/structure/secondary_structure/tape_ss3_processed.csv").set.unique()

array(['train', 'cb513', 'ts115', 'casp12'], dtype=object)

In [57]:
df_dict = split_df_sets(pd.read_csv("data/structure/secondary_structure/tape_ss3_processed.csv"))
df_dict.keys()

dict_keys(['train', 'val', 'casp12', 'cb513', 'ts115'])

In [58]:
df_dict["ts115"]

Unnamed: 0,sequence,target,set,validation
0,GGGLAPAEVPKGDRTAGSPPRTISPPPCQGPIEIKETFKYINTVVS...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",ts115,
1,GSLGKRPDAVSQAVSSLQGLSPEQADLVAKLKNGHLSERVLAANKL...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",ts115,
2,GEITSVSTACQQLEVFSRVLRTSLATILDGGEENLEKNLPEFAKMV...,"[2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, ...",ts115,
3,GPAAAPALDTLPAPTSLVLSQVTSSSIRLSWTPAPRHPLKYLIVWR...,"[2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 1, ...",ts115,
4,GSAFSNQTYPTIEPKPFLYVVGRKKMMDAQYKCYDRMQQLPAYQGE...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",ts115,
...,...,...,...,...
110,SSLLSRLTQSNQSKDKIIAALAKRNVYKSFAGLYDSKGKNDNTGYD...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, ...",ts115,
111,GYLDTSFYKQCRSILHVVMDLDRDGIFARDPSKLPDYRMIISHPMW...,"[2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 0, 0, ...",ts115,
112,GKKEFLKHEYSPGHWSIDYTRAGTSIAVITVRNKYHYSVILNPTDC...,"[2, 2, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",ts115,
113,SNAXKYFQIDELTLNAXLRITTIESLTPEQRLELIKAHLLNIKTPS...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, ...",ts115,


In [59]:
df_dict["casp12"]

Unnamed: 0,sequence,target,set,validation
0,MHHHHHHENLYFQSHQGPEVTLITANAEGIEGGKTTIKSRSVDVGV...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",casp12,
1,MSAETVNNYDYSDWYENAAPTKAPVEVIPPCDPTADEGLFHICIAA...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",casp12,
2,ETGQVAASPSINVALKAAFPSPPYLVELLETAASDNTTIYYSLLDR...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, ...",casp12,
3,MAHHHHHHMAISPRDEQNRSVDLWFAYKVPKLTKDADSDSASGYEY...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, ...",casp12,
4,MHHHHHHENLYFQTSIRTEPTYTLYATFDNIGGLKARSPVSIGGVV...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",casp12,
5,MGHHHHHHGGSENLYFQGNEDILKASATQSAVAGTYQIQVNSLATS...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",casp12,
6,SRGPSNGQSVLENSVQVKETSPRRVSVDPQTGEFVVFDRTLGDVYH...,"[2, 2, 2, 2, 2, 2, 0, 0, 0, 0, 0, 0, 2, 2, 1, ...",casp12,
7,MGAEEEDTAILYPFTISGNDRNGNFTINFKGTPNSTNNGCIGYSYN...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, ...",casp12,
8,KMAGSIVISKEVRVPVSTSQFDYLVSRIGDQFHSSDMWIKDEVYLP...,"[2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, ...",casp12,
9,MGASSGSNISASNGSSSPTTIVASNPVDLNAFDRLNVVDPAVGKFR...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, ...",casp12,


In [60]:
df_dict["cb513"]

Unnamed: 0,sequence,target,set,validation
0,MNIFEMLRIDEGLRLKIYKDTEGYYTIGIGHLLTKSPSLNAAKSEL...,"[2, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 2, 1, 1, ...",cb513,
1,FKIETTPESRYLAQIGDSVSLTCSTTGCESPFFSWRTQIDSPLNGK...,"[2, 1, 1, 1, 1, 1, 2, 2, 2, 1, 1, 1, 1, 1, 2, ...",cb513,
2,KQKVINVKEVRLSPTIEEHDFNTKLRNARKFLEKGDKVKATIRFKG...,"[2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, ...",cb513,
3,SPLPITPVNATCAIRHPCHGNLMNQIKNQLAQLNGSANALFISYYT...,"[2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 1, 0, 0, 0, 2, ...",cb513,
4,CDAFVGTWKLVSSENFDDYMKEVGVGFATRKVAGMAKPNMIISVNG...,"[2, 0, 0, 0, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, ...",cb513,
...,...,...,...,...
508,LKAGPLLSSEKLIAIGASTGGTEAIRHVLQPLPLSSPAVIITQHMP...,"[2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 1, 1, 1, 1, ...",cb513,
509,HGGEAHMVPMDKTLKEFGADVQWDDYAQLFTLIKDGAYVKVKPGAQ...,"[2, 2, 2, 2, 2, 2, 1, 1, 1, 0, 0, 0, 0, 0, 0, ...",cb513,
510,AAEEKTEFDVILKAAGANKVAVIKAVRGATGLGLKEAKDLVESAPA...,"[2, 2, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 2, ...",cb513,
511,TKGLVLGIYSKEKEEDEPQFTSAGENFNKLVSGKLREILNISGPPL...,"[2, 2, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, ...",cb513,


In [4]:
"""PyTorch models"""

from __future__ import annotations

import torch
from torch import nn


class LinearRegression(nn.Module):
    """Linear regression"""

    def __init__(self, input_dim: int, output_dim: int):
        super(LinearRegression, self).__init__()
        self.linear = nn.Linear(input_dim, output_dim)

    def forward(self, x):
        return self.linear(x)

    @property
    def model_name(self) -> str:
        """Return the name of the model"""
        return "LinearRegression"


class LinearClassifier(nn.Module):
    """Linear classifier"""

    def __init__(self, input_dim: int, numb_class: int):
        """
      Args:
      - input_dim: int, 
      - numb_class: int, the number of classes
      """
        super(LinearClassifier, self).__init__()
        self.linear = nn.Linear(input_dim, numb_class)

    def forward(self, x: torch.tensor) -> torch.tensor:
        return self.linear(x)

    @property
    def model_name(self) -> str:
        """Return the name of the model"""
        return "LinearClassifier"

class MultiLabelMultiClass(nn.Module):
    """Multi label multi class"""

    def __init__(self, input_dim: int, numb_class: int) -> None:
        super(MultiLabelMultiClass, self).__init__()
        self.linear = nn.Linear(input_dim, numb_class)
        # self.softmax = nn.Softmax(dim=2)

    def forward(self, x: torch.tensor) -> torch.tensor:
        print("x in forward in model")
        print(x.dtype, x.shape)
        return self.linear(x)
        # return self.softmax(self.linear(x))

    @property
    def model_name(self) -> str:
        """Return the name of the model"""
        return "MultiLabelMultiClass"

In [10]:
np.array([[0, 2, 1, -1], [0, 2, -1, -1]])

array([[ 0,  2,  1, -1],
       [ 0,  2, -1, -1]])

In [68]:
np.delete(np.array([[0, 2, 1, -1], [0, 2, -1, -1]]).flatten(), np.where(np.array([[0, 2, 1, -1], [0, 2, -1, -1]]).flatten() == -1))

array([0, 2, 1, 0, 2])

In [None]:
t_x = torch.from_numpy()np.array([[[0, 1, 0 , 0, 1, ], [0, 1, 0, 0]]])

In [21]:
t = torch.from_numpy(np.array([[0, 2, 1, -1], [0, 2, -1, -1]]))
t.shape

torch.Size([2, 4])

In [19]:
t.flatten()

tensor([ 0,  2,  1, -1,  0,  2, -1, -1])

In [20]:
[t.flatten()!= -1]

[tensor([ True,  True,  True, False,  True,  True, False, False])]

In [74]:
t.flatten()[t.flatten()!= -1]

tensor([0, 2, 1, 0, 2])

In [20]:
"""A script with model training and testing details assuming"""

from __future__ import annotations

import random
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from sklearn.metrics import mean_squared_error, log_loss, accuracy_score, roc_auc_score
from sklearn.preprocessing import LabelEncoder

from scr.params.sys import RAND_SEED, DEVICE

# seed everything
random.seed(RAND_SEED)
np.random.seed(RAND_SEED)
torch.manual_seed(RAND_SEED)
torch.cuda.manual_seed(RAND_SEED)
torch.cuda.manual_seed_all(RAND_SEED)
torch.backends.cudnn.deterministic = True


def get_x_y(
    model_name, device, batch, embed_layer: int,
):
    """
    A function process x and y from the loader

    Args:
    -

    Returns:
    - x
    - y
    """

    # for each batch: y, sequence, mut_name, mut_numb, [layer0, ...]
    # or: y, seuqence, mut_name, mut_numb, layer0, ...

    embs = batch[4]
    if torch.is_tensor(embs):
        x = embs
    else:
        x = embs[embed_layer]
    y = batch[0]
    print("x, y and y.flatten() shape in get_x_y")
    print(x.shape, y.shape, y.flatten().shape)

    # ss3 / ss8 type
    if model_name == "MultiLabelMultiClass":
        y = torch.squeeze(y)
        # x = torch.cat([seq[seq.abs().sum(dim=1).bool(), :] for seq in x]).to(torch.float32)
        # non_mask = [y.flatten()!= -1]
        # print(non_mask[0].shape)

        # print(f"x[0].shape is {x[0].shape}")
        # print(f"x[0][non_mask, :] {x[0][non_mask, :]}")
        print(f"x[0], y[0] shape {x[0].shape} {y[0].shape}")
        print(f"x[0][y[0]!= -1, :] {x[0][y[0]!= -1, :]}")
        x = torch.cat([seq[ss!= -1, :] for (seq, ss) in zip(x, y)]).to(torch.float32)
        # torch.cat([torch.from_numpy(x)[torch.from_numpy(x).abs().sum(dim=1).bool(), :] for x in X])
        # non_empty_mask = x0_tensor.abs().sum(dim=1).bool()
        # x0_non_empty = x0_tensor[non_empty_mask, :]

        # le = LabelEncoder()
        # y = le.fit_transform(np.concatenate(y.apply(lambda x: np.array(x[1:-1].split(", ")).astype("int"))))
        # y = y.to(torch.float32)
        # concat and remove the padding
        y = y.flatten()[y.flatten()!= -1]
        print(f"y is now {y.shape}")

        print("after concat x and y in get_x_y")
        print(x.shape, y.shape)
        print(x.dtype, y.dtype)

    return x.to(device, non_blocking=True), y.to(device, non_blocking=True)


def run_epoch(
    model: nn.Module,
    loader: DataLoader,
    encoder_name: str,
    embed_layer: int,
    reset_param: bool = False,
    resample_param: bool = False,
    embed_batch_size: int = 0,
    flatten_emb: bool | str = False,
    # if_encode_all: bool = True,
    device: torch.device | str = DEVICE,
    criterion: nn.Module | None = None,
    optimizer: torch.optim.Optimizer | None = None,
    **encoder_params,
) -> float:

    """
    Runs one epoch.
    
    Args:
    - model: nn.Module, already moved to device
    - loader: torch.utils.data.DataLoader
    - device: torch.device or str
    - criterion: optional nn.Module, loss function, already moved to device
    - optimizer: optional torch.optim.Optimizer, must also provide criterion,
        only provided for training

    Returns: 
    - float, average loss over batches
    """
    if optimizer is not None:
        assert criterion is not None
        model.train()
        is_train = True
    else:
        model.eval()
        is_train = False

    model_name = model.model_name
    cum_loss = 0.0

    with torch.set_grad_enabled(is_train):
        # if not if_encode_all:
        # for each batch: y, sequence, mut_name, mut_numb, [layer0, ...]

        for batch in loader:

            x, y = get_x_y(model_name, device, batch, embed_layer)

            outputs = model(x)
            print(x.shape, y.shape, outputs.shape)
            print(x.dtype, y.dtype, outputs.dtype)

            if criterion is not None:
                if model_name == "LinearRegression":
                    loss = criterion(outputs, y.float())
                elif model_name == "LinearClassifier":
                    loss = criterion(outputs, y.squeeze())
                elif model_name == "MultiLabelMultiClass":
                    detached_y = y.detach().cpu().numpy()
                    binary_y = torch.from_numpy((np.arange(detached_y.max() + 1) == detached_y[..., None]))
                    print(binary_y.shape)
                    loss = criterion(
                        outputs,
                        binary_y.to(torch.float32).to(device, non_blocking=True),
                    )
                    """
                    loss = criterion(
                        outputs,
                        (torch.arange(y.max() + 1) == y[..., None]).to(torch.float32),
                    )
                    """

                if optimizer is not None:
                    optimizer.zero_grad()
                    loss.backward()
                    optimizer.step()

                cum_loss += loss.item()

    return cum_loss / len(loader)


def train(
    model: nn.Module,
    criterion: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    encoder_name: str,
    embed_layer: int,
    reset_param: bool = False,
    resample_param: bool = False,
    embed_batch_size: int = 0,
    flatten_emb: bool | str = False,
    device: torch.device | str = DEVICE,
    learning_rate: float = 1e-4,
    lr_decay: float = 0.1,
    epochs: int = 100,
    early_stop: bool = True,
    tolerance: int = 10,
    min_epoch: int = 5,
    **encoder_params,
) -> tuple[np.ndarray, np.ndarray]:

    """
    Args:
    - model: nn.Module, already moved to device
    - train_loader: torch.utils.data.DataLoader, 
    - val_loader: torch.utils.data.DataLoader, 
    - criterion: nn.Module, loss function, already moved to device
    - device: torch.device or str
    - learning_rate: float
    - lr_decay: float, factor by which to decay LR on plateau
    - epochs: int, number of epochs to train for
    - early_stop: bool = True,

    Returns: 
    - tuple of np.ndarray, (train_losses, val_losses)
        train/val_losses: np.ndarray, shape [epochs], entries are average loss
        over batches for that epoch
    """

    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode="min", factor=lr_decay
    )

    train_losses = np.zeros(epochs)
    val_losses = np.zeros(epochs)

    # init for early stopping
    counter = 0
    min_val_loss = np.Inf

    for epoch in tqdm(range(epochs)):

        train_losses[epoch] = run_epoch(
            model=model,
            loader=train_loader,
            encoder_name=encoder_name,
            embed_layer=embed_layer,
            reset_param=reset_param,
            resample_param=resample_param,
            embed_batch_size=embed_batch_size,
            flatten_emb=flatten_emb,
            device=device,
            criterion=criterion,
            optimizer=optimizer,
            **encoder_params,
        )

        val_loss = run_epoch(
            model=model,
            loader=val_loader,
            encoder_name=encoder_name,
            embed_layer=embed_layer,
            reset_param=reset_param,
            resample_param=resample_param,
            embed_batch_size=embed_batch_size,
            flatten_emb=flatten_emb,
            device=device,
            criterion=criterion,
            optimizer=None,
            **encoder_params,
        )
        val_losses[epoch] = val_loss

        scheduler.step(val_loss)

        if early_stop:
            # when val loss decrease, reset min loss and counter
            if val_loss < min_val_loss:
                min_val_loss = val_loss
                counter = 0
            else:
                counter += 1

            if epoch > min_epoch and counter == tolerance:
                break

    return train_losses, val_losses


def test(
    model: nn.Module,
    loader: DataLoader,
    embed_layer: int,
    criterion: nn.Module | None,
    device: torch.device | str = DEVICE,
    print_every: int = 1000,
) -> tuple[float, np.ndarray, np.ndarray]:
    """
    Runs one epoch of testing, returning predictions and labels.
    
    Args:
    - model: nn.Module, already moved to device
    - device: torch.device or str
    - loader: torch.utils.data.DataLoader
    - criterion: optional nn.Module, loss function, already moved to device
    - print_every: int, how often (number of batches) to print avg loss
    
    Returns: tuple (avg_loss, preds, labels)
    - avg_loss: float, average loss per training example 
    - preds: np.ndarray, shape [num_examples, ...], predictions over dataset
    - labels: np.ndarray, shape [num_examples, ...], dataset labels
    """
    model.eval()
    msg = "[{step:5d}] loss: {loss:.3f}"

    model_name = model.model_name
    cum_loss = 0.0

    pred_probs = []
    pred_classes = []
    labels = []

    with torch.no_grad():

        for i, batch in enumerate(tqdm(loader)):
            # for each batch: y, sequence, mut_name, mut_numb, [layer0, ...]
            x, y = get_x_y(model_name, device, batch, embed_layer)

            # forward + backward + optimize
            outputs = model(x)

            # append results
            labels.append(y.detach().cpu().squeeze().numpy())

            # append class
            if model_name == "LinearClassifier":
                pred_classes.append(
                    outputs.detach()
                    .cpu()
                    .data.max(1, keepdim=True)[1]
                    .squeeze()
                    .numpy()
                )
            elif model_name == "MultiLabelMultiClass":
                pred_classes.append(
                    torch.argmax(outputs, dim=2).detach().cpu().squeeze().numpy()
                )

            pred_probs.append(outputs.detach().cpu().squeeze().numpy())

            if criterion is not None:
                if model_name == "LinearRegression":
                    loss = criterion(outputs, y)
                elif model_name == "LinearClassifier":
                    loss = criterion(outputs, y.squeeze())
                elif model_name == "MultiLabelMultiClass":
                    detached_y = y.detach().cpu().numpy()
                    binary_y = torch.from_numpy((np.arange(detached_y.max() + 1) == detached_y[..., None]))
                    loss = criterion(
                        outputs,
                        binary_y.to(torch.float32).to(device, non_blocking=True),
                    )
                cum_loss += loss.item()

                if ((i + 1) % print_every == 0) or (i + 1 == len(loader)):
                    tqdm.write(msg.format(step=i + 1, loss=cum_loss / len(loader)))

    avg_loss = cum_loss / len(loader)

    if pred_classes == []:
        pred_classes_conc = pred_classes
    else:
        pred_classes_conc = np.concatenate(pred_classes)

    return (
        avg_loss,
        np.concatenate(pred_probs),
        pred_classes_conc,
        np.concatenate(labels),
    )

In [21]:
"""Script for running pytorch models"""

from __future__ import annotations

import os
from tqdm import tqdm
from concurrent import futures

import numpy as np

import torch
import torch.nn as nn

from sklearn.metrics import ndcg_score, accuracy_score, roc_auc_score

from scipy.stats import spearmanr

from scr.params.aa import AA_NUMB
from scr.params.sys import RAND_SEED, DEVICE
from scr.params.emb import TRANSFORMER_INFO, CARP_INFO, MAX_SEQ_LEN

# from scr.preprocess.data_process import split_protrain_loader, DatasetInfo
from scr.encoding.encoding_classes import (
    OnehotEncoder,
    ESMEncoder,
    CARPEncoder,
    get_emb_info,
)
from scr.model.pytorch_model import (
    LinearRegression,
    LinearClassifier,
    # MultiLabelMultiClass,
)

# from scr.model.train_test import train, test
from scr.vis.learning_vis import plot_lc
from scr.utils import get_folder_file_names, pickle_save, get_default_output_path


class Run_Pytorch:
    def __init__(
        self,
        dataset_path: str,
        encoder_name: str,
        reset_param: bool = False,
        resample_param: bool = False,
        embed_batch_size: int = 128,
        flatten_emb: bool | str = False,
        embed_folder: str | None = None,
        seq_start_idx: bool | int = False,
        seq_end_idx: bool | int = False,
        loader_batch_size: int = 64,
        worker_seed: int = RAND_SEED,
        if_encode_all: bool = True,
        if_multiprocess: bool = False,
        learning_rate: float = 1e-4,
        lr_decay: float = 0.1,
        epochs: int = 100,
        early_stop: bool = True,
        tolerance: int = 10,
        min_epoch: int = 5,
        device: torch.device | str = DEVICE,
        all_plot_folder: str = "results/learning_curves",
        all_result_folder: str = "results/train_val_test",
        **encoder_params,
    ) -> None:

        """
        A function for running pytorch model

        Args:
        - dataset_path: str, full path to the dataset, in pkl or panda readable format
            columns include: sequence, target, set, validation, mut_name (optional), mut_numb (optional)
        - encoder_name: str, the name of the encoder
        
        - embed_batch_size: int, set to 0 to encode all in a single batch
        - flatten_emb: bool or str, if and how (one of ["max", "mean"]) to flatten the embedding
        - embed_folder: str = None, path to presaved embedding
        - seq_start_idx: bool | int = False, the index for the start of the sequence
        - seq_end_idx: bool | int = False, the index for the end of the sequence
        - loader_batch_size: int, the batch size for train, val, and test dataloader
        - worker_seed: int, the seed for dataloader
        - learning_rate: float
        - lr_decay: float, factor by which to decay LR on plateau
        - epochs: int, number of epochs to train for
        - device: torch.device or str
        - all_plot_folder: str, the parent folder path for saving all the learning curves
        - all_result_folder: str = "results/train_val_test", the parent folder for all results
        - encoder_params: kwarg, additional parameters for encoding

        Returns:
        - result_dict: dict, with the keys and dict values
            "losses": {"train_losses": np.ndarray, "val_losses": np.ndarray}
            "train": {"mse": float, 
                    "pred": np.ndarray,
                    "true": np.ndarray,
                    "ndcg": float,
                    "rho": SpearmanrResults(correlation=float, pvalue=float)}
            "val":   {"mse": float, 
                    "pred": np.ndarray,
                    "true": np.ndarray,
                    "ndcg": float,
                    "rho": SpearmanrResults(correlation=float, pvalue=float)}
            "test":  {"mse": float, 
                    "pred": np.ndarray,
                    "true": np.ndarray,
                    "ndcg": float,
                    "rho": SpearmanrResults(correlation=float, pvalue=float)}

        """

        self._dataset_path = dataset_path
        self._encoder_name = encoder_name

        if self._encoder_name not in (
            list(TRANSFORMER_INFO.keys()) + list(CARP_INFO.keys())
        ):
            self._encoder_name = "onehot"

        self._reset_param = reset_param
        self._resample_param = resample_param
        self._embed_batch_size = embed_batch_size
        self._flatten_emb = flatten_emb

        self._learning_rate = learning_rate
        self._lr_decay = lr_decay
        self._epochs = epochs
        self._early_stop = early_stop
        self._tolerance = tolerance
        self._min_epoch = min_epoch
        self._device = device
        self._all_plot_folder = all_plot_folder
        self._all_result_folder = all_result_folder
        self._encoder_params = encoder_params

        self._ds_info = DatasetInfo(self._dataset_path)
        self._model_type = self._ds_info.model_type
        self._numb_class = self._ds_info.numb_class
        self._subset_list = self._ds_info.subset_list
        
        print(f"This dataset includes subsets: {self._subset_list}...")

        # self._train_loader, self._val_loader, self._test_loader 
        self._loader_dict = split_protrain_loader(
            dataset_path=self._dataset_path,
            encoder_name=self._encoder_name,
            reset_param=self._reset_param,
            resample_param=self._resample_param,
            embed_batch_size=self._embed_batch_size,
            flatten_emb=self._flatten_emb,
            embed_folder=embed_folder,
            seq_start_idx=seq_start_idx,
            seq_end_idx=seq_end_idx,
            subset_list=self._subset_list,
            loader_batch_size=loader_batch_size,
            worker_seed=worker_seed,
            if_encode_all=if_encode_all,
            **encoder_params,
        )

        encoder_name, encoder_class, total_emb_layer = get_emb_info(encoder_name)

        if encoder_class == ESMEncoder:
            self._encoder_info_dict = TRANSFORMER_INFO
        elif encoder_class == CARPEncoder:
            self._encoder_info_dict = CARP_INFO
        elif encoder_class == OnehotEncoder:
            # TODO aultoto
            if self._flatten_emb == False:
                self._encoder_info_dict = {"onehot": (AA_NUMB,)}
            else:
                self._encoder_info_dict = {"onehot": (MAX_SEQ_LEN * 22,)}

        if if_multiprocess:
            print("Running different emb layer in parallel...")
            # add the thredpool max_workers=None
            with futures.ProcessPoolExecutor(max_workers=os.cpu_count() - 1) as pool:
                # for each layer train the model and save the model
                for embed_layer in tqdm(range(total_emb_layer)):
                    pool.submit(self.run_pytorch_layer, embed_layer)

        else:
            for embed_layer in range(total_emb_layer):
                print(f"Running pytorch model for layer {embed_layer}")
                self.run_pytorch_layer(embed_layer)

    def run_pytorch_layer(self, embed_layer):

        # init model based on datasets
        if self._model_type == "LinearRegression":
            model = LinearRegression(
                input_dim=self._encoder_info_dict[self._encoder_name][0], output_dim=1
            )
            criterion = nn.MSELoss()

        elif self._model_type == "LinearClassifier":
            model = LinearClassifier(
                input_dim=self._encoder_info_dict[self._encoder_name][0],
                numb_class=self._numb_class,
            )
            criterion = nn.CrossEntropyLoss()

        elif self._model_type == "MultiLabelMultiClass":
            model = MultiLabelMultiClass(
                input_dim=self._encoder_info_dict[self._encoder_name][0],
                numb_class=self._numb_class,
            )
            criterion = nn.CrossEntropyLoss()
            # criterion = nn.BCELoss()

        model_name = model.model_name
        model.to(self._device, non_blocking=True)
        criterion.to(self._device, non_blocking=True)
        # print("in layer before train")

        train_losses, val_losses = train(
            model=model,
            criterion=criterion,
            train_loader=self._loader_dict["train"],
            val_loader=self._loader_dict["val"],
            encoder_name=self._encoder_name,
            embed_layer=embed_layer,
            reset_param=self._reset_param,
            resample_param=self._resample_param,
            embed_batch_size=self._embed_batch_size,
            flatten_emb=self._flatten_emb,
            device=self._device,
            learning_rate=self._learning_rate,
            lr_decay=self._lr_decay,
            epochs=self._epochs,
            early_stop=self._early_stop,
            tolerance=self._tolerance,
            min_epoch=self._min_epoch,
            **self._encoder_params,
        )

        # record the losses
        result_dict = {
            "losses": {"train_losses": train_losses, "val_losses": val_losses}
        }

        if self._flatten_emb == False:
            flatten_emb_name = "noflatten"

        plot_lc(
            train_losses=train_losses,
            val_losses=val_losses,
            dataset_path=self._dataset_path,
            encoder_name=self._encoder_name,
            embed_layer=embed_layer,
            flatten_emb=flatten_emb_name,
            all_plot_folder=get_default_output_path(self._all_plot_folder),
        )

        # now test the model with the test data
        for subset, loader_key in zip(self._subset_list, self._loader_dict.keys()):

            loss, pred, cls, true = test(
                model=model,
                loader=self._loader_dict[loader_key],
                embed_layer=embed_layer,
                device=self._device,
                criterion=criterion,
            )

            if model_name == "LinearRegression":
                result_dict[subset] = {
                    "mse": loss,
                    "pred": pred,
                    "true": true,
                    "ndcg": ndcg_score(true[None, :], pred[None, :]),
                    "rho": spearmanr(true, pred),
                }

            elif model_name == "LinearClassifier":
                result_dict[subset] = {
                    "cross-entropy": loss,
                    "pred": pred,
                    "true": true,
                    "acc": accuracy_score(true, cls),
                    "rocauc": roc_auc_score(
                        true,
                        nn.Softmax(dim=1)(torch.from_numpy(pred)).numpy(),
                        multi_class="ovr",
                    )
                    # "rocauc": eval_rocauc(true, pred),
                }

            elif model_name == "MultiLabelMultiClass":
                scaled_pred = nn.Softmax(dim=2)(torch.from_numpy(pred)).numpy()
                result_dict[subset] = {
                    "bceloss": loss,
                    "pred": pred,
                    "true": true,
                    "acc": accuracy_score(true.flatten(), cls.flatten()),
                    "rocauc": roc_auc_score(
                        true.flatten(),
                        scaled_pred.reshape(-1, scaled_pred.shape[-1]),
                        multi_class="ovr",
                    )
                }

        dataset_subfolder, file_name = get_folder_file_names(
            parent_folder=get_default_output_path(self._all_result_folder),
            dataset_path=self._dataset_path,
            encoder_name=self._encoder_name,
            embed_layer=embed_layer,
            flatten_emb=flatten_emb_name,
        )

        print(f"Saving results for {file_name} to: {dataset_subfolder}...")
        pickle_save(
            what2save=result_dict,
            where2save=os.path.join(dataset_subfolder, file_name + ".pkl"),
        )

In [22]:
Run_Pytorch(
    dataset_path="data/structure/secondary_structure/tape_ss3_processed.csv",
    encoder_name="",
    reset_param = False,
    resample_param = False,
    embed_batch_size = 4,
    flatten_emb= False,
    embed_folder=None,
    seq_start_idx= False,
    seq_end_idx = False,
    loader_batch_size = 4,
    worker_seed = RAND_SEED,
    if_encode_all = True,
    learning_rate = 1e-4,
    lr_decay = 0.1,
    epochs = 1,
    early_stop = True,
    tolerance = 10,
    min_epoch = 5,
    device = DEVICE,
    all_plot_folder = "test/learning_curves",
    all_result_folder = "test/pytorch",
    # **encoder_params,
    )

This dataset includes subsets: ['train', 'val', 'cb513', 'ts115', 'casp12']...


  9%|▊         | 188/2170 [00:00<00:01, 1876.70it/s]

Generating onehot upto 0 layer embedding ...
Encoding all...


100%|██████████| 2170/2170 [00:01<00:00, 1554.61it/s]


Converting ss3/ss8 into np.array and pad -1...
y shape in dataset is (8678, 1632)


 34%|███▎      | 182/543 [00:00<00:00, 1815.28it/s]

Generating onehot upto 0 layer embedding ...
Encoding all...


100%|██████████| 543/543 [00:00<00:00, 1227.34it/s]


Converting ss3/ss8 into np.array and pad -1...
y shape in dataset is (2170, 1632)


100%|██████████| 129/129 [00:00<00:00, 1829.00it/s]

Generating onehot upto 0 layer embedding ...
Encoding all...
Converting ss3/ss8 into np.array and pad -1...
y shape in dataset is (513, 1632)



100%|██████████| 29/29 [00:00<00:00, 1702.64it/s]

Generating onehot upto 0 layer embedding ...
Encoding all...
Converting ss3/ss8 into np.array and pad -1...
y shape in dataset is (115, 1632)



100%|██████████| 6/6 [00:00<00:00, 1688.87it/s]
  0%|          | 0/1 [00:00<?, ?it/s]

Generating onehot upto 0 layer embedding ...
Encoding all...
Converting ss3/ss8 into np.array and pad -1...
y shape in dataset is (21, 1632)
Running pytorch model for layer 0
x, y and y.flatten() shape in get_x_y
torch.Size([4, 1632, 23]) torch.Size([4, 1, 1632]) torch.Size([6528])
x[0], y[0] shape torch.Size([1632, 23]) torch.Size([1632])
x[0][y[0]!= -1, :] tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [1., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]], dtype=torch.float64)
y is now torch.Size([885])
after concat x and y in get_x_y
torch.Size([885, 23]) torch.Size([885])
torch.float32 torch.int64
x in forward in model
torch.float32 torch.Size([885, 23])
torch.Size([885, 23]) torch.Size([885]) torch.Size([885, 3])
torch.float32 torch.int64 torch.float32
torch.Size([885, 3])





RuntimeError: 1D target tensor expected, multi-target not supported