In [1]:
import os
os.chdir('..')

In [2]:
from src.multimodal.model import MultimodalModel
from src.utils import get_lookup, load_pkl, get_labels
from src.text.model.word_embedding import WordEmbedding
from src.text.model.custom_dataset import CustomDataset
from src.mol.preprocess import mapped_smiles_reader, candidate_smiles
from src.mol.mol_dataset import MolDataset
from src.multimodal.dataset import MultimodalDataset

## Text

In [3]:
lookup_word = get_lookup('cache/fasttext/nguyennb/all_words.txt')
lookup_tag = get_lookup('cache/fasttext/nguyennb/all_pos.txt')
lookup_dep = get_lookup('cache/fasttext/nguyennb/all_dep.txt')
lookup_direction = get_lookup('cache/fasttext/nguyennb/all_direction.txt')

all_candidates_train = load_pkl('cache/pkl/v1/candidates.train.pkl')
all_candidates_test = load_pkl('cache/pkl/v1/candidates.test.pkl')
sdp_train = load_pkl('cache/pkl/v1/sdp.train.pkl')
sdp_test = load_pkl('cache/pkl/v1/sdp.test.pkl')
sdp_train_mapped = load_pkl('cache/pkl/v1/train.mapped.sdp.pkl')
sdp_test_mapped = load_pkl('cache/pkl/v1/test.mapped.sdp.pkl')
we = WordEmbedding(fasttext_path='cache/fasttext/nguyennb/fastText_ddi.npz',
                   vocab_path='cache/fasttext/nguyennb/all_words.txt')

y_train = get_labels(all_candidates_train)
y_test = get_labels(all_candidates_test)

In [4]:
dataset_train_text = CustomDataset(sdp_train_mapped, y_train)
dataset_train_text.fix_exception()
dataset_train_text.batch_padding(batch_size=1)

dataset_test_text = CustomDataset(sdp_test_mapped, y_test)
dataset_test_text.fix_exception()
dataset_test_text.batch_padding(batch_size=1)



## Mol

In [5]:
all_candidates_train = load_pkl('cache/pkl/v1/candidates.train.pkl')
all_candidates_test = load_pkl('cache/pkl/v1/candidates.test.pkl')
mapped_smiles = mapped_smiles_reader('cache/mapped_drugs/all_mapped.txt')
x_train, y_train = candidate_smiles(all_candidates_train, mapped_smiles)
x_test, y_test = candidate_smiles(all_candidates_test, mapped_smiles)
dataset_train_mol = MolDataset(x_train, y_train)
dataset_test_mol = MolDataset(x_test, y_test)

Converting SMILES to PyG: 100%|██████████| 27792/27792 [00:30<00:00, 925.46it/s] 
Converting SMILES to PyG: 100%|██████████| 5716/5716 [00:06<00:00, 904.91it/s] 


## Training

In [7]:
dataset_train = MultimodalDataset(dataset_train_text, dataset_train_mol)

In [19]:
import torch
import torch.nn
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool

from src.seed import MANUAL_SEED
from src.text.model.text_model import TextModel
from src.mol.gcn import GCN

class MultimodalModel(torch.nn.Module):
    def __init__(self, 
                 we,
                 dropout_rate: float = 0.5,
                 word_embedding_size: int = 200,
                 tag_number: int = 51,
                 tag_embedding_size: int = 50,
                 position_number: int = 4,
                 position_embedding_size: int = 50,
                 direction_number: int = 3,
                 direction_embedding_size: int = 50,
                 edge_number: int = 46,
                 edge_embedding_size: int = 200,
                 token_embedding_size: int = 500,
                 dep_embedding_size: int = 500,
                 conv1_out_channels: int = 256,
                 conv2_out_channels: int = 256,
                 conv3_out_channels: int = 256,
                 conv1_length: int = 1,
                 conv2_length: int = 2,
                 conv3_length: int = 3,
                 target_class: int = 5,
                 num_node_features: int = 4, 
                 hidden_channels: int = 256,
                 device: str = 'cpu'):
        super(MultimodalModel, self).__init__()
        torch.manual_seed(MANUAL_SEED)
        self.device = device

        self.text_model = TextModel(we=we,
                                    dropout_rate=dropout_rate,
                                    word_embedding_size=word_embedding_size,
                                    tag_number=tag_number,
                                    tag_embedding_size=tag_embedding_size,
                                    position_number=position_number,
                                    position_embedding_size=position_embedding_size,
                                    direction_number=direction_number,
                                    direction_embedding_size=direction_embedding_size,
                                    edge_number=edge_number,
                                    edge_embedding_size=edge_embedding_size,
                                    token_embedding_size=token_embedding_size,
                                    dep_embedding_size=dep_embedding_size,
                                    conv1_out_channels=conv1_out_channels,
                                    conv2_out_channels=conv2_out_channels,
                                    conv3_out_channels=conv3_out_channels,
                                    conv1_length=conv1_length,
                                    conv2_length=conv2_length,
                                    conv3_length=conv3_length,
                                    target_class=target_class)

        self.gcn1 = GCN(num_node_features=num_node_features,
                       hidden_channels=hidden_channels,
                       dropout_rate=dropout_rate, 
                       device=device)
        
        self.gcn2 = GCN(num_node_features=num_node_features,
                        hidden_channels=hidden_channels,
                        dropout_rate=dropout_rate, 
                        device=device)

        self.dense_to_tag = torch.nn.Linear(in_features=conv1_out_channels+conv2_out_channels+conv3_out_channels+2*hidden_channels, 
                                            out_features=target_class,
                                            bias=False)

    def forward(self, text_x, mol_x1, mol_x2):
        text_x = self.text_model(text_x)
        mol_x1 = self.gcn1(mol_x1)
        mol_x2 = self.gcn2(mol_x2)

        x = torch.cat((text_x, mol_x1, mol_x2), dim=1)

        # Classifier
        x = self.dense_to_tag(x)
        x = self.softmax(x)

        return x

SyntaxError: unmatched ')' (1460969149.py, line 74)

In [20]:
model = MultimodalModel(we)

In [15]:
model(dataset_train[0][0][0], dataset_train[0][0][1][0], dataset_train[0][0][1][1])

AttributeError: 'MultimodalModel' object has no attribute 'dense_to_tag'

In [19]:
dataset_train[0][0][1]

[Data(x=[101, 4], edge_index=[2, 212], edge_attr=[212, 2], mol=<rdkit.Chem.rdchem.Mol object at 0x7feeafdf2e30>, smiles='C[C@H](C(=O)N)NC(=O)[C@@H]1CCCN1C(=O)[C@H](CCCCNC(C)C)NC(=O)[C@H](CC(C)C)NC(=O)[C@@H](CC(=O)N)NC(=O)[C@H](CC2=CC=C(C=C2)O)N(C)C(=O)[C@H](CO)NC(=O)[C@@H](CC3=CN=CC=C3)NC(=O)[C@@H](CC4=CC=C(C=C4)Cl)NC(=O)[C@@H](CC5=CC6=CC=CC=C6C=C5)NC(=O)C'),
 Data(x=[21, 4], edge_index=[2, 48], edge_attr=[48, 2], mol=<rdkit.Chem.rdchem.Mol object at 0x7feeafdf2ff0>, smiles='C[C@]12CC[C@H]3[C@H]([C@@H]1CC[C@@H]2O)CCC4=CC(=O)CC[C@]34C')]