In [None]:
from DrugDrugInteraction.data import build_dataset
from DrugDrugInteraction.utils import get_stats, write_summary, write_summary_total

### Create PyTorch files
this should create .pt files in the processed folder. This might take a couple minutes

In [None]:
dataset = "ZhangDDI"  # ZhangDDI or ChChMiner

In [None]:
df = pd.read_csv(f'data/raw_data/{dataset}_train.csv', sep=",")
processed_data, dataset = build_dataset(df, "smiles_1", "smiles_2", "label")
torch.save(processed_data, f"./data/processed/{dataset}_train.pt")

df = pd.read_csv(f'data/raw_data/{dataset}_valid.csv', sep=",")
processed_data, dataset = build_dataset(df, "smiles_1", "smiles_2", "label")
torch.save(processed_data, f"./data/processed/{dataset}_valid.pt")

df = pd.read_csv(f'data/raw_data/{dataset}_test.csv', sep=",")
processed_data, dataset = build_dataset(df, "smiles_1", "smiles_2", "label")
torch.save(processed_data, f"./data/processed/{dataset}_test.pt")

### Define hyperparameters

In [None]:
lr = 0.001        # Learning rate for training the model
epochs = 10       # Number of epochs for training the model
beta = 1.0        # Hyperparameters for balance the trade-off between prediction and compression
tau = 1.0         # Temperature hyperparameter for CGIB_cont
embedder = 'CGIB' # use 'CGIB' or 'CGIB_cont'

### Load Dataset

In [None]:
print("Loading dataset...")
start = time.time()

# Load dataset
train_set = torch.load("./data/processed/{}_train.pt".format(dataset))
valid_set = torch.load("./data/processed/{}_valid.pt".format(dataset)) 
test_set = torch.load("./data/processed/{}_test.pt".format(dataset))

print("Dataset Loaded! ({:.4f} sec)".format(time.time() - start))

### CGIB Model

In [None]:
import torch
import torch.nn as nn
from torch import optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torch.nn.functional as F

from torch_geometric.nn import Set2Set

from DrugDrugInteraction.embedder import embedder
from layers import GINE
from DrugDrugInteraction.utils import create_batch_mask

from torch_scatter import scatter_mean, scatter_add, scatter_std

import time

