In [22]:
import torch
from torch.utils.data import Dataset, DataLoader
import pickle
import re

In [21]:
with open('data/ADAGRASIB_SMILES.txt', 'r') as file : 
    smi_list = file.readlines() 
    smi_list = [smi[:-1] for smi in smi_list if smi.endswith('\n')]

In [33]:
class MyDataset(Dataset):
    def __init__(self, path, max_len):
        self.path = path
        self.max_len = max_len
        self.token_list = self.process()

    def extract(self):
        if self.path.lower().endswith('.txt'):
            with open(self.path, 'r') as file:
                data = [line.strip() for line in file if len(line) < self.max_len]
            return data
        elif self.path.lower().endswith('.pickle'):
            with open(self.path, 'rb') as file:
                data = pickle.load(file)
                data = [x for x in data if len(x) < self.max_len]
            return data
        else:
            raise ValueError("Unsupported file format. Only .txt and .pkl files are supported.")
    
    def tokenizer(self, smile):
        pattern =  "(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\\\|\/|_|:|~|@|\?|>|\*|\$|\%[0-9]{2}|[0-9])"
        regezz = re.compile(pattern)
        tokens = [token for token in regezz.findall(smile)]
        assert smile == ''.join(tokens), ("{} could not be joined".format(smile))
        return tokens
    
    def get_vocab(self, smi_list) :
        vocab = {'<START>': 0, '<END>': 1, '<PAD>': 2}
        for smi in smi_list :
            for char in smi :
                if char not in vocab :
                    vocab[char] = len(vocab) 
        return vocab 
    
    def encode(self, smi, vocab) :
        return [vocab['<START>']] + [vocab[char] for char in smi] + [vocab['<END>']]
    
    def pad(self, smi, max_len) :
        return smi + [2] * (max_len - len(smi))
    
    def process(self) : 
        smi_list = self.extract() 
        smi_list = [self.tokenizer(s) for s in smi_list]

        self.vocab = self.get_vocab(smi_list) 
        inv_vocab = {v:k for k, v in self.vocab.items()}

        token_list = [self.encode(s, self.vocab) for s in smi_list]
        max_token_len = len(max(token_list, key=len))
        token_list = [self.pad(t, max_token_len) for t in token_list]

        return token_list

    def __len__(self) : 
        return len(self.token_list) 
    
    def __getitem__(self, idx) : 
        tensor = torch.tensor(self.token_list[idx], dtype=torch.long)
        return tensor


In [38]:
dataset = MyDataset('data/chembl24_canon_train.pickle', 20)