In [1]:
import torch
import numpy as np
import random

from torch_geometric.datasets import MoleculeNet
from torch_geometric.data import DataLoader

from torch.optim.lr_scheduler import ReduceLROnPlateau

from misc.training_functions import train_model, test_model
from misc.plotting_functions import plot_loss
from misc.custom_dataset import CustomDataset
from model.model import MPNN

import pandas as pd

In [None]:
# Load the HIV dataset
dataset = MoleculeNet(root="./data/", name="HIV")

number_molecules_with_hiv, number_molecules_without_hiv = 0, 0
pos_indexes, neg_indexes = [], []
for index, molecule in enumerate(dataset):
	label = molecule.y.item()
	if label:
		number_molecules_with_hiv += 1
		pos_indexes.append(index)
	else:
		neg_indexes.append(index)
		number_molecules_without_hiv += 1

pos_ds = [dataset[index]for index in pos_indexes]
neg_ds = [dataset[index]for index in neg_indexes]

# Positive Molecules
train_num_molecules_with_hiv = 3*(number_molecules_with_hiv//5)
val_num_molecules_with_hiv = 1*(number_molecules_with_hiv//5)
test_num_molecules_with_hiv = number_molecules_with_hiv - train_num_molecules_with_hiv - val_num_molecules_with_hiv

train_pos_molecules = pos_ds[:train_num_molecules_with_hiv]
val_pos_molecules = pos_ds[train_num_molecules_with_hiv:train_num_molecules_with_hiv+val_num_molecules_with_hiv]
test_pos_molecules = pos_ds[train_num_molecules_with_hiv+val_num_molecules_with_hiv:]

print('Train Number of Molecules HIV Positive:', train_num_molecules_with_hiv)
print('Validation Number of Molecules HIV Positive:', val_num_molecules_with_hiv)
print('Test Number of Molecules HIV Positive:', test_num_molecules_with_hiv, '\n')

# Negative Molecules
train_num_molecules_without_hiv = 3*(number_molecules_without_hiv//5)
val_num_molecules_without_hiv = (number_molecules_without_hiv//5)
test_num_molecules_without_hiv = number_molecules_without_hiv - train_num_molecules_without_hiv - val_num_molecules_without_hiv

train_neg_molecules = neg_ds[:train_num_molecules_without_hiv]
val_neg_molecules = neg_ds[train_num_molecules_without_hiv:train_num_molecules_without_hiv+val_num_molecules_without_hiv]
test_neg_molecules = neg_ds[train_num_molecules_without_hiv+val_num_molecules_without_hiv:]

print('Train Number of Molecules HIV Negative:', train_num_molecules_without_hiv)
print('Validation Number of Molecules HIV Negative:', val_num_molecules_without_hiv)
print('Test Number of Molecules HIV Negative:', test_num_molecules_without_hiv)

In [3]:
SEED=42

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

# Creating an oversampled and undersampled dataset
NUM_OVERSAMPLING = 2
UNDER_SAMPLING = 4
BATCH_SIZE = 32

undersampled_data = random.sample(train_neg_molecules, len(train_neg_molecules)//UNDER_SAMPLING)
oversampled_data = train_pos_molecules.copy()

for _ in range(NUM_OVERSAMPLING-1):
	oversampled_data.extend(train_pos_molecules)

over_len, under_len = len(oversampled_data), len(undersampled_data)

combined_data = oversampled_data + undersampled_data
train_ds = CustomDataset(combined_data)
val_ds = val_pos_molecules + val_neg_molecules
test_ds = test_pos_molecules + test_neg_molecules

def worker_init_fn(worker_id):
    np.random.seed(SEED)
    random.seed(SEED)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, worker_init_fn=worker_init_fn)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=True)

print('Number of Steps per Epoch:', len(train_loader))

In [None]:
w1 = round((over_len+under_len)/over_len, 3)
w2 = round((over_len+under_len)/under_len, 3)
print('Oversampled Weights:', w1)
print('Undersampled Weights:', w2)

weights = torch.FloatTensor([w2, w1])

In [None]:
torch.manual_seed(SEED)
np.random.seed(SEED)
random.seed(SEED)

model_path = './model/model_weights.pth'

mpnn = MPNN(
    node_dim=9,
    edge_dim=3,
    output_dim=2,
    node_embedding_dim=18,
    edge_embedding_dim=6,
    edge_num_layers=2,
    edge_hidden_dim=15,
    num_propagation_steps=6
)

pytorch_total_params = sum(p.numel() for p in mpnn.parameters())
print('Number of Parameters:', pytorch_total_params)

optimizer = torch.optim.Adam(
	params=mpnn.parameters(),
	lr=1e-4
)

scheduler = ReduceLROnPlateau(
	optimizer,
	'min',
	factor=0.1,
	min_lr=1e-6
)

train_model(
	mpnn,
	optimizer,
	scheduler=scheduler,
	train_loader=train_loader,
	val_loader=val_loader,
	model_path=model_path,
	weights=weights,
	epochs=200,
	patience=25,
	threshold=1e-4
)

In [None]:
mpnn.load_state_dict(torch.load(model_path))

test_model(mpnn, test_loader)

In [None]:
df = pd.read_csv('./history/history.csv')
plot_loss(df)