In [1]:
# Install some dependencies if it is necessary
# !pip install torch_geometric rdkit torch

In [2]:
from datetime import datetime

import time
import argparse
import sys
import torch
from torch import optim
from torch import nn
import torch.nn.functional as F
from sklearn import metrics
import pandas as pd
import numpy as np
from torch.nn.modules.container import ModuleList
from torch_geometric.nn import (
    GATConv,
    SAGPooling,
    LayerNorm,
    global_mean_pool,
    max_pool_neighbor_x,
    global_add_pool,
)

import itertools
from collections import defaultdict
from operator import neg
import random
import math

import csv

from torch.utils.data import Dataset, DataLoader
from torch_geometric.data import Data, Batch
from rdkit import Chem

### Common configuration

Common configuration: model going to be used for the prediction or to be saved in the training process

In [1]:
# Directory configuration
data_dir = "data"
model_dir = "models"
model_name = "case28"

# Change mode to "prediction" to ignore training and perform prediction on test set.
mode = "train"


model_acc_file = f"{model_dir}/acc/{model_name}.pth"
model_roc_file = f"{model_dir}/roc/{model_name}.pth"
model_prc_file = f"{model_dir}/prc/{model_name}.pth"


### Tunning parameters

Modified the tunning parameters below to train on different cases

In [4]:
####### Tunning parameters #######

# Number of epochs
n_epochs = 300

# SagPooling ratio & min score. 
# Set sp_ratio to None to disable ratio in SagPooling
sp_ratio = None
sp_min_score = None

# Enable using gpu
use_cuda = True

# Use activation function for CoAttention Layer
use_activation_fn = False

# Use ComplEx instead of RESCAL
use_ComplEx = True

# Use improved CoAttention Layer
# Could be "original" || "improved" || "multihead"
co_attention_method = "multihead"

# Use Explicit Valence
# refers to the number of chemical bonds explicitly represented for an atom 
# within a molecule in a chemical structure.
use_explicit_valence = False

# Number of GAT layers
num_GAT_layers = 4

# Number of GAT multiheads
num_GAT_multiheads = 2

#################################

### Data preprocessing

In [5]:
# Data directory
data_dir = "data"

df_drugs_smiles = pd.read_csv(f'{data_dir}/drug_smiles.csv')

DRUG_TO_INDX_DICT = {drug_id: indx for indx, drug_id in enumerate(df_drugs_smiles['drug_id'])}

# drug_id_mol_graph_tup (list of tuple): Contains drug information of all drugs as a list of tuples.
# each tuple has the following elements:
# - id: drug ID provided in drug_smiles.csv
# - mol_graph: rdkit Mol object extracted from Smiles string
# note: smiles object only represent the connectivity of atoms in a molecule in a text form.
#       where Rdkit Mol object capture a complete graph-based representation of molecule.
#       making it easier for feature extraction (atomic features, bond features, ...)
drug_id_mol_graph_tup = [(id, Chem.MolFromSmiles(smiles.strip())) for id, smiles in zip(df_drugs_smiles['drug_id'], df_drugs_smiles['smiles'])]

# ATOM_MAX_NUM (int): Max number of atoms in every drugs in drug_id_mol_graph_tup list.
ATOM_MAX_NUM = np.max([m[1].GetNumAtoms() for m in drug_id_mol_graph_tup])

# AVAILABLE_ATOM_SYMBOLS (list of str): Contains all symbols of all atoms of each drugs in drug_id_mol_graph_tup list.
# atoms have symbols like Ag, Na, Br, ...
AVAILABLE_ATOM_SYMBOLS = list({a.GetSymbol() for a in itertools.chain.from_iterable(m[1].GetAtoms() for m in drug_id_mol_graph_tup)})

# AVAILABLE_ATOM_DEGREES (list of int): Contains all degree of all atoms of all drugs in drug_id_mol_graph_tup list
# an atom degree is the total number of connections between atoms (Simply it's degree of nodes).
AVAILABLE_ATOM_DEGREES = list({a.GetDegree() for a in itertools.chain.from_iterable(m[1].GetAtoms() for m in drug_id_mol_graph_tup)})

# AVAILABLE_ATOM_TOTAL_HS (list of int): Contains all Hydrogens attached to the atoms of all drugs in drug_id_mol_graph_tup list
AVAILABLE_ATOM_TOTAL_HS = list({a.GetTotalNumHs() for a in itertools.chain.from_iterable(m[1].GetAtoms() for m in drug_id_mol_graph_tup)})

# max_valence (int): A maximum number of Implicit Valence in all atoms of all drugs. Minimum = 9.
# implicit valance: describes the "wants" of complete bonds. For eg. C=0 wants to complete 4 bonds,
#                   because C can form 4 bonds with other atoms or molecule
#                   Tthen Implicit Valance = 2.
max_valence = max(a.GetImplicitValence() for a in itertools.chain.from_iterable(m[1].GetAtoms() for m in drug_id_mol_graph_tup))
max_valence = max(max_valence, 9)

# AVAILABLE_ATOM_VALENCE (NumPy array of int): generates a NumPy array of integers, starting from 0 and ending at value. 
# the result is an array containing all integers from 0 up to max_valence (but not including max_valence).
AVAILABLE_ATOM_VALENCE = np.arange(max_valence + 1)

