In [1]:
import os
os.chdir('..')

In [2]:
from src.multimodal.model import MultimodalModel
from src.utils import get_lookup, load_pkl, get_labels
from src.text.model.word_embedding import WordEmbedding
from src.text.model.custom_dataset import CustomDataset
from src.mol.preprocess import mapped_smiles_reader, candidate_smiles
from src.mol.mol_dataset import MolDataset
from src.multimodal.dataset import MultimodalDataset

## Text

In [40]:
lookup_word = get_lookup('cache/fasttext/nguyennb/all_words.txt')
lookup_tag = get_lookup('cache/fasttext/nguyennb/all_pos.txt')
lookup_dep = get_lookup('cache/fasttext/nguyennb/all_dep.txt')
lookup_direction = get_lookup('cache/fasttext/nguyennb/all_direction.txt')

all_candidates_train = load_pkl('cache/pkl/v1/candidates.train.pkl')
all_candidates_test = load_pkl('cache/pkl/v1/candidates.test.pkl')
sdp_train = load_pkl('cache/pkl/v1/sdp.train.pkl')
sdp_test = load_pkl('cache/pkl/v1/sdp.test.pkl')
sdp_train_mapped = load_pkl('cache/pkl/v1/train.mapped.sdp.pkl')
sdp_test_mapped = load_pkl('cache/pkl/v1/test.mapped.sdp.pkl')
we = WordEmbedding(fasttext_path='cache/fasttext/nguyennb/fastText_ddi.npz',
                   vocab_path='cache/fasttext/nguyennb/all_words.txt')

y_train = get_labels(all_candidates_train)
y_test = get_labels(all_candidates_test)

In [41]:
dataset_train_text = CustomDataset(sdp_train_mapped, y_train)
dataset_train_text.fix_exception()
dataset_train_text.batch_padding(batch_size=128, min_batch_size=3)
dataset_train_text.squeeze()
dataset_test_text = CustomDataset(sdp_test_mapped, y_test)
dataset_test_text.fix_exception()
dataset_test_text.batch_padding(batch_size=128, min_batch_size=3)
dataset_test_text.squeeze()



In [42]:
from torch.utils.data import DataLoader

train_loader_text = DataLoader(dataset_train_text, batch_size=128)
test_loader_text = DataLoader(dataset_test_text, batch_size=128)

## Mol

In [6]:
all_candidates_train = load_pkl('cache/pkl/v1/candidates.train.pkl')
all_candidates_test = load_pkl('cache/pkl/v1/candidates.test.pkl')
mapped_smiles = mapped_smiles_reader('cache/mapped_drugs/all_mapped.txt')
x_train, y_train = candidate_smiles(all_candidates_train, mapped_smiles)
x_test, y_test = candidate_smiles(all_candidates_test, mapped_smiles)
dataset_train_mol1 = MolDataset(x_train, element=1)
dataset_train_mol2 = MolDataset(x_train, element=2)
dataset_test_mol1 = MolDataset(x_test, element=1)
dataset_test_mol1 = MolDataset(x_test, element=2)

Converting SMILES to PyG: 100%|██████████| 27792/27792 [00:15<00:00, 1800.89it/s]
Converting SMILES to PyG: 100%|██████████| 27792/27792 [00:16<00:00, 1688.12it/s]
Converting SMILES to PyG: 100%|██████████| 5716/5716 [00:03<00:00, 1886.37it/s]
Converting SMILES to PyG: 100%|██████████| 5716/5716 [00:03<00:00, 1606.04it/s]


In [7]:
from torch_geometric.loader import DataLoader
train_loader_mol1 = DataLoader(dataset_train_mol1, batch_size=128, shuffle=False)
train_loader_mol2 = DataLoader(dataset_train_mol2, batch_size=128, shuffle=False)

## Training

In [8]:
model = MultimodalModel(we)

In [46]:
for ((a, batch_label), b, c) in zip(train_loader_text, train_loader_mol1, train_loader_mol2):
    print(model(a, b, c))
    break

tensor([[2.0715e-01, 2.1703e-01, 1.9849e-01, 1.8952e-01, 1.8782e-01],
        [2.0374e-01, 1.9408e-01, 1.9999e-01, 2.1135e-01, 1.9084e-01],
        [1.9544e-01, 2.0012e-01, 1.9868e-01, 2.1950e-01, 1.8626e-01],
        [1.9030e-01, 2.0838e-01, 1.9673e-01, 2.1009e-01, 1.9450e-01],
        [1.9599e-01, 2.2345e-01, 1.8425e-01, 2.1735e-01, 1.7896e-01],
        [1.9067e-01, 1.9614e-01, 2.0502e-01, 2.1093e-01, 1.9724e-01],
        [1.7989e-01, 2.0981e-01, 1.9814e-01, 2.1353e-01, 1.9862e-01],
        [1.8959e-01, 2.2152e-01, 1.8271e-01, 2.0894e-01, 1.9724e-01],
        [1.8666e-01, 2.0617e-01, 2.0112e-01, 2.0627e-01, 1.9978e-01],
        [1.9199e-01, 2.3025e-01, 1.8386e-01, 2.0527e-01, 1.8863e-01],
        [2.0334e-01, 2.3519e-01, 1.8438e-01, 1.9825e-01, 1.7884e-01],
        [2.1304e-01, 2.3446e-01, 1.9090e-01, 1.7903e-01, 1.8256e-01],
        [1.9728e-01, 2.5404e-01, 1.8426e-01, 1.9598e-01, 1.6843e-01],
        [1.9924e-01, 2.6227e-01, 1.7634e-01, 1.9521e-01, 1.6693e-01],
        [1.8753e-01,