In [7]:
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import numpy as np

In [21]:
def one_hot_encode_smiles(smiles, charset, max_length=120):
    char_to_int = dict((c, i) for i, c in enumerate(charset))
    integer_encoded = [char_to_int[char] for char in smiles]
    if len(integer_encoded) > max_length:
        integer_encoded = integer_encoded[:max_length]
    else:
        integer_encoded = integer_encoded + [0] * (max_length - len(integer_encoded))
    onehot_encoded = np.zeros((max_length, len(charset)), dtype=np.float32)
    for i, val in enumerate(integer_encoded):
        onehot_encoded[i, val] = 1.0

    return onehot_encoded


class SMILESDataset(Dataset):
    def __init__(self, smiles_list, charset):
        self.smiles_list = smiles_list
        self.charset = charset

    def __len__(self):
        return len(self.smiles_list)

    def __getitem__(self, idx):
        smiles = self.smiles_list[idx]
        encoded_smiles = one_hot_encode_smiles(smiles, self.charset)
        return torch.FloatTensor(encoded_smiles), smiles

- __init__: インスタンス生成時に実行されるメソッド
- __len__: データセットの長さを返す。len()を呼び出した際の挙動に対応。
- __getitem__: 指定されたインデックスのデータを返します。[0]などでアクセスした際の挙動に対応。

In [22]:
df = pd.read_csv("250k_rndm_zinc_drugs_clean_3.csv")
df["smiles"] = df["smiles"].str.rstrip("\n")
charset = set("".join(df["smiles"].values.tolist()))

dataset = SMILESDataset(df["smiles"].values.tolist(), charset)

In [35]:
dataset[0]

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.],
         [1., 0., 0.,  ..., 0., 0., 0.]]),
 'CC(C)(C)c1ccc2occ(CC(=O)Nc3ccccc3F)c2c1')

In [37]:
len(dataset), len(df)

(249455, 249455)

In [24]:
test_ratio = 0.2

test_size = int(test_ratio * len(dataset))
train_size = len(dataset) - test_size
train_dataset, test_dataset = random_split(dataset, [train_size, test_size])

In [25]:
len(train_dataset), len(test_dataset)

(199564, 49891)

In [26]:
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=250, shuffle=True)

In [34]:
for batch_idx, (data, smiles_list) in enumerate(train_loader):
    print(f"Batch {batch_idx}: data shape {data.shape}")
    if batch_idx == 0:
        print(smiles_list)
        break

Batch 0: data shape torch.Size([250, 120, 34])
('C[C@@H]1CCCCN1C(=O)c1cc(C(F)(F)F)nn1C', 'Cc1nc(-c2cc(S(=O)(=O)Nc3ccc(Cl)cc3)c(C)s2)sc1C', 'Cc1ccncc1NC(=O)N[C@H]1CCC[C@H]1CNC(=O)OC(C)(C)C', 'CC[C@H](C)NC(=O)N1CCc2ccc([N+](=O)[O-])cc2C1', 'C[C@@H](C(=O)[O-])N1C(=O)/C(=C/c2ccc(N(C)C)cc2)SC1=S', 'Cc1nc(CNC(=O)c2cc(NC(=O)C(C)(C)C)ccc2F)oc1C', 'O=S(=O)(c1ccccc1C(F)(F)F)N1CCN(c2cccc(Cl)c2)CC1', 'CC(C)CN1C[C@H](C(=O)N2CCn3c(C(N)=O)cnc3C2)CC1=O', 'Cc1ccc(C)c(NC(=O)CCCOc2ccc(Cl)cc2)c1', 'Cc1ccc(CNC(=O)NC2CC[NH+](C)CC2)c(OC[C@@H]2CCOC2)c1', 'O=C(CCS(=O)(=O)c1ccc(Cl)cc1)Nc1cc(Cl)cc(Cl)c1', 'C[C@H](N[C@H](C)C(=O)N(C)C)c1cc(F)ccc1O', 'CCCCn1c(N)c(N(Cc2ccccc2)C(=O)c2ccc(Cl)c(Cl)c2)c(=O)[nH]c1=O', 'CCc1nnc(NS(=O)(=O)c2ccc3ccccc3c2)s1', 'CC(=O)N1CCC[C@@H]2[C@H]1[C@@H](c1ccc(F)cc1)CN2C(=O)c1cc(C(C)C)n[nH]1', 'O=C(NCC(F)(F)F)C(=O)NCC1(O)CCCC1', 'CC[C@H]1CC[C@H](c2noc([C@](C)(N)C(F)(F)F)n2)C1', 'CCOC(=O)c1csc(NC(=O)/C=C(/C)c2ccccc2OC)n1', 'C[C@@H]1COC[C@@H](C)N1C(=O)c1cc(-c2ccc(Cl)s2)on1', 'COC(=O)[C@@H]