# MAX_ATOM_FC (int): maximum absolute formal charge of all atoms in a list of molecular structures.
# formal Charge: Provide a way to estimate ad track electron distribution within a molecule.
MAX_ATOM_FC = abs(np.max([a.GetFormalCharge() for a in itertools.chain.from_iterable(m[1].GetAtoms() for m in drug_id_mol_graph_tup)]))
MAX_ATOM_FC = MAX_ATOM_FC if MAX_ATOM_FC else 0

# MAX_RADICAL_ELC (int): calculates the maximum absolute number of radical electrons found across all atoms in a list of molecular structures.
# radical electrons: this is important feature, play importan participation in reactions. 
#                    this is the key in many chemical.
MAX_RADICAL_ELC = abs(np.max([a.GetNumRadicalElectrons() for a in itertools.chain.from_iterable(m[1].GetAtoms() for m in drug_id_mol_graph_tup)]))
MAX_RADICAL_ELC = MAX_RADICAL_ELC if MAX_RADICAL_ELC else 0


def one_of_k_encoding_unk(x, allowable_set):
    """
    Get all matched elements in allowable set.

    Args:
        x (any): element to check.
        allowable_set (list of any): List of any elements are allowed.

    Returns:
        list of any: Return list of any elements in allowable_set equal to x. 
            If x not in allowable_set then x = last element of the allowable set.
    """
    if x not in allowable_set:
        # This would be Unknown
        x = allowable_set[-1]
    return list(map(lambda s: x == s, allowable_set))

def atom_features(atom,
                explicit_H=True,
                use_chirality=False,
                use_explicit_valence=False):
    """
    Get all features of an atom.

    Args:
        atom (Atom): atom to be performed features extractation.
        explicit_H (boolean): include Explicit Hydrogens. Default = True.
        use_chirality (boolean): include Chirality feature. Default = False.
        use_explicit_valence (boolean): include Explicit Valence feature. Default = True.

    Returns:
        NumPy array: Represents an one hot encoding feature vector of an atom.
    """

    results = one_of_k_encoding_unk(
        atom.GetSymbol(),
        ['C','N','O', 'S','F','Si','P', 'Cl','Br','Mg','Na','Ca','Fe','As','Al','I','B','V','K','Tl',
            'Yb','Sb','Sn','Ag','Pd','Co','Se','Ti','Zn','H', 'Li','Ge','Cu','Au','Ni','Cd','In',
            'Mn','Zr','Cr','Pt','Hg','Pb','Unknown'
        ]) + [atom.GetDegree()/10, atom.GetImplicitValence(), 
                atom.GetFormalCharge(), atom.GetNumRadicalElectrons()] + \
                one_of_k_encoding_unk(atom.GetHybridization(), [
                Chem.rdchem.HybridizationType.SP, Chem.rdchem.HybridizationType.SP2,
                Chem.rdchem.HybridizationType.SP3, Chem.rdchem.HybridizationType.
                                    SP3D, Chem.rdchem.HybridizationType.SP3D2
                ]) + [atom.GetIsAromatic()]
    # in case of explicit hydrogen(QM8, QM9), avoid calling `GetTotalNumHs`
    if explicit_H:
        results = results + [atom.GetTotalNumHs()]

    # chilarity: chilarity molecule can have very different properties and biological effects depending on
    # their orientation. This influence how the molecule interact with biological systems or other chemical entities.
    if use_chirality:
        try:
            results = results + one_of_k_encoding_unk(
            atom.GetProp('_CIPCode'),
            ['R', 'S']) + [atom.HasProp('_ChiralityPossible')]
        except:
            results = results + [False, False
                            ] + [atom.HasProp('_ChiralityPossible')]

    if use_explicit_valence:
        results = results + [atom.GetExplicitValence()]

    results = np.array(results).astype(np.float32)

    return torch.from_numpy(results)

def get_mol_edge_list_and_feat_mtx(mol_graph):
    """
    Get molecule edge list and features matrix

    Args:
        mol_graph (Mol): rdkit mol object of a drug.

    Returns:
        tuple: edge list and features matrix
    """
    # features (list of tuple): Contains drug id and features of all atom as a list of tuples.
    # each tuple has the following elements:
    # - index: index of an atom.
    # - features (Numpy array): features of an atom
    features = [(atom.GetIdx(), atom_features(atom)) for atom in mol_graph.GetAtoms()]

    # to make sure that the feature matrix is aligned according to the idx of the atom
    features.sort()
    _, features = zip(*features)
    features = torch.stack(features)

    # GetBonds(): retrive all bonds in the Mol object. Each bond represent connections between 2 atoms.
    # for each bond, retrieve the indices of 2 connected by a bond to form a tuple (start_atom_index, end_atom_index)
    # representing edge between 2 atoms.
    # convert to a Torch Tensor (2D tensor) of a matrix.
    # each row represent an edge, the first column and second column being the indices of the bonded atoms.
    edge_list = torch.LongTensor([(b.GetBeginAtomIdx(), b.GetEndAtomIdx()) for b in mol_graph.GetBonds()])
    undirected_edge_list = torch.cat([edge_list, edge_list[:, [1, 0]]], dim=0) if len(edge_list) else edge_list
    
    return undirected_edge_list.T, features


# MOL_EDGE_LIST_FEAT_MTX (dict): each field in the dict is (drug_id: (edge_list, feature_matrix))
# edge_list (shape: (num_edges, 2)): in the mol object, retrive all bonds. "bonds" means the connection between atoms.
#           list of edges tuple(atom1 index, atom2 index).
# undirected_edge_list (shape: (2, 2 * num_edges)): since molecular bonds are typically undirected, we make each directed edge bidirectional.
#                      first row ontains the start nodes of each edge, second row contains the end nodes of each edge.
# undirected_edge_list.T (shape (2*num_edges, 2)): 2 columns, first col is start node(atom1) and second one is end node(atom2).