class CGIB(nn.Module):
    """
    This the main class for CIGIN model
    """

    def __init__(self,
                device,
                node_input_dim=133,
                edge_input_dim=14,
                node_hidden_dim=300,
                edge_hidden_dim=300,
                num_step_message_passing=3,
                interaction='dot',
                num_step_set2_set=2,
                num_layer_set2set=1,
                ):
        super(CGIB, self).__init__()

        self.device = device

        self.node_input_dim = node_input_dim
        self.node_hidden_dim = node_hidden_dim
        self.edge_input_dim = edge_input_dim
        self.edge_hidden_dim = edge_hidden_dim
        self.num_step_message_passing = num_step_message_passing
        self.interaction = interaction

        self.gather = GINE(self.node_input_dim, self.edge_input_dim, 
                            self.node_hidden_dim, self.num_step_message_passing,
                            )

        self.predictor = nn.Linear(8 * self.node_hidden_dim, 1)

        self.compressor = nn.Sequential(
            nn.Linear(2 * self.node_hidden_dim, self.node_hidden_dim),
            nn.BatchNorm1d(self.node_hidden_dim),
            nn.ReLU(),
            nn.Linear(self.node_hidden_dim, 1)
            )
        
        self.solvent_predictor = nn.Linear(4 * self.node_hidden_dim, 4 * self.node_hidden_dim)
        
        self.mse_loss = torch.nn.MSELoss()

        self.num_step_set2set = num_step_set2_set
        self.num_layer_set2set = num_layer_set2set
        self.set2set = Set2Set(2 * node_hidden_dim, self.num_step_set2set, self.num_layer_set2set)

        self.init_model()
    
    def init_model(self):
        for m in self.modules():
            if isinstance(m, nn.Linear):
                torch.nn.init.xavier_uniform_(m.weight.data)
                if m.bias is not None:
                    m.bias.data.fill_(0.0)
    
    def compress(self, solute_features):
        
        p = self.compressor(solute_features)
        temperature = 1.0
        bias = 0.0 + 0.0001  # If bias is 0, we run into problems
        eps = (bias - (1 - bias)) * torch.rand(p.size()) + (1 - bias)
        gate_inputs = torch.log(eps) - torch.log(1 - eps)
        gate_inputs = gate_inputs.to(self.device)
        gate_inputs = (gate_inputs + p) / temperature
        gate_inputs = torch.sigmoid(gate_inputs).squeeze()

        return gate_inputs, p
    
    def forward(self, data, bottleneck = False, test = False):
        solute = data[0]
        solvent = data[1]
        solute_len = data[2]
        solvent_len = data[3]
        # node embeddings after interaction phase
        solute_features = self.gather(solute)
        solvent_features = self.gather(solvent)

        # Add normalization
        self.solute_features = F.normalize(solute_features, dim = 1)
        self.solvent_features = F.normalize(solvent_features, dim = 1)

        # Interaction phase
        len_map = torch.sparse.mm(solute_len.t(), solvent_len)

        interaction_map = torch.mm(self.solute_features, self.solvent_features.t())
        ret_interaction_map = torch.clone(interaction_map)
        ret_interaction_map = interaction_map * len_map.to_dense()
        interaction_map = interaction_map * len_map.to_dense()

        self.solvent_prime = torch.mm(interaction_map.t(), self.solute_features)
        self.solute_prime = torch.mm(interaction_map, self.solvent_features)

        # Prediction phase
        self.solute_features = torch.cat((self.solute_features, self.solute_prime), dim=1)
        self.solvent_features = torch.cat((self.solvent_features, self.solvent_prime), dim=1)

        if test:

            _, self.importance = self.compress(self.solute_features)
            self.importance = torch.sigmoid(self.importance)

        if bottleneck:

            lambda_pos, p = self.compress(self.solute_features)
            lambda_pos = lambda_pos.reshape(-1, 1)
            lambda_neg = 1 - lambda_pos

            # Get Stats
            preserve_rate = (torch.sigmoid(p) > 0.5).float().mean()

            static_solute_feature = self.solute_features.clone().detach()
            node_feature_mean = scatter_mean(static_solute_feature, solute.batch, dim = 0)[solute.batch]
            node_feature_std = scatter_std(static_solute_feature, solute.batch, dim = 0)[solute.batch]
            # node_feature_std, node_feature_mean = torch.std_mean(static_solute_feature, dim=0)
            
            noisy_node_feature_mean = lambda_pos * self.solute_features + lambda_neg * node_feature_mean
            noisy_node_feature_std = lambda_neg * node_feature_std

            noisy_node_feature = noisy_node_feature_mean + torch.rand_like(noisy_node_feature_mean) * noisy_node_feature_std
            noisy_solute_subgraphs = self.set2set(noisy_node_feature, solute.batch)

            epsilon = 1e-7

            KL_tensor = 0.5 * scatter_add(((noisy_node_feature_std ** 2) / (node_feature_std + epsilon) ** 2).mean(dim = 1), solute.batch).reshape(-1, 1) + \
                        scatter_add((((noisy_node_feature_mean - node_feature_mean)/(node_feature_std + epsilon)) ** 2), solute.batch, dim = 0)
            KL_Loss = torch.mean(KL_tensor)
            
            # Predict Solvent
            self.solvent_features_s2s = self.set2set(self.solvent_features, solvent.batch)
            solvent_pred_loss = self.mse_loss(self.solvent_features_s2s, self.solvent_predictor(noisy_solute_subgraphs))

            # Prediction Y
            final_features = torch.cat((noisy_solute_subgraphs, self.solvent_features_s2s), 1)
            predictions = self.predictor(final_features)

            return predictions, KL_Loss, solvent_pred_loss, preserve_rate
        
        else:

            self.solute_features_s2s = self.set2set(self.solute_features, solute.batch)
            self.solvent_features_s2s = self.set2set(self.solvent_features, solvent.batch)

            final_features = torch.cat((self.solute_features_s2s, self.solvent_features_s2s), 1)
            predictions = self.predictor(final_features)

            if test:
                return torch.sigmoid(predictions), ret_interaction_map

            else:
                return predictions, ret_interaction_map

