In [1]:
%cd ~/repo/protein-transfer

/home/t-fli/repo/protein-transfer


In [2]:
%load_ext blackcellmagic

In [9]:
import os
import numpy as np
import pandas as pd

import torch
from torch.utils.data import Dataset, DataLoader

In [5]:
from scr.utils import pickle_load
from scr.preprocess.data_process import TaskProcess

In [4]:
TaskProcess().sum_file_df

Unnamed: 0,task,dataset,split,csv_path,fasta_path,pkl_path
0,annotation,scl,balanced,data/annotation/scl/balanced.csv,,
1,structure,secondary_structure,cb513,[data/structure/secondary_structure/secondary_...,,
2,proeng,gb1,two_vs_rest,data/proeng/gb1/two_vs_rest.csv,data/proeng/gb1/5LDE_1.fasta,data/proeng/gb1/two_vs_rest.pkl
3,proeng,gb1,low_vs_high,data/proeng/gb1/low_vs_high.csv,data/proeng/gb1/5LDE_1.fasta,data/proeng/gb1/low_vs_high.pkl
4,proeng,aav,one_vs_many,data/proeng/aav/one_vs_many.csv,data/proeng/aav/P03135.fasta,data/proeng/aav/one_vs_many.pkl
5,proeng,aav,two_vs_many,data/proeng/aav/two_vs_many.csv,data/proeng/aav/P03135.fasta,data/proeng/aav/two_vs_many.pkl
6,proeng,thermo,mixed,data/proeng/thermo/mixed.csv,,


In [3]:
from scr.preprocess.data_process import ProtranDataset

In [4]:
dataset_path="data/proeng/gb1/two_vs_rest.pkl"
val_dataset = ProtranDataset(dataset_path=dataset_path, subset="val")
len(val_dataset), val_dataset[0]

(43,
 ('MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVNGEWTYDDATKTFTVTELEVLFQGPLDPNSMATYEVLCEVARKLGTDDREVVLFLLNVFIPQPTLAQLIGALRALKEEGRLTFPLLAECLFRAGRRDLLRDLLHLDPRFLERHLAGTMSYFSPYQLTVLHVDGELCARDIRSLIFLSKDTIGSRSTPQTFLHWVYCMENLDLLGPTDVDALMSMLRSLSRVDLQRQVQTLMGLHLSGPSHSQHYRHTPLEHHHHHH',
  1.80292279874,
  'D40N',
  1))

In [None]:
df = pickle_load("data/proeng/gb1/two_vs_rest.pkl")

df_train = df.loc[(df["set"] == "train") & (df["validation"] != True)]
df_val = df.loc[(df["set"] == "train") & (df["validation"] == True)]
df_test = df.loc[(df["set"] == "test")]

len(df_train), len(df_val), len(df_test), len(df)

In [74]:
class ProtranDataset(Dataset):

    """A dataset class for processing protein transfer data"""

    def __init__(self, dataset_path: str, subset: str):

        """
        Args:
        - dataset_path: str, full path to the dataset, in pkl or panda readable format
            columns include: sequence, target, set, validation, mut_name (optional), mut_numb (optional)
        - subset: str, train, val, test
        """

        # with additional info mut_name, mut_numb
        if os.path.splitext(dataset_path)[-1] in [".pkl", ".PKL", ""]:
            self._df = pickle_load(dataset_path)
            self._add_mut_info = True
        # without such info
        else:
            self._df = pd.read_csv(dataset_path)
            self._add_mut_info = False

        assert "set" in self._df.columns, f"set is not a column in {dataset_path}"
        assert (
            "validation" in self._df.columns
        ), f"validation is not a column in {dataset_path}"

        self._df_train = self._df.loc[
            (self._df["set"] == "train") & (self._df["validation"] != True)
        ]
        self._df_val = self._df.loc[
            (self._df["set"] == "train") & (self._df["validation"] == True)
        ]
        self._df_test = self._df.loc[(self._df["set"] == "test")]

        self._df_dict = {
            "train": self._df_train,
            "val": self._df_val,
            "test": self._df_test,
        }

        assert subset in list(
            self._df_dict.keys()
        ), "split can only be 'train', 'val', or 'test'"
        self._subset = subset

        self._subdf_len = len(self._df_dict[self._subset])

        # get unencoded string of input sequence
        # will need to convert data type
        # self.x = torch.tensor(x,dtype=torch.float32)
        self.x = self._get_column_value("sequence")

        # get and format the fitness or secondary structure values
        # can be numbers or string
        # will need to convert data type
        self.y = self._get_column_value("target")

        if self._add_mut_info:
            self.mut_name = self._get_column_value("mut_name")
            self.mut_numb = self._get_column_value("mut_numb")
        else:
            self.mut_name = [""] * self._subdf_len
            self.mut_numb = [np.nan] * self._subdf_len

    def __len__(self):
        """Return the length of the selected subset of the dataframe"""
        return self._subdf_len

    def __getitem__(self, idx: int):

        """
        Return the item in the order of
        sequence (x), target (y), mut_name (optional), mut_numb (optional)
        """

        return self.x[idx], self.y[idx], self.mut_name[idx], self.mut_numb[idx]

    def _get_column_value(self, column_name: str) -> np.ndarray:
        """
        Check and return the column values of the selected dataframe subset
        """
        if column_name in self._df.columns:
            return self._df_dict[self._subset][column_name].values

    @property
    def df_full(self) -> pd.DataFrame:
        """Return the full loaded dataset"""
        return self._df

    @property
    def df_train(self) -> pd.DataFrame:
        """Return the dataset for training only"""
        return self._df_train

    @property
    def df_val(self) -> pd.DataFrame:
        """Return the dataset for validation only"""
        return self._df_val

    @property
    def df_test(self) -> pd.DataFrame:
        """Return the dataset for training only"""
        return self._df_test

In [64]:
dataset_path="data/proeng/gb1/two_vs_rest.pkl"
val_dataset = ProtranDataset(dataset_path=dataset_path, subset="val")

In [65]:
len(val_dataset), val_dataset[0]

(43,
 ('MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVNGEWTYDDATKTFTVTELEVLFQGPLDPNSMATYEVLCEVARKLGTDDREVVLFLLNVFIPQPTLAQLIGALRALKEEGRLTFPLLAECLFRAGRRDLLRDLLHLDPRFLERHLAGTMSYFSPYQLTVLHVDGELCARDIRSLIFLSKDTIGSRSTPQTFLHWVYCMENLDLLGPTDVDALMSMLRSLSRVDLQRQVQTLMGLHLSGPSHSQHYRHTPLEHHHHHH',
  1.80292279874,
  'D40N',
  1))

In [66]:
val_dataset.x[0], val_dataset.y[0], val_dataset.mut_name[0], val_dataset.mut_numb[0]

('MQYKLILNGKTLKGETTTEAVDAATAEKVFKQYANDNGVNGEWTYDDATKTFTVTELEVLFQGPLDPNSMATYEVLCEVARKLGTDDREVVLFLLNVFIPQPTLAQLIGALRALKEEGRLTFPLLAECLFRAGRRDLLRDLLHLDPRFLERHLAGTMSYFSPYQLTVLHVDGELCARDIRSLIFLSKDTIGSRSTPQTFLHWVYCMENLDLLGPTDVDALMSMLRSLSRVDLQRQVQTLMGLHLSGPSHSQHYRHTPLEHHHHHH',
 1.80292279874,
 'D40N',
 1)

In [67]:
dataset_path = "data/annotation/scl/balanced.csv"
val_dataset = ProtranDataset(dataset_path=dataset_path, subset="val")
len(val_dataset), val_dataset[0]


(1678,
 ('MATPSAAFEALMNGVTSWDVPEDAVPCELLLIGEASFPVMVNDMGQVLIAASSYGRGRLVVVSHEDYLVEAQLTPFLLNAVGWLCSSPGAPIGVHPSLAPLAKILEGSGVDAKVEPEVKDSLGVYCIDAYNETMTEKLVKFMKCGGGLLIGGQAWDWANQGEDERVLFTFPGNLVTSVAGIYFTDNKGDTSFFKVSKKMPKIPVLVSCEDDLSDDREELLHGISELDISNSDCFPSQLLVHGALAFPLGLDSYHGCVIAAARYGRGRVVVTGHKVLFTVGKLGPFLLNAVRWLDGGRRGKVVVQTELRTLSGLLAVGGIDTSIEPNLTSDASVYCFEPVSEVGVKELQEFVAEGGGLFVGAQAWWWAFKNPGVSPLARFPGNLLLNPFGISITSQSLNPGPFRTPKAGIRTYHFRSTLAEFQVIMGRKRGNVEKGWLAKLGPDGAAFLQIPAEEIPAYMSVHRLLRKLLSRYRLPVATRENPVINDCCRGAMLSLATGLAHSGSDLSLLVPEIEDMYSSPYLRPSESPITVEVNCTNPGTRYCWMSTGLYIPGRQIIEVSLPEAAASADLKIQIGCHTDDLTRASKLFRGPLVINRCCLDKPTKSITCLWGGLLYIIVPQNSKLGSVPVTVKGAVHAPYYKLGETTLEEWKRRIQENPGPWGELATDNIILTVPTANLRTLENPEPLLRLWDEVMQAVARLGAEPFPLRLPQRIVADVQISVGWMHAGYPIMCHLESVQELINEKLIRTKGLWGPVHELGRNQQRQEWEFPPHTTEATCNLWCVYVHETVLGIPRSRANIALWPPVREKRVRIYLSKGPNVKNWNAWTALETYLQLQEAFGWEPFIRLFTEYRNQTNLPTENVDKMNLWVKMFSHQVQKNLAPFFEAWAWPIQKEVATSLAYLPEWKENIMKLYLLTQMPH',
  'Cell membrane',
  '',
  nan))

In [68]:
val_dataset.x[0], val_dataset.y[0], val_dataset.mut_name[0], val_dataset.mut_numb[0]

('MATPSAAFEALMNGVTSWDVPEDAVPCELLLIGEASFPVMVNDMGQVLIAASSYGRGRLVVVSHEDYLVEAQLTPFLLNAVGWLCSSPGAPIGVHPSLAPLAKILEGSGVDAKVEPEVKDSLGVYCIDAYNETMTEKLVKFMKCGGGLLIGGQAWDWANQGEDERVLFTFPGNLVTSVAGIYFTDNKGDTSFFKVSKKMPKIPVLVSCEDDLSDDREELLHGISELDISNSDCFPSQLLVHGALAFPLGLDSYHGCVIAAARYGRGRVVVTGHKVLFTVGKLGPFLLNAVRWLDGGRRGKVVVQTELRTLSGLLAVGGIDTSIEPNLTSDASVYCFEPVSEVGVKELQEFVAEGGGLFVGAQAWWWAFKNPGVSPLARFPGNLLLNPFGISITSQSLNPGPFRTPKAGIRTYHFRSTLAEFQVIMGRKRGNVEKGWLAKLGPDGAAFLQIPAEEIPAYMSVHRLLRKLLSRYRLPVATRENPVINDCCRGAMLSLATGLAHSGSDLSLLVPEIEDMYSSPYLRPSESPITVEVNCTNPGTRYCWMSTGLYIPGRQIIEVSLPEAAASADLKIQIGCHTDDLTRASKLFRGPLVINRCCLDKPTKSITCLWGGLLYIIVPQNSKLGSVPVTVKGAVHAPYYKLGETTLEEWKRRIQENPGPWGELATDNIILTVPTANLRTLENPEPLLRLWDEVMQAVARLGAEPFPLRLPQRIVADVQISVGWMHAGYPIMCHLESVQELINEKLIRTKGLWGPVHELGRNQQRQEWEFPPHTTEATCNLWCVYVHETVLGIPRSRANIALWPPVREKRVRIYLSKGPNVKNWNAWTALETYLQLQEAFGWEPFIRLFTEYRNQTNLPTENVDKMNLWVKMFSHQVQKNLAPFFEAWAWPIQKEVATSLAYLPEWKENIMKLYLLTQMPH',
 'Cell membrane',
 '',
 nan)

In [73]:
val_dataset.x[0], val_dataset.y[0], val_dataset.mut_name[0], val_dataset.mut_numb[0]

('MSFTLTNKNVIFVAGLGGIGLDTSKELLKRDLKNLVILDRIENPAAIAELKAINPKVTVTFYPYDVTVPIAETTKLLKTIFAQLKTVDVLINGAGILDDHQIERTIAVNYTGLVNTTTAILDFWDKRKGGPGGIICNIGSVTGFNAIYQVPVYSGTKAAVVNFTSSLAKLAPITGVTAYTVNPGITRTTLVHKFNSWLDVEPQVAEKLLAHPTQPSLACAENFVKAIELNQNGAIWKLDLGTLEAIQWTKHWDSGI',
 'neutral',
 '',
 nan)