MOL_EDGE_LIST_FEAT_MTX = {drug_id: get_mol_edge_list_and_feat_mtx(mol) 
                                for drug_id, mol in drug_id_mol_graph_tup}
MOL_EDGE_LIST_FEAT_MTX = {drug_id: mol for drug_id, mol in MOL_EDGE_LIST_FEAT_MTX.items() if mol is not None}

# TOTAL_ATOM_FEATS (int): total features of an atom.
TOTAL_ATOM_FEATS = (next(iter(MOL_EDGE_LIST_FEAT_MTX.values()))[1].shape[-1])


##### DDI statistics and counting #######
df_all_pos_ddi = pd.read_csv(f'{data_dir}/ddis.csv')
all_pos_tup = [(h, t, r) for h, t, r in zip(df_all_pos_ddi['d1'], df_all_pos_ddi['d2'], df_all_pos_ddi['type'])]

# ALL_DRUG_IDS (Numpy array of int): list of all drug ids
ALL_DRUG_IDS, _ = zip(*drug_id_mol_graph_tup)
ALL_DRUG_IDS = np.array(list(set(ALL_DRUG_IDS)))

# ALL_TRUE_H_WITH_TR (dict): all drug 1 with drug 2 and their relationship
ALL_TRUE_H_WITH_TR = defaultdict(list)

# ALL_TRUE_T_WITH_HR (dict): all drug 2 with drug 1 and their relationship
ALL_TRUE_T_WITH_HR = defaultdict(list)

# FREQ_REL (dict): calculate number of relationships appeared.
FREQ_REL = defaultdict(int)

# ALL_H_WITH_R (dict): mark the dict relationships with head drug.
ALL_H_WITH_R = defaultdict(dict)

# ALL_T_WITH_R (dict): mark the dict relationships with tail drug.
ALL_T_WITH_R = defaultdict(dict)

# ALL_TAIL_PER_HEAD (dict): freaquent of relationships / (length of list drug tail in relationships "r")
ALL_TAIL_PER_HEAD = {}

# ALL_HEAD_PER_TAIL (dict): freaquent of relationships / (length of list drug head in relationships "r")
ALL_HEAD_PER_TAIL = {}

for h, t, r in all_pos_tup:
    ALL_TRUE_H_WITH_TR[(t, r)].append(h)
    ALL_TRUE_T_WITH_HR[(h, r)].append(t)
    FREQ_REL[r] += 1.0
    ALL_H_WITH_R[r][h] = 1
    ALL_T_WITH_R[r][t] = 1

for t, r in ALL_TRUE_H_WITH_TR:
    ALL_TRUE_H_WITH_TR[(t, r)] = np.array(list(set(ALL_TRUE_H_WITH_TR[(t, r)])))
for h, r in ALL_TRUE_T_WITH_HR:
    ALL_TRUE_T_WITH_HR[(h, r)] = np.array(list(set(ALL_TRUE_T_WITH_HR[(h, r)])))

for r in FREQ_REL:
    ALL_H_WITH_R[r] = np.array(list(ALL_H_WITH_R[r].keys()))
    ALL_T_WITH_R[r] = np.array(list(ALL_T_WITH_R[r].keys()))
    ALL_HEAD_PER_TAIL[r] = FREQ_REL[r] / len(ALL_T_WITH_R[r])
    ALL_TAIL_PER_HEAD[r] = FREQ_REL[r] / len(ALL_H_WITH_R[r])

#######    ****** ###############