In [None]:
class CGIB_ModelTrainer(embedder):
    def __init__(self, train_df, valid_df, test_df, repeat, fold):
        embedder.__init__(self, train_df, valid_df, test_df, repeat, fold)

        self.model = CGIB(device = self.device).to(self.device)
        self.optimizer = optim.Adam(params = self.model.parameters(), lr = lry)
        self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', verbose=True)
        
    def train(self):        

        loss_function_BCE = nn.BCEWithLogitsLoss(reduction='none')

        for epoch in range(1, epochs + 1):
            self.model.train()
            self.train_loss = 0
            preserve = 0

            start = time.time()

            for bc, samples in enumerate(self.train_loader):
                self.optimizer.zero_grad()
                masks = create_batch_mask(samples)
                
                outputs, _ = self.model([samples[0].to(self.device), samples[1].to(self.device), masks[0].to(self.device), masks[1].to(self.device)])
                loss = loss_function_BCE(outputs, samples[2].reshape(-1, 1).to(self.device).float()).mean()

                # Information Bottleneck
                outputs, KL_Loss, solvent_pred_loss, preserve_rate = self.model([samples[0].to(self.device), samples[1].to(self.device), masks[0].to(self.device), masks[1].to(self.device)], bottleneck = True)
                loss += loss_function_BCE(outputs, samples[2].reshape(-1, 1).to(self.device).float()).mean()
                loss += beta * KL_Loss
                loss += beta * solvent_pred_loss

                loss.backward()
                self.optimizer.step()
                self.train_loss += loss
                preserve += preserve_rate
            
            self.epoch_time = time.time() - start

            self.model.eval()
            self.evaluate(epoch)

            self.scheduler.step(self.val_roc_score)

            # Write Statistics
            self.writer.add_scalar("stats/preservation", preserve/bc, epoch)

        self.evaluate(epoch, final = True)
        self.writer.close()
        
        return self.best_test_roc, self.best_test_ap, self.best_test_f1, self.best_test_acc

### Training

In [None]:
def summary(train_df, valid_df, test_df, repeat = 0, fold = 0):

    if embedder == 'CGIB':
        from models import CGIB_ModelTrainer
        embedder = CGIB_ModelTrainer(train_df, valid_df, test_df, repeat, fold)

    elif embedder == 'CGIB_cont':
        from models import CGIB_cont_ModelTrainer
        embedder = CGIB_cont_ModelTrainer(train_df, valid_df, test_df, repeat, fold)

    best_roc, best_ap, best_f1, best_acc = embedder.train()

    return [best_roc, best_ap, best_f1, best_acc], embedder.config_str, embedder.best_config_roc, embedder.best_config_f1


In [None]:
best_rocs, best_aps, best_f1s, best_accs = [], [], [], []

stats, config_str, _, _ = summary(train_set, valid_set, test_set)

# get Stats
best_rocs.append(stats[0])
best_aps.append(stats[1])
best_f1s.append(stats[2])
best_accs.append(stats[3])

write_summary(dataset, embedder, config_str, stats)

roc_mean, roc_std = get_stats(best_rocs)
ap_mean, ap_std = get_stats(best_aps)
f1_mean, f1_std = get_stats(best_f1s)
accs_mean, accs_std = get_stats(best_accs)

write_summary_total(dataset, embedder, config_str, [roc_mean, roc_std, ap_mean, ap_std, f1_mean, f1_std, accs_mean, accs_std])