In [1]:

from google.colab import drive
import sys
drive.mount('/content/drive')

# Add necessary directories to sys.path
sys.path.append('/content/drive/My Drive/AbelBioToken-main/AbelBioToken-main/data')
sys.path.append('/content/drive/My Drive/AbelBioToken-main/AbelBioToken-main/model')
sys.path.append('/content/drive/My Drive/AbelBioToken-main/AbelBioToken-main/train')
sys.path.append('/content/drive/My Drive/AbelBioToken-main/AbelBioToken-main/test')


Mounted at /content/drive


In [2]:
from enum import Enum, auto
import os
import csv
import torch
import context
from torch.utils.data import Dataset, DataLoader, dataloader
from torch.nn.utils.rnn import pad_sequence # Utility for padding sequences to the same length.
# These are used to pad sequences to a consistent length across batches.
_max_tok_len = 18
_max_lab_len = 10

# An enumeration for categorizing datasets into training, validation, and testing sets.
class DataCtg(Enum):
    TRAIN = auto()
    VAL = auto()
    TEST = auto()


class AminoDataSet(Dataset):
    def __init__(
        self,
        root_path,
        ctg=DataCtg.TRAIN,
    ):
        super(AminoDataSet, self).__init__()
        self.root_path = root_path
        self.ctg = ctg
        self.label_dict = dict()
        self.labels = list()
        self.tokens = list()

        with open(os.path.join(self.root_path, "target_vocab.csv")) as file: # contains mappings from label tokens (e.g., amino acids) to numerical indices
            data_reader = csv.reader(file)
            for row in data_reader:
                self.label_dict[row[0]] = row[1]

        if self.ctg == DataCtg.TRAIN:
            file_path = os.path.join(self.root_path, "train_data.csv")
        elif self.ctg == DataCtg.VAL:
            file_path = os.path.join(self.root_path, "val_data.csv")
        elif self.ctg == DataCtg.TEST:
            file_path = os.path.join(self.root_path, "test_data.csv")
        else:
            raise ValueError("There only train, validation and test dataset.")

        with open(file_path, mode="r") as in_file:
            data_reader = csv.reader(in_file)
            for row in data_reader:
                self.labels.append(self.label2sequence(row[0]))
                self.tokens.append(self.sentence2token(row[1:]))

                # global _max_lab_len
                # if len(self.labels[-1]) > _max_lab_len:
                #     _max_lab_len = len(self.labels[-1])
                # global _max_tok_len
                # if len(self.tokens[-1]) > _max_tok_len:
                #     _max_tok_len = len(self.tokens[-1])
            self.length = len(self.labels)

# Returns the label and token tuple at the specified index.
    def __getitem__(self, index):
        label = [int(x) for x in self.labels[index]]
        token = [int(x) for x in self.tokens[index]]
        return label, token

    def __len__(self):
        return self.length

    def sentence2token(self, sentence):
        token = list()
        for word in sentence:
            if word == "<SOS>":
                token.append(1400)
            elif word == "<NOS>":
                token.append(1401)
            elif word == "<EOS>":
                token.append(1402)
            else:
                token.append(int(word))

        return token

    def label2sequence(self, label):
        result = list()
        result.append(self.label_dict["<SOS>"])
        for amino in label:
            result.append(self.label_dict[amino])
        result.append(self.label_dict["<EOS>"])

        return result

# Custom function to collate data into batches ( list of (label, token) tuples) and pad sequences to the same length.
def collate_batch(batch):
    labels = [item[0] for item in batch]
    tokens = [item[1] for item in batch]
    global _max_tok_len
    global _max_lab_len

# Checks if each label sequence is shorter than _max_lab_len. If so, extends the label sequence with the padding index 24 (assumed to be the padding token for labels).
# Similarly, pads token sequences with the index 1403 (assumed to be the padding token for tokens).
    for label in labels:
        if len(label) < _max_lab_len:
            label.extend([24] * int(_max_lab_len - len(label)))
    for token in tokens:
        if len(token) < _max_tok_len:
            token.extend([1403] * (_max_tok_len - len(token)))
    return torch.tensor(labels, dtype=torch.int), torch.tensor(tokens, dtype=torch.int)


def get_dataloader(ctg: DataCtg, batch_size=32):

    dataset = AminoDataSet(os.path.join(context.parent_dir, "data"), ctg)
    dataloader = DataLoader(
        dataset, batch_size=batch_size, shuffle=True, collate_fn=collate_batch
    )
    return dataloader, _max_lab_len, _max_tok_len


if __name__ == "__main__":
    dataloader, lab_len, tok_len = get_dataloader(DataCtg.TRAIN, 16)

    for label, token in dataloader:
        print("label: ", label, " token: ", token)

    print(lab_len, tok_len)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
        [ 0, 17,  5, 18, 12,  1, 24, 24, 24, 24],
        [ 0,  5, 19,  6,  1, 24, 24, 24, 24, 24],
        [ 0, 18, 23,  7, 10,  1, 24, 24, 24, 24],
        [ 0, 10, 16,  8,  4,  1, 24, 24, 24, 24],
        [ 0,  4,  5, 10,  6,  1, 24, 24, 24, 24],
        [ 0,  9, 21,  7, 14,  1, 24, 24, 24, 24],
        [ 0,  2, 10, 20, 13,  1, 24, 24, 24, 24],
        [ 0, 14, 10,  3, 14,  1, 24, 24, 24, 24],
        [ 0, 21,  4, 21,  1, 24, 24, 24, 24, 24],
        [ 0, 15, 16,  1, 24, 24, 24, 24, 24, 24]], dtype=torch.int32)  token:  tensor([[1400,  909, 1057, 1205, 1401, 1025, 1402, 1403, 1403, 1403, 1403, 1403,
         1403, 1403, 1403, 1403, 1403, 1403],
        [1400,  772,  960, 1147, 1401,  927, 1402, 1403, 1403, 1403, 1403, 1403,
         1403, 1403, 1403, 1403, 1403, 1403],
        [1400, 1055, 1243, 1312, 1401, 1210, 1402, 1403, 1403, 1403, 1403, 1403,
         1403, 1403, 1403, 1403, 1403, 1403],
        [1400,  746,  934