class DrugDataset(Dataset):
    def __init__(self, tri_list, ratio=1.0,  neg_ent=1, disjoint_split=True, shuffle=True):
        ''''disjoint_split: Consider whether entities should appear in one and only one split of the dataset
        ''' 
        self.neg_ent = neg_ent
        self.tri_list = []
        self.ratio = ratio

        for h, t, r, *_ in tri_list:
            if ((h in MOL_EDGE_LIST_FEAT_MTX) and (t in MOL_EDGE_LIST_FEAT_MTX)):
                self.tri_list.append((h, t, r))

        if disjoint_split:
            d1, d2, *_ = zip(*self.tri_list)
            self.drug_ids = np.array(list(set(d1 + d2)))
        else:
            self.drug_ids = ALL_DRUG_IDS

        self.drug_ids = np.array([id for id in self.drug_ids if id in MOL_EDGE_LIST_FEAT_MTX])
        
        if shuffle:
            random.shuffle(self.tri_list)
        limit = math.ceil(len(self.tri_list) * ratio)
        self.tri_list = self.tri_list[:limit]

    def __len__(self):
        return len(self.tri_list)
    
    def __getitem__(self, index):
        return self.tri_list[index]

    def collate_fn(self, batch):
        """
        This function is for createing positive and negative sampling.
        """
        pos_rels = []
        pos_h_samples = []
        pos_t_samples = []
        neg_rels = []
        neg_h_samples = []
        neg_t_samples = []

        for h, t, r in batch:
            pos_rels.append(r)
            h_data = self.__create_graph_data(h)
            t_data = self.__create_graph_data(t)
            pos_h_samples.append(h_data)
            pos_t_samples.append(t_data)

            neg_heads, neg_tails = self.__normal_batch(h, t, r, self.neg_ent)

            for neg_h in neg_heads:
                neg_rels.append(r)
                neg_h_samples.append(self.__create_graph_data(neg_h))
                neg_t_samples.append(t_data)

            for neg_t in neg_tails:
                neg_rels.append(r)
                neg_h_samples.append(h_data)
                neg_t_samples.append(self.__create_graph_data(neg_t))

        pos_h_samples = Batch.from_data_list(pos_h_samples)
        pos_t_samples = Batch.from_data_list(pos_t_samples)
        pos_rels = torch.LongTensor(pos_rels)
        pos_tri = (pos_h_samples, pos_t_samples, pos_rels)

        neg_h_samples = Batch.from_data_list(neg_h_samples)
        neg_t_samples = Batch.from_data_list(neg_t_samples)
        neg_rels = torch.LongTensor(neg_rels)
        neg_tri = (neg_h_samples, neg_t_samples, neg_rels)

        return pos_tri, neg_tri
            
    def __create_graph_data(self, id):
        edge_index = MOL_EDGE_LIST_FEAT_MTX[id][0]
        features = MOL_EDGE_LIST_FEAT_MTX[id][1]

        return Data(x=features, edge_index=edge_index)

    def __corrupt_ent(self, other_ent, r, other_ent_with_r_dict, max_num=1):
        corrupted_ents = []
        current_size = 0
        while current_size < max_num:
            candidates = np.random.choice(self.drug_ids, (max_num - current_size) * 2)
            mask = np.isin(candidates, other_ent_with_r_dict[(other_ent, r)], assume_unique=True, invert=True)
            corrupted_ents.append(candidates[mask])
            current_size += len(corrupted_ents[-1])
        
        if corrupted_ents != []:
            corrupted_ents = np.concatenate(corrupted_ents)

        return np.asarray(corrupted_ents[:max_num])
        
    def __corrupt_head(self, t, r, n=1):
        return self.__corrupt_ent(t, r, ALL_TRUE_H_WITH_TR, n)

    def __corrupt_tail(self, h, r, n=1):
        return self.__corrupt_ent(h, r, ALL_TRUE_T_WITH_HR, n)
    
    def __normal_batch(self, h, t, r, neg_size):
        neg_size_h = 0
        neg_size_t = 0
        prob = ALL_TAIL_PER_HEAD[r] / (ALL_TAIL_PER_HEAD[r] + ALL_HEAD_PER_TAIL[r])
        for i in range(neg_size):
            if random.random() < prob:
                neg_size_h += 1
            else:
                neg_size_t +=1
        
        return (self.__corrupt_head(t, r, neg_size_h),
                self.__corrupt_tail(h, r, neg_size_t))  


class DrugDataLoader(DataLoader):
    # Each epoch will generate random batches of data with the specified batch_size
    def __init__(self, data, **kwargs):
        super().__init__(data, collate_fn=data.collate_fn, **kwargs)



  return undirected_edge_list.T, features


### Parameters

In [6]:
# Total atom features
n_atom_feats = TOTAL_ATOM_FEATS
# Not use
n_atom_hid = 64
# Total interactions information in the Interaction_information.csv
rel_total = 86
lr = 1e-2
weight_decay = 5e-4
neg_samples = 1
# Represents the number of samples (or graph instances) loaded in each batch during the training process.
batch_size = 1024
data_size_ratio = 1

# Knowledge graph embeding dimension
kge_dim = 64

device = "cuda" if torch.cuda.is_available() and use_cuda else "cpu"

print(device)
print(f"Epochs: {n_epochs}")
print(f"Total of atom features: {TOTAL_ATOM_FEATS}")

cuda
Epochs: 300
Total of atom features: 55


### Load the data

In [None]:
df_ddi_train = pd.read_csv(f"{data_dir}/ddi_training.csv")
df_ddi_val = pd.read_csv(f"{data_dir}/ddi_validation.csv")
df_ddi_test = pd.read_csv(f"{data_dir}/ddi_test.csv")


train_tup = [
    (h, t, r)
    for h, t, r in zip(df_ddi_train["d1"], df_ddi_train["d2"], df_ddi_train["type"])
]
val_tup = [
    (h, t, r) for h, t, r in zip(df_ddi_val["d1"], df_ddi_val["d2"], df_ddi_val["type"])
]
test_tup = [
    (h, t, r)
    for h, t, r in zip(df_ddi_test["d1"], df_ddi_test["d2"], df_ddi_test["type"])
]

train_data = DrugDataset(train_tup, ratio=data_size_ratio, neg_ent=neg_samples)
val_data = DrugDataset(val_tup, ratio=data_size_ratio, disjoint_split=False)
test_data = DrugDataset(test_tup, disjoint_split=False)

print(
    f"Training with {len(train_data)} samples, validating with {len(val_data)}, and testing with {len(test_data)}"
)

train_data_loader = DrugDataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data_loader = DrugDataLoader(val_data, batch_size=batch_size * 3)
test_data_loader = DrugDataLoader(test_data, batch_size=batch_size * 3)


In [7]:
def print_tunning_parameters():
    """
    Print tunning parameters
    """
    print()
    print("####### Tunning parameters #######")
    print()
    
    print("n_epochs =", n_epochs)
    print("use_cuda =", use_cuda)
    print()
    print("num_GAT_layers = ", num_GAT_layers)
    print("num_GAT_multiheads = ", num_GAT_multiheads)
    print()
    print("sp_ratio =", sp_ratio)
    print("sp_min_score =", sp_min_score)
    print()
    print("use_explicit_valence =", use_explicit_valence)
    print()
    print("use_activation_fn =", use_activation_fn)
    print()
    print("use_ComplEx =", use_ComplEx)
    print()
    print("co_attention_method =", co_attention_method)
    
    print()
    print("#################################")
    print()


