In [1]:
import os
os.chdir('..')

In [2]:
from src.multimodal.model import MultimodalModel
from src.utils import get_lookup, load_pkl, get_labels
from src.text.model.word_embedding import WordEmbedding
from src.text.model.custom_dataset import CustomDataset
from src.mol.preprocess import mapped_smiles_reader, candidate_smiles
from src.mol.mol_dataset import MolDataset
from src.multimodal.dataset import MultimodalDataset

## Text

In [4]:
lookup_word = get_lookup('cache/fasttext/nguyennb/all_words.txt')
lookup_tag = get_lookup('cache/fasttext/nguyennb/all_pos.txt')
lookup_dep = get_lookup('cache/fasttext/nguyennb/all_dep.txt')
lookup_direction = get_lookup('cache/fasttext/nguyennb/all_direction.txt')

all_candidates_train = load_pkl('cache/pkl/v1/candidates.train.pkl')
all_candidates_test = load_pkl('cache/pkl/v1/candidates.test.pkl')
sdp_train = load_pkl('cache/pkl/v1/sdp.train.pkl')
sdp_test = load_pkl('cache/pkl/v1/sdp.test.pkl')
sdp_train_mapped = load_pkl('cache/pkl/v1/train.mapped.sdp.pkl')
sdp_test_mapped = load_pkl('cache/pkl/v1/test.mapped.sdp.pkl')
we = WordEmbedding(fasttext_path='cache/fasttext/nguyennb/fastText_ddi.npz',
                   vocab_path='cache/fasttext/nguyennb/all_words.txt')

y_train = get_labels(all_candidates_train)
y_test = get_labels(all_candidates_test)

In [41]:
dataset_train_text = CustomDataset(sdp_train_mapped, y_train)
dataset_train_text.fix_exception()
dataset_train_text.batch_padding(batch_size=128, min_batch_size=3)
dataset_train_text.squeeze()
dataset_test_text = CustomDataset(sdp_test_mapped, y_test)
dataset_test_text.fix_exception()
dataset_test_text.batch_padding(batch_size=128, min_batch_size=3)
dataset_test_text.squeeze()



In [42]:
from torch.utils.data import DataLoader

train_loader_text = DataLoader(dataset_train_text, batch_size=128)
test_loader_text = DataLoader(dataset_test_text, batch_size=128)

## Mol

In [6]:
all_candidates_train = load_pkl('cache/pkl/v1/candidates.train.pkl')
all_candidates_test = load_pkl('cache/pkl/v1/candidates.test.pkl')
mapped_smiles = mapped_smiles_reader('cache/mapped_drugs/all_mapped.txt')
x_train, y_train = candidate_smiles(all_candidates_train, mapped_smiles)
x_test, y_test = candidate_smiles(all_candidates_test, mapped_smiles)
dataset_train_mol1 = MolDataset(x_train, element=1)
dataset_train_mol2 = MolDataset(x_train, element=2)
dataset_test_mol1 = MolDataset(x_test, element=1)
dataset_test_mol1 = MolDataset(x_test, element=2)

Converting SMILES to PyG: 100%|██████████| 27792/27792 [00:15<00:00, 1800.89it/s]
Converting SMILES to PyG: 100%|██████████| 27792/27792 [00:16<00:00, 1688.12it/s]
Converting SMILES to PyG: 100%|██████████| 5716/5716 [00:03<00:00, 1886.37it/s]
Converting SMILES to PyG: 100%|██████████| 5716/5716 [00:03<00:00, 1606.04it/s]


In [7]:
from torch_geometric.loader import DataLoader
train_loader_mol1 = DataLoader(dataset_train_mol1, batch_size=128, shuffle=False)
train_loader_mol2 = DataLoader(dataset_train_mol2, batch_size=128, shuffle=False)

## Training

In [5]:
model = MultimodalModel(we)

In [6]:
total_params = sum(p.numel() for p in model.parameters())
print(total_params)

5997788