### Layers

This section contains custom layers can be used in model. `CoAttentionLayer` and `RESCAL` are from the original paper. Our improvment layers are `MultiheadCoAttentionLayer` and `ComplEx`

In [8]:
class CoAttentionLayer(nn.Module):
    """
    
    """
    def __init__(self, n_features, use_activation_fn=True):
        super().__init__()
        self.n_features = n_features
        self.w_q = nn.Parameter(torch.zeros(n_features, n_features // 2))
        self.w_k = nn.Parameter(torch.zeros(n_features, n_features // 2))
        self.bias = nn.Parameter(torch.zeros(n_features // 2))
        self.a = nn.Parameter(torch.zeros(n_features // 2))
        self.use_activation_fn = use_activation_fn

        nn.init.xavier_uniform_(self.w_q)
        nn.init.xavier_uniform_(self.w_k)
        nn.init.xavier_uniform_(self.bias.view(*self.bias.shape, -1))
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)
        keys = receiver @ self.w_k
        queries = attendant @ self.w_q
        # values = receiver @ self.w_v
        values = receiver

        # queries.shape = (1024, 4, 32)
        # keys.shape = (1024, 4, 32)
        e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations) @ self.a
        else:
            e_scores = e_activations @ self.a
        attentions = e_scores

        return attentions


class MultiheadCoAttentionLayer(nn.Module):
    """
    Using multi-head co-attention
    """
    def __init__(self, n_features, use_activation_fn=True, dropout=0.1, n_heads=2):
        super().__init__()
        self.n_features = n_features
        self.n_heads = n_heads

        
        self.W_q = nn.ParameterList([torch.zeros(self.n_features, self.n_features // n_heads) for _ in range(n_heads)])
        self.W_k = nn.ParameterList([torch.zeros(self.n_features, self.n_features // n_heads) for _ in range(n_heads)])
        
        self.a = nn.Parameter(torch.zeros(self.n_features))
        self.bias = nn.ParameterList([torch.zeros(self.n_features // n_heads) for _ in range(n_heads)])
        
        self.use_activation_fn = use_activation_fn

        self.dropout = nn.Dropout(dropout)

        for i in range(n_heads):
            nn.init.xavier_uniform_(self.W_q[i])
            nn.init.xavier_uniform_(self.W_k[i])
            nn.init.xavier_uniform_(self.bias[i].view(*self.bias[i].shape, -1))
        
        nn.init.xavier_uniform_(self.a.view(*self.a.shape, -1))

    def forward(self, receiver, attendant):
        # receiver.shape  = (1024, 4, 64)
        # attendant.shape = (1024, 4, 64)

        # Compute attention score for each head
        head_outputs = []
        for i in range(self.n_heads):
            keys = receiver @ self.W_k[i]
            queries = attendant @ self.W_q[i]
            # print("keys.shape = ", keys.shape)
            e_activations = queries.unsqueeze(-3) + keys.unsqueeze(-2) + self.bias[i]
            # print("e_activations.shape = ", e_activations.shape)
            head_outputs.append(e_activations)

        # Average the outputs from all heads
        # e_activations_concat.shape = (1024, 4, 4, 32)
        e_activations_concat = torch.cat(head_outputs, dim=-1)
        # print("e_activations_avg.shape = ", e_activations_avg.shape)
        
        if self.use_activation_fn:
            e_scores = torch.tanh(e_activations_concat) @ self.a
        else:
            e_scores = e_activations_concat @ self.a

        # attentions.shape = (1024, 4, 4)
        attentions = e_scores

        return attentions

class RESCAL(nn.Module):
    """
    RESCAL is like the Dismults but for matrices.
    """
    def __init__(self, n_rels, n_features):
        """
        n_rels: number of relations = 86
        n_features: kge_dim = 64
        """
        super().__init__()
        self.n_rels = n_rels
        self.n_features = n_features
        # Embedding layer
        self.rel_emb = nn.Embedding(self.n_rels, n_features * n_features)
        #  Initializes the embedding weights with the Xavier uniform distribution, which helps maintain the scale of gradients during training
        nn.init.xavier_uniform_(self.rel_emb.weight)

    def forward(self, heads, tails, rels, alpha_scores):
        rels = self.rel_emb(rels)
        rels = F.normalize(rels, dim=-1)
        heads = F.normalize(heads, dim=-1)
        tails = F.normalize(tails, dim=-1)
        # print(rels.shape)
        # Convert shape (1024, 4096) to (1024, 64, 64) for dot product
        rels = rels.view(-1, self.n_features, self.n_features)
        # print(rels.shape)
        # (1024, 4, 64) @ (1024, 64, 64) = (1024, 4, 64) @ (1024, 64, 4) = (1024, 4, 4)
        scores = heads @ rels @ tails.transpose(-2, -1)

        # alpha_scores.shape = (1024, 4, 4)
        # scores.shape = (1024, 4, 4)
        if alpha_scores is not None:
            scores = alpha_scores * scores
        # print(scores.shape)
        
        # sum the last 2 dimensions
        scores = scores.sum(dim=(-2, -1))
        
        # print(scores.shape)
        # Shape(1024,)
        return scores

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_rels}, {self.rel_emb.weight.shape})"



class ComplEx(nn.Module):
    """
    ComplEx method but for matrices
    """
    def __init__(self, n_rels, n_features):
        super().__init__()
        self.n_rels = n_rels
        self.n_features = n_features
    
        # Relation embeddings are also complex
        self.rel_real = nn.Embedding(self.n_rels, (self.n_features // 2) * (self.n_features // 2))
        self.rel_imag = nn.Embedding(self.n_rels, (self.n_features // 2) * (self.n_features // 2))
        
        # Initialize embeddings
        nn.init.xavier_uniform_(self.rel_real.weight)
        nn.init.xavier_uniform_(self.rel_imag.weight)

    def forward(self, heads, tails, rels, alpha_scores=None):
        # Preprocess
        # heads = F.normalize(heads, dim=-1)
        # tails = F.normalize(tails, dim=-1)
        
        r_real, r_imag = self.rel_real(rels), self.rel_imag(rels)
        r_real = F.normalize(r_real, dim=-1)
        r_imag = F.normalize(r_imag, dim=-1)
        # print(r_real.shape)
        r_real = r_real.view(-1, self.n_features // 2, self.n_features // 2)
        r_imag = r_imag.view(-1, self.n_features // 2, self.n_features // 2)
        # print(r_real.shape)
        # Split heads and tails to imaginary parts
        h_real, h_imag = heads[..., :self.n_features // 2], heads[..., self.n_features // 2:]
        t_real, t_imag = tails[..., :self.n_features // 2], heads[..., self.n_features // 2:]

        h_real, h_imag = F.normalize(h_real, dim=-1), F.normalize(h_imag, dim=-1)
        t_real, t_imag = F.normalize(t_real, dim=-1), F.normalize(t_imag, dim=-1)

        # ComplEx scoring functionn
        first_part_score = h_real @ r_real @ t_real.transpose(-2, -1)
        second_part_score = h_real @ r_imag @ t_imag.transpose(-2, -1)
        third_part_score = h_imag @ r_real @ t_imag.transpose(-2, -1)
        fourth_part_score = h_imag @ r_imag @ t_real.transpose(-2, -1)

        scores = first_part_score + second_part_score + third_part_score + fourth_part_score
        
        # If alpha_scores is provided, apply it
        if alpha_scores is not None:
            scores = alpha_scores * scores

        scores = scores.sum(dim=(-2, -1))
        
        return scores

    def __repr__(self):
        return f"{self.__class__.__name__}({self.n_rels}, {self.rel_real.weight.shape}, {self.rel_imag.weight.shape})"


### Model

In [9]:
class SSI_DDI(nn.Module):
    def __init__(
        self,
        in_features,
        hidd_dim,
        kge_dim,
        rel_total,
        heads_out_feat_params,
        blocks_params,
        sp_ratio,
        use_activation_fn,
        use_ComplEx,
        sp_min_score,
        co_attention_method,
    ):
        """
        blocks_params: list of number layers for multi-head attentions
        """
        super().__init__()
        self.in_features = in_features
        # not using this one
        self.hidd_dim = hidd_dim
        self.rel_total = rel_total
        self.kge_dim = kge_dim
        self.n_blocks = len(blocks_params)

        self.initial_norm = LayerNorm(self.in_features)
        self.blocks = []
        self.use_activation_fn = use_activation_fn
        self.use_ComplEx = use_ComplEx
        # Layer normalization list
        self.net_norms = ModuleList()
        for i, (head_out_feats, n_heads) in enumerate(
            zip(heads_out_feat_params, blocks_params)
        ):
            block = SSI_DDI_Block(
                n_heads, in_features, head_out_feats, final_out_feats=self.hidd_dim, sp_ratio=sp_ratio, sp_min_score=sp_min_score
            )
            self.add_module(f"block{i}", block)
            self.blocks.append(block)
            self.net_norms.append(LayerNorm(head_out_feats * n_heads))
            in_features = head_out_feats * n_heads

        if co_attention_method == "multihead":
            self.co_attention = MultiheadCoAttentionLayer(self.kge_dim, self.use_activation_fn)
        elif co_attention_method == "improved":
            self.co_attention = CoAttentionLayerImproved(self.kge_dim, self.use_activation_fn)
        else:
            self.co_attention = CoAttentionLayer(self.kge_dim, self.use_activation_fn)
            
        if self.use_ComplEx:
            self.KGE = ComplEx(self.rel_total, self.kge_dim)
        else:
            self.KGE = RESCAL(self.rel_total, self.kge_dim)

    def forward(self, triples):
        h_data, t_data, rels = triples

        h_data.x = self.initial_norm(h_data.x, h_data.batch)
        t_data.x = self.initial_norm(t_data.x, t_data.batch)

        repr_h = []
        repr_t = []

        for i, block in enumerate(self.blocks):
            out1, out2 = block(h_data), block(t_data)

            h_data = out1[0]
            t_data = out2[0]
            r_h = out1[1]
            r_t = out2[1]

            repr_h.append(r_h)
            repr_t.append(r_t)

            h_data.x = F.elu(self.net_norms[i](h_data.x, h_data.batch))
            t_data.x = F.elu(self.net_norms[i](t_data.x, t_data.batch))

        repr_h = torch.stack(repr_h, dim=-2)
        repr_t = torch.stack(repr_t, dim=-2)

        kge_heads = repr_h
        kge_tails = repr_t

        attentions = self.co_attention(kge_heads, kge_tails)
        # attentions = None
        scores = self.KGE(kge_heads, kge_tails, rels, attentions)

        return scores


class SSI_DDI_Block(nn.Module):
    def __init__(self, n_heads, in_features, head_out_feats, final_out_feats, sp_ratio, sp_min_score):
        """
        n_heades: number of multi-head attentions = 2
        in_features: number of features = 55 . For explicit valence use, number of features = 56.
        head_out_feats: number of out features. For 4 layers: [32, 32, 32, 32]
        sp_ratio: SAGPooling ratio
        """
        super().__init__()
        self.n_heads = n_heads
        self.in_features = in_features
        self.out_features = head_out_feats
        self.conv = GATConv(in_features, head_out_feats, n_heads)
        # SAGPooling: Ranks nodes based on self-attention scores

        if sp_ratio is None and sp_min_score is None:
            self.readout = SAGPooling(n_heads * head_out_feats, min_score=-1)
        else:
            if sp_ratio is not None:
                self.readout = SAGPooling(n_heads * head_out_feats, min_score=sp_min_score, ratio=sp_ratio)
            else:
                self.readout = SAGPooling(n_heads * head_out_feats, min_score=sp_min_score)

    def forward(self, data):
        data.x = self.conv(data.x, data.edge_index)
        # Call SAGPooling here
        # If min_score = -1 so nodes will not be filtered out, basically redudant for using the SAGPooling.
        att_x, att_edge_index, att_edge_attr, att_batch, att_perm, att_scores = (
            self.readout(data.x, data.edge_index, batch=data.batch)
        )
        # Aggregates the pooled node features (att_x) across the graph to obtain a global representation
        global_graph_emb = global_add_pool(att_x, att_batch)

        # data = max_pool_neighbor_x(data)
        return data, global_graph_emb


### Loss function

In [10]:
class SigmoidLoss(nn.Module):
    def __init__(self, adv_temperature=None):
        super().__init__()
        self.adv_temperature = adv_temperature

    def forward(self, p_scores, n_scores):
        if self.adv_temperature:
            weights = F.softmax(self.adv_temperature * n_scores, dim=-1).detach()
            n_scores = weights * n_scores
        p_loss = -F.logsigmoid(p_scores).mean()
        n_loss = -F.logsigmoid(-n_scores).mean()

        return (p_loss + n_loss) / 2, p_loss, n_loss


### Helper functions

This section contains compute functions for the results or metrics

In [12]:
def do_compute(model, batch, device, training=True):
    """
    *batch: (pos_tri, neg_tri)
    *pos/neg_tri: (batch_h, batch_t, batch_r)
    """
    probas_pred, ground_truth = [], []
    pos_tri, neg_tri = batch

    pos_tri = [tensor.to(device=device) for tensor in pos_tri]
    p_score = model(pos_tri)
    probas_pred.append(torch.sigmoid(p_score.detach()).cpu())
    ground_truth.append(np.ones(len(p_score)))

    neg_tri = [tensor.to(device=device) for tensor in neg_tri]
    n_score = model(neg_tri)
    probas_pred.append(torch.sigmoid(n_score.detach()).cpu())
    ground_truth.append(np.zeros(len(n_score)))

    probas_pred = np.concatenate(probas_pred)
    ground_truth = np.concatenate(ground_truth)

    return p_score, n_score, probas_pred, ground_truth


def do_compute_metrics(probas_pred, target):

    pred = (probas_pred >= 0.5).astype(np.int64)

    acc = metrics.accuracy_score(target, pred)
    auc_roc = metrics.roc_auc_score(target, probas_pred)
    f1_score = metrics.f1_score(target, pred)

    p, r, t = metrics.precision_recall_curve(target, probas_pred)
    auc_prc = metrics.auc(r, p)

    return acc, auc_roc, auc_prc



def export_metrics(train_metrics, val_metrics, epoch):
    train_metrics_dir = "train_metrics"
    metrics_file = f"{train_metrics_dir}/{model_name}.csv"
    train_loss, train_acc, train_auc_roc, train_auc_prc = train_metrics
    val_loss, val_acc, val_auc_roc, val_auc_prc = val_metrics

    data = [epoch, train_loss, train_acc, train_auc_roc, train_auc_prc, val_loss, val_acc, val_auc_roc, val_auc_prc]
    header = ["epoch", "train_loss", "train_acc", "train_auc_roc", "train_auc_prc", "val_loss", "val_acc", "val_auc_roc", "val_auc_prc"]
    
    if epoch == 1:
        with open(metrics_file, 'w', newline='') as file:
            writer = csv.writer(file)
            # Write the header
            writer.writerow(header)
            # Write the data rows
            writer.writerow(data)
    else:
        with open(metrics_file, 'a', newline='') as file:
            writer = csv.writer(file)
            # Write the data to the file
            writer.writerow(data)
    

def save_model(best, current, met_type):
    model_file = model_prc_file
    if met_type == "acc":
        model_file = model_acc_file
    elif met_type == "roc":
        model_file = model_roc_file
        
    if best < current:
        print(f"Saving model {met_type}")
        best = current
        torch.save(model, model_file)
    return best
    

### Training process

In [15]:
def train(
    model,
    train_data_loader,
    val_data_loader,
    loss_fn,
    optimizer,
    n_epochs,
    device,
    scheduler=None,
):
    print("Starting training at:", datetime.today())
    print("Device:", device)
    print_tunning_parameters()
    best_val_prc = 0
    best_val_acc = 0
    best_val_roc = 0
    for i in range(1, n_epochs + 1):
        start = time.time()
        train_loss = 0
        train_loss_pos = 0
        train_loss_neg = 0
        val_loss = 0
        val_loss_pos = 0
        val_loss_neg = 0
        train_probas_pred = []
        train_ground_truth = []
        val_probas_pred = []
        val_ground_truth = []

        for batch in train_data_loader:
            # print(len(batch))
            model.train()
            p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device)
            train_probas_pred.append(probas_pred)
            train_ground_truth.append(ground_truth)
            loss, loss_p, loss_n = loss_fn(p_score, n_score)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            train_loss += loss.item() * len(p_score)
        train_loss /= len(train_data)

        with torch.no_grad():
            train_probas_pred = np.concatenate(train_probas_pred)
            train_ground_truth = np.concatenate(train_ground_truth)

            train_acc, train_auc_roc, train_auc_prc = do_compute_metrics(
                train_probas_pred, train_ground_truth
            )

            for batch in val_data_loader:
                model.eval()
                p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device)
                val_probas_pred.append(probas_pred)
                val_ground_truth.append(ground_truth)
                loss, loss_p, loss_n = loss_fn(p_score, n_score)
                val_loss += loss.item() * len(p_score)

            val_loss /= len(val_data)
            val_probas_pred = np.concatenate(val_probas_pred)
            val_ground_truth = np.concatenate(val_ground_truth)
            val_acc, val_auc_roc, val_auc_prc = do_compute_metrics(
                val_probas_pred, val_ground_truth
            )
            
            # Save model if this is the best so far
            best_val_prc = save_model(best_val_prc, val_auc_prc, "prc")
            best_val_acc = save_model(best_val_acc, val_acc, "acc")
            best_val_roc = save_model(best_val_roc, val_auc_roc, "roc")

        if scheduler:
            # print('scheduling')
            scheduler.step()

        # Exporting metrics for later plots
        train_metrics = (train_loss, train_acc, train_auc_roc, train_auc_prc)
        val_metrics = (val_loss, val_acc, val_auc_roc, val_auc_prc)
        export_metrics(train_metrics, val_metrics, i)
        
        print(
            f"Epoch: {i} ({time.time() - start:.4f}s), train_loss: {train_loss:.4f}, val_loss: {val_loss:.4f},"
            f" train_acc: {train_acc:.4f}, val_acc:{val_acc:.4f}"
        )
        print(
            f"\t\ttrain_roc: {train_auc_roc:.4f}, val_roc: {val_auc_roc:.4f}, train_auprc: {train_auc_prc:.4f}, val_auprc: {val_auc_prc:.4f}"
        )

    return model


def predict(model, test_data_loader, device):
    print('Starting predicting at', datetime.today())
    print('Device', device)

    test_probas_pred = []
    test_ground_truth = []

    # Switch to evaluation mode
    model.eval()

    with torch.no_grad():  # No need to calculate gradients during testing
        for batch in test_data_loader:
            # Get predictions and ground truth for the batch
            p_score, n_score, probas_pred, ground_truth = do_compute(model, batch, device, training=False)

            # Append the predictions and ground truths
            test_probas_pred.append(probas_pred)
            test_ground_truth.append(ground_truth)


    # Concatenate the results for the entire test dataset
    test_probas_pred = np.concatenate(test_probas_pred)
    test_ground_truth = np.concatenate(test_ground_truth)

    # Calculate the metrics for the test dataset
    test_acc, test_auc_roc, test_auc_prc = do_compute_metrics(test_probas_pred, test_ground_truth)

    print(f'Test Accuracy: {test_acc:.4f}')
    print(f'Test ROC AUC: {test_auc_roc:.4f}')
    print(f'Test PRC AUC: {test_auc_prc:.4f}')

In [17]:
heads_out_feat_params = []
block_params = []

for _ in range(num_GAT_layers):
    heads_out_feat_params.append(kge_dim // 2)
    block_params.append(num_GAT_multiheads)

if mode == "train":
    model = SSI_DDI(
        n_atom_feats,
        n_atom_hid,
        kge_dim,
        rel_total,
        heads_out_feat_params=heads_out_feat_params,
        blocks_params=block_params,
        sp_ratio=sp_ratio,
        use_activation_fn=use_activation_fn,
        use_ComplEx=use_ComplEx,
        sp_min_score=sp_min_score,
        co_attention_method=co_attention_method,
    )
    loss = SigmoidLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
    scheduler = optim.lr_scheduler.LambdaLR(optimizer, lambda epoch: 0.96 ** (epoch))
    print(model)
    model.to(device=device)

    # Train
    train(
        model,
        train_data_loader,
        val_data_loader,
        loss,
        optimizer,
        n_epochs,
        device,
        scheduler,
    )

SSI_DDI(
  (initial_norm): LayerNorm(55, affine=True, mode=graph)
  (net_norms): ModuleList(
    (0-3): 4 x LayerNorm(64, affine=True, mode=graph)
  )
  (block0): SSI_DDI_Block(
    (conv): GATConv(55, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block1): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block2): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (block3): SSI_DDI_Block(
    (conv): GATConv(64, 32, heads=2)
    (readout): SAGPooling(GraphConv, 64, min_score=-1, multiplier=1.0)
  )
  (co_attention): MultiheadCoAttentionLayer(
    (W_q): ParameterList(
        (0): Parameter containing: [torch.float32 of size 64x32]
        (1): Parameter containing: [torch.float32 of size 64x32]
    )
    (W_k): ParameterList(
        (0): Parameter containing: [torch.float32 of size 64

### Predict results

In [None]:
# Predict
model = torch.load(model_prc_file)
print(model)
model.to(device=device)
predict(model, test_data_loader, device)