<a href="https://colab.research.google.com/github/davidkubanek/Thesis/blob/main/model_experiments.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# CONCERTO architecture

In [1]:
# for running in colab
!pip install dgl
!pip install rdkit
!pip install torch_geometric
!pip install wandb

Collecting dgl
  Downloading dgl-1.1.2-cp310-cp310-manylinux1_x86_64.whl (6.0 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.0/6.0 MB[0m [31m47.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: dgl
Successfully installed dgl-1.1.2
Collecting rdkit
  Downloading rdkit-2023.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (29.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m29.7/29.7 MB[0m [31m64.9 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: rdkit
Successfully installed rdkit-2023.3.2
Collecting torch_geometric
  Downloading torch_geometric-2.3.1.tar.gz (661 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m661.6/661.6 kB[0m [31m10.8 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: tor

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
if not torch.cuda.is_available():
  plt.rcParams["font.family"] = "Palatino"

import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, lr_scheduler
from torch.utils.data import Dataset, DataLoader
from torch_geometric.loader import DataLoader
import wandb
import warnings
from tqdm import tqdm

# check if cuda is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

cuda


In [3]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


# Load Dataset

In [79]:
from torch_geometric.loader import DataLoader

def prepare_datalist(matrix_df, args, graph_fp=True, grover_fp=None):
    '''
    Convert matrix dataframe to a data_list with pytorch geometric graph data, fingerprints and labels.
    Inputs:
        matrix_df: dataframe of SMILES, assays and bioactivity labels
        args: arguments
        graph_fp: if True, includes graph embedding fingerprints into data_list
        grover_fp: if True, includes GROVER graph transformer embedding fingerprints into data_list
    Outputs:
        data_list: list of data objects
    '''
    # only use subset of data (assays and data points)
    assay_list = args['assay_list']
    num_assays = args['num_assays']
    assay_start = args['assay_start']
    num_data_points = args['num_data_points']

    # get binary target labels
    y = matrix_df[assay_list[assay_start:assay_start+num_assays]].values[:num_data_points]

    # get SMILES strings
    data = matrix_df['SMILES'].values[:num_data_points]

    if graph_fp is True: # add graph fingerprint
        GraphDataset = GraphDatasetClass()
        # create pytorch geometric graph data list
        data_list = GraphDataset.create_pytorch_geometric_graph_data_list_from_smiles_and_labels(data, y)
    else: # create simple data_list without graph fingerprint
      data_list = []
      for label in y:
          # construct Pytorch Geometric data object and append to data list
          data_list.append(Data(y = label.reshape(1, -1)))

    # add fingerprint data to each graph
    for i, smile in tqdm(enumerate(data), desc='Adding fingerprints...', total=len(data)):
        fp = convert_smile_to_fp_bit_string(smile)
        data_list[i].fp = fp


    # add grover fingerprint to each graph
    if grover_fp is not None:
        for i, gfp in tqdm(enumerate(grover_fp['fps'][:args['num_data_points']]), desc='Adding grover embedding...', total=len(data)):
          data_list[i].grover_fp = torch.tensor(gfp)

    print(f'Example of a graph data object: {data_list[0]}')

    return data_list

def prepare_splits(data_list, args):

    data_list = data_list[:args['num_data_points']]

    data_splits = {}
    # split into train and test
    train_dataset = [d.to(args['device']) for d in data_list[:int(len(data_list)*0.8)]]
    data_splits['test'] = [d.to(args['device']) for d in data_list[int(len(data_list)*0.8):]]

    # split into train and validation
    data_splits['val'] = train_dataset[:int(len(train_dataset)*0.25)]
    data_splits['train'] = train_dataset[int(len(train_dataset)*0.25):]

    print(f'Number of training graphs:', len(data_splits['train']))
    print(f'Number of validation graphs:', len(data_splits['val']))
    print(f'Number of test graphs:', len(data_splits['test']))
    print(f'Example of a graph data object: {data_list[0]}')

    return data_splits


def prepare_dataloader(data_splits, args):
    '''
    Get dataloader dictionary from data_list with desired batch_size
    '''
    # create data loaders
    dataloader = {}
    dataloader['train'] = DataLoader(data_splits['train'], batch_size=args['batch_size'], shuffle=True)
    dataloader['val'] = DataLoader(data_splits['val'], batch_size=args['batch_size'], shuffle=False)
    dataloader['test'] = DataLoader(data_splits['test'], batch_size=args['batch_size'], shuffle=False)

    return dataloader

def analyze_dataset(dataset, args):
    '''
    Analyze the distribution of positive classes in the dataset
    '''
    positive = []
    for i in range(len(dataset)):
        positive.append(dataset[i].y[0].sum().item())


    num_assays = args['num_assays']
    # make histogram of the number of positive
    plt.figure(figsize=(7, 4))
    # define bins
    bins = np.linspace(0, num_assays, num_assays+1)-0.5
    plt.hist(positive, bins=bins, alpha=0.5, label='train')
    num_assays = args['num_assays']
    plt.xlabel(f'# of positive hits in target vector (out of {num_assays})')
    plt.ylabel('Number of data points')
    plt.title('Histogram of positive class distribution')
    plt.show()

    # for i in range(num_assays+1):
    #     print(f'Number of data points with {i} positive targets: ', (np.array(positive) == i).sum(), f'({(np.array(positive) == i).sum()/len(positive)*100:.2f}%)')

def data_explore(dataloader):
    '''
    Explore the data
    '''
    # check proportion of positive and negative samples across each assay
    pos = torch.zeros(args['num_assays'])
    for data in dataloader:  # Iterate in batches over the training dataset
        # print('inputs:')
        # print(' x:', data.x.shape, '| y:',data.y.shape, '| fp:',data.fp.shape, '| grover:', data.grover_fp.shape)
        pos += data.y.sum(axis=0)
        #  print(data.y.sum(axis=0))
    print('# positive samples:', pos)
    print(torch.round((pos/len(dataloader.dataset)*100),decimals=2),'% are positive')



In [6]:
 import pickle
import os

#directory = '/home/ubuntu/Thesis/Thesis MSc/PubChem Data/'
directory = '/content/drive/MyDrive/Thesis/Data/'

# Specify the path where you saved the dictionary
load_path = directory + 'final/datalist_no_out.pkl'

# Load the dictionary using pickle
with open(load_path, 'rb') as f:
    data_list = pickle.load(f)


# load the assay groups
with open(directory + 'info/cell_based_high_hr.txt', 'r') as file:
    lines = file.read().splitlines()
cell_based_high_hr = list(map(str, lines))
with open(directory + 'info/cell_based_med_hr.txt', 'r') as file:
    lines = file.read().splitlines()
cell_based_med_hr = list(map(str, lines))
with open(directory + 'info/cell_based_low_hr.txt', 'r') as file:
    lines = file.read().splitlines()
cell_based_low_hr = list(map(str, lines))
with open(directory + 'info/non_cell_based_high_hr.txt', 'r') as file:
    lines = file.read().splitlines()
non_cell_based_high_hr = list(map(str, lines))
with open(directory + 'info/non_cell_based_med_hr.txt', 'r') as file:
    lines = file.read().splitlines()
non_cell_based_med_hr = list(map(str, lines))
with open(directory + 'info/non_cell_based_low_hr.txt', 'r') as file:
    lines = file.read().splitlines()
non_cell_based_low_hr = list(map(str, lines))
# load assay order
with open(directory + 'info/assay_order.txt', 'r') as f:
    assay_order = [line.strip() for line in f.readlines()]

args = {}
args['assay_order'] = assay_order


# Setup

Define hyperparams to sweep

# Models
### GCN and GCN_FP
- GCN: graph embedding followed by a final classification layer
- GCN_FP: graph + fingerprints embedding followed by a final classification layer

In [8]:
from torch.nn import Linear
import torch.nn.functional as F
from torch_geometric.nn import GCNConv
from torch_geometric.nn import global_mean_pool


class GCN(torch.nn.Module):
    '''
    Define a Graph Convolutional Network (GCN) model architecture.
    Can include 'graph' only or 'graph + fingerprints' embedding before final classification layer.
    '''
    def __init__(self, args):
        super(GCN, self).__init__()
        torch.manual_seed(12345)

        num_node_features = args['num_node_features']
        hidden_channels = args['hidden_channels']
        num_classes = args['num_assays']
        if args['model'] == 'GCN_FP':
            fp_dim = args['fp_dim']
        else:
            fp_dim = 0

        self.conv1 = GCNConv(num_node_features, hidden_channels)
        self.conv2 = GCNConv(hidden_channels, hidden_channels)
        self.conv3 = GCNConv(hidden_channels, hidden_channels)

        self.lin = Linear(hidden_channels + fp_dim, num_classes)

    def forward(self, x, edge_index, batch, fp=None):
        # 1. Obtain node embeddings
        x = self.conv1(x, edge_index)
        x = x.relu()
        x = self.conv2(x, edge_index)
        x = x.relu()
        x = self.conv3(x, edge_index)

        # 2. Readout layer
        x = global_mean_pool(x, batch)  # [batch_size, hidden_channels]

        # if also using fingerprints
        if fp is not None:
            # reshape fp to batch_size x fp_dim
            fp = fp.reshape(x.shape[0], -1)
            # concatenate graph node embeddings with fingerprint
            # print('BEFORE CONCAT x:',x.shape, 'fp:', fp.shape)
            x = torch.cat([x, fp], dim=1)
            # print('AFTER CONCAT x:',x.shape)

        # 3. Apply a final classifier
        x = F.dropout(x, p=0.1, training=self.training)
        x = self.lin(x)

        return x


### FP, GROVER and GROVER_FP
- FP: fingerprints embedding followed by a final classification layer
- GROVER: graph transformer embedding followed by a final classification layer
- GROVER_FP: graph transformer + fingerprints embedding followed by a final classification layer

In [9]:

class LinearBlock(nn.Module):
	""" basic block in an MLP, with dropout and batch norm """

	def __init__(self, in_feats, out_feats, dropout=0.1):
		super(LinearBlock, self).__init__()
		self.linear = nn.Linear(in_feats, out_feats)
		self.bn = nn.BatchNorm1d(out_feats)
		self.dropout = nn.Dropout(dropout)

	def forward(self, x):
		# ReLU activation, batch norm, dropout on layer
		return self.bn(self.dropout(F.relu(self.linear(x))))

def construct_mlp(in_dim, out_dim, hidden_dim, hidden_layers, dropout=0.1):
	"""
	Constructs an MLP with specified dimensions.
		- total number of layers = hidden_layers + 1 (the + 1 is for the output linear)
		- no activation/batch norm/dropout on output layer
	"""

	assert hidden_layers >= 1, hidden_layers
	mlp_list = []
	mlp_list.append(LinearBlock(in_dim,hidden_dim,dropout=dropout))
	for i in range(hidden_layers-1):
		mlp_list.append(LinearBlock(hidden_dim,hidden_dim,dropout=dropout))

	# no activation/batch norm/dropout on output layer
	mlp_list.append(nn.Linear(hidden_dim,out_dim))
	mlp = nn.Sequential(*mlp_list)
	return mlp

class MLP(nn.Module):
	'''
	MLP with optional Grover fingerprints.
	Customizable number of layers, hidden dimensions, and dropout.
	'''
	def __init__(self, args):

		super(MLP, self).__init__()

		self.model_type = args['model']
		self.fp_dim = args['fp_dim'] # can be 0
		self.grover_fp_dim = args['grover_fp_dim'] # can be 0
		self.hidden_dim = args['hidden_channels']
		self.output_dim = args['num_assays']
		self.num_layers = args['num_layers']
		self.dropout = args['dropout']

		assert self.model_type in ['FP','GROVER','GROVER_FP'], f'model type not supported: {self.model_type}'

		if self.model_type == 'FP':
			self.grover_fp_dim = 0
		elif self.model_type == 'GROVER':
			self.fp_dim = 0

		self.ff_layers = construct_mlp(
			self.fp_dim + self.grover_fp_dim,
			self.output_dim,
			self.hidden_dim,
			self.num_layers,
			self.dropout
		)

	def forward(self, data):


		if self.model_type == 'FP': # only fp is used
			fingerprints = data.fp
			# reshape fp to batch_size x fp_dim
			fingerprints = fingerprints.reshape(int(fingerprints.shape[0]/self.fp_dim), -1)

			output = self.ff_layers(fingerprints)

		elif self.model_type == 'GROVER': # only grover is used
			# reshape grover_fp to batch_size x grover_fp_dim
			grover_fp = data.grover_fp
			grover_fp = grover_fp.reshape(int(grover_fp.shape[0]/self.grover_fp_dim), -1)

			output = self.ff_layers(grover_fp)

		elif self.model_type == 'GROVER_FP': #grover and fp are concatenated
			fingerprints = data.fp
			# reshape fp to batch_size x fp_dim
			fingerprints = fingerprints.reshape(int(fingerprints.shape[0]/self.fp_dim), -1)
			# reshape grover_fp to batch_size x grover_fp_dim
			grover_fp = data.grover_fp
			grover_fp = grover_fp.reshape(int(grover_fp.shape[0]/self.grover_fp_dim), -1)

			output = self.ff_layers(torch.cat([fingerprints, grover_fp], dim=1))


		return output

# Training

In [62]:
from sklearn.metrics import confusion_matrix, precision_score, recall_score, f1_score, roc_auc_score
import time


class TrainManager:

    def __init__(self, dataloader, args, model=None):

        self.args = args
        self.num_assays = args['num_assays']
        self.num_node_features = args['num_node_features']
        self.hidden_channels = args['hidden_channels']

        if not model:
            # initialize model depending on model type
            if args['model'] in ['GCN','GCN_FP']:
                self.model = GCN(args)
            elif args['model'] in ['FP','GROVER','GROVER_FP']:
              self.model = MLP(args)
        else:
            self.model = model

        self.model.to(args['device'])
        print("Model is on device:", next(self.model.parameters()).device)
        total_params = sum(p.numel() for p in self.model.parameters() if p.requires_grad)
        print(f'Total number of parameters: {total_params}')

        self.dataloader = dataloader

        self.optimizer = torch.optim.Adam(self.model.parameters(), lr=args['lr'])
        self.criterion = nn.BCEWithLogitsLoss()

        self.curr_epoch = 0

        # logging
        self.eval_metrics = {}
        self.eval_metrics['loss'] = []
        self.eval_metrics['acc_train'] = []
        self.eval_metrics['acc_test'] = []
        self.eval_metrics['auc_train'] = []
        self.eval_metrics['auc_test'] = []
        self.eval_metrics['precision_train'] = []
        self.eval_metrics['precision_test'] = []
        self.eval_metrics['recall_train'] = []
        self.eval_metrics['recall_test'] = []
        self.eval_metrics['f1_train'] = []
        self.eval_metrics['f1_test'] = []


    def train(self, epochs=100, log=False, wb_log=False):
        '''
        Train the model for a given number of epochs.
        '''

        self.wb_log = wb_log

        epoch_times = []

        for epoch in range(epochs):

            self.model.train()
            cum_loss = 0
            start_time = time.time()

            # Iterate in batches over the training dataset
            for data in tqdm(self.dataloader['train'], desc=f'Epoch [{self.curr_epoch}/{epochs}]', total=int(len(self.dataloader['train'].dataset)/self.args['batch_size'])):

                # forward pass based on model type
                if self.args['model'] == 'GCN':
                    out = self.model(data.x, data.edge_index, data.batch)
                elif self.args['model'] == 'GCN_FP':
                    out = self.model(data.x, data.edge_index, data.batch, fp=data.fp)
                elif self.args['model'] in ['FP','GROVER','GROVER_FP']:
                    out = self.model(data)

                # data.y = data.y.unsqueeze(1)
                # print('data.y:',data.y.shape)
                loss = self.criterion(out, data.y[:,args['assays_idx']])  # Compute the loss. (sigmoid inherent in loss)
                loss.backward()  # Derive gradients.
                self.optimizer.step()  # Update parameters based on gradients.
                self.optimizer.zero_grad()  # Clear gradients.
                cum_loss += loss.item()

            self.eval_metrics['loss'].append(cum_loss/len(self.dataloader['train']))
            if wb_log is True:
                wandb.log({'epoch': self.curr_epoch, "loss": cum_loss/len(self.dataloader['train'])})

            epoch_time = time.time() - start_time
            epoch_times.append(epoch_time)

            if log:
                # evaluate
                acc_train, auc_train, precision_train, recall_train, f1_train = self.eval(self.dataloader['train'])
                acc_test, auc_test, precision_test, recall_test, f1_test = self.eval(self.dataloader['val'])


                self.eval_metrics['acc_train'].append(acc_train)
                self.eval_metrics['acc_test'].append(acc_test)
                self.eval_metrics['auc_train'].append(auc_train)
                self.eval_metrics['auc_test'].append(auc_test)
                self.eval_metrics['precision_train'].append(precision_train)
                self.eval_metrics['precision_test'].append(precision_test)
                self.eval_metrics['recall_train'].append(recall_train)
                self.eval_metrics['recall_test'].append(recall_test)
                self.eval_metrics['f1_train'].append(f1_train)
                self.eval_metrics['f1_test'].append(f1_test)

                if wb_log is True:
                    wandb.log({'epoch': self.curr_epoch, "AUC train": auc_train, "AUC test": auc_test, "F1 train": f1_train, "F1 test": f1_test, "Precision train": precision_train, "Precision test": precision_test, "Recall train": recall_train, "Recall test": recall_test})


                if epoch % 10 == 0:
                    print(f'Epoch: {self.curr_epoch}, Loss: {loss.item():.4f}, Train AUC: {auc_train:.4f}, Test AUC: {auc_test:.4f}')
                    print(f'                        Train F1: {f1_train:.4f}, Test F1: {f1_test:.4f}')

            self.curr_epoch += 1



        self.avg_epoch_time = np.mean(epoch_times)
        if wb_log is True:
            wandb.log({'epoch': self.curr_epoch, "avg epoch time": self.avg_epoch_time})

    def eval(self, loader):
        '''
        Evaluate the model on a given dataset (train/val/test).
        '''
        start_time = time.time()

        self.model.eval()

        # print("Model is on device for eval:", next(exp.model.parameters()).device)

        correct = 0

        gts = []
        preds = []
        with torch.no_grad():
            for data in loader:  # Iterate in batches over the training/test dataset.

                data = data.to(self.args['device'])

                # forward pass based on model type
                if self.args['model'] == 'GCN':
                    out = self.model(data.x, data.edge_index, data.batch)
                elif self.args['model'] == 'GCN_FP':
                    out = self.model(data.x, data.edge_index, data.batch, fp=data.fp)
                elif self.args['model'] in ['FP','GROVER','GROVER_FP']:
                    out = self.model(data)

                # convert out to binary
                pred = torch.round(torch.sigmoid(out))
                preds.append(torch.round(torch.sigmoid(out)).tolist())
                gts.append(data.y[:,args['assays_idx']].tolist())
                # print('pred:', pred)
                # print('data.y:', data.y)
                # print('data.y eval:',data.y.shape)
                # data.y = data.y.unsqueeze(1)
                correct += int((pred == data.y[:,args['assays_idx']]).sum())  # Check against ground-truth labels.




        preds = [b[i] for b in preds for i in range(len(b))]
        gts = [b[i] for b in gts for i in range(len(b))]

        auc = roc_auc_score(gts, preds)
        # Calculate macro-averaged precision, recall, and F1 Score
        precision = precision_score(gts, preds, average='macro', zero_division=0)
        recall = recall_score(gts, preds, average='macro', zero_division=0)
        f1 = f1_score(gts, preds, average='macro', zero_division=0)


        acc = correct / (len(loader.dataset) * self.args['num_assays']) # Derive ratio of correct predictions.

        self.eval_time = time.time() - start_time

        if self.wb_log is True:
            wandb.log({'epoch': self.curr_epoch, "eval time": self.eval_time})

        return acc, auc, precision, recall, f1



    def analyze(self):
        '''
        Plot the model performance.
        '''

        # plot side by side
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(10, 5))
        ax1.plot(self.eval_metrics['loss'])
        ax1.set_xlabel('Epoch')
        ax1.set_ylabel('Loss')
        ax1.set_title('Losses')

        ax2.plot(self.eval_metrics['auc_train'], label='train')
        ax2.plot(self.eval_metrics['auc_test'], label='test')
        ax2.set_xlabel('Epoch')
        ax2.set_ylabel('AUC')
        ax2.set_title('Area Under Curve')
        ax2.legend()
        # make main title for the whole plot
        if args['model'] in ['GCN', 'GCN_FP']:
            plt.suptitle(f'Model: {self.args["model"]} | Node feats: {self.args["num_node_features"]}, Hidden dim: {self.args["hidden_channels"]}, Dropout: {self.args["dropout"]}, Num data points: {self.args["num_data_points"]}, Num assays: {self.args["num_assays"]}, Num epochs: {self.curr_epoch}')
        elif args['model'] in ['FP', 'GROVER', 'GROVER_FP']:
            plt.suptitle(f'Model: {self.args["model"]} | Num layers: {self.args["num_layers"]}, Hidden dim: {self.args["hidden_channels"]}, Dropout: {self.args["dropout"]}, Num data points: {self.args["num_data_points"]}, Num assays: {self.args["num_assays"]}, Num epochs: {self.curr_epoch}')
        plt.show()



    def save_model(self, folder, filename, save_weights=True, save_logs=True):
        print('saving experiment...')

        filename += f'_{self.curr_epoch}e'
        if save_weights:
            torch.save(self.model.state_dict(), os.path.join(folder, filename+'.pt'))

        #if save_logs:

    def load_model(self, folder, filename):
        print('loading model...')
        self.model.load_state_dict(torch.load(os.path.join(folder, filename+'.pt')))


# Experiments

In [81]:
args = {}
args['device'] = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# data parameters
args['num_data_points'] = 324191 # all=324191 # number of data points to use
#args['assay_list'] = cell_based_high_hr #for all: matrix_df.columns.values[1:]
args['num_assays'] = 5 # number of assays to use (i.e., no. of output classes)
args['assay_start'] = 0 # which assay to start from
args['assay_order'] = assay_order
args['num_node_features'] = 79 # number of node features in graph representation
args['grover_fp_dim'] = 5000 #grover_fp['fps'][0].shape[0] # None  # dim of grover fingerprints
args['fp_dim'] = 2215 # dim of fingerprints


# training parameters
args['model'] = 'GROVER_FP' # 'GCN', 'GCN_FP', 'FP', 'GROVER', 'GROVER_FP'
args['num_layers'] = 5 # number of layers in MLP
args['hidden_channels'] = 128 # 64
args['dropout'] = 0.2
args['batch_size'] = 256
args['num_epochs'] = 100
args['lr'] = 0.01
#args['gradient_clip_norm'] = 1.0
#args['network_weight_decay'] = 0.0001
#args['lr_decay_factor'] = 0.5

# check batch size -> to include examples of classes
# dropout maybe higher


In [None]:
cell_based_high_hr

['2797', '743397', '1979', '602248', '624127', '1910', '2796']

In [84]:
# find indeces of assays in assay_list in assay_order
# return list of indeces
def find_assay_indeces(assay_list, assay_order):
    indeces = []
    for assay in assay_list:
        indeces.append(assay_order.index(assay))
    return indeces

args['assay_list'] = ['2797']
args['num_assays'] = 1
args['assays_idx'] = find_assay_indeces(args['assay_list'], assay_order)


In [83]:
data_splits = prepare_splits(data_list, args)

Number of training graphs: 194514
Number of validation graphs: 64838
Number of test graphs: 64839
Example of a graph data object: Data(x=[24, 79], edge_index=[2, 52], edge_attr=[52, 10], y=[1, 271], fp=[2215], grover_fp=[5000])


### Sweeps

In [90]:
sweep_config = {
    'method': 'bayes',
    'metric': {'goal': 'maximize', 'name': 'AUC test'},
    }
parameters_dict = {
    'batch_size': {
        'values': [128, 256, 512, 1014]
        },
    'dropout': {
          'values': [0.3, 0.5]
        },
    }

parameters_dict.update({
    'num_data_points': {
        'value': args['num_data_points']},
    'num_epochs': {
        'value': args['num_epochs']},
    'num_layers': {
        'value': args['num_layers']},
    'hidden_channels': {
        'value': args['hidden_channels']},
    'lr': {
        'value': args['lr']}
    })
sweep_config['parameters'] = parameters_dict

In [91]:
sweep_config

{'method': 'bayes',
 'metric': {'goal': 'maximize', 'name': 'AUC test'},
 'parameters': {'batch_size': {'values': [128, 256, 512, 1014]},
  'dropout': {'values': [0.3, 0.5]},
  'num_data_points': {'value': 324191},
  'num_epochs': {'value': 100},
  'num_layers': {'value': 5},
  'hidden_channels': {'value': 128},
  'lr': {'value': 0.01}}}

In [92]:
sweep_id = wandb.sweep(sweep_config, project="GDL_molecular_activity_prediction")

Create sweep with ID: kk90ey6b
Sweep URL: https://wandb.ai/davidkubanek/GDL_molecular_activity_prediction/sweeps/kk90ey6b


In [93]:
def run_sweep(data_splits, args):
    # Create a custom run name dynamically
    run_name = f"{args['model']}"

    with wandb.init(config=args):
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller

        # with wandb.init(config=wandb.config) as run:
        # If called by wandb.agent, as below,
        # this config will be set by Sweep Controller
        # config = wandb.config

        args['batch_size'] = wandb.config.batch_size
        args['dropout'] = wandb.config.dropout

        # create dataset from data_list
        dataloader = prepare_dataloader(data_splits, args)

        # train model
        exp = TrainManager(dataloader, args)
        exp.train(epochs=10, log=True, wb_log=True)


In [94]:
dataloader = prepare_dataloader(data_splits, args)

In [95]:
dataloader['train'].batch_size

1014

In [None]:
args['model'] = 'GROVER_FP'
# run the sweep
wandb.agent(sweep_id, run_sweep(data_splits, args), count=4)

Model is on device: cuda:0
Total number of parameters: 991105


Epoch [0/10]: 760it [00:19, 39.48it/s]                         


Epoch: 0, Loss: 0.1947, Train AUC: 0.5000, Test AUC: 0.5000
                        Train F1: 0.4860, Test F1: 0.4852


Epoch [1/10]: 760it [00:18, 41.97it/s]                         
Epoch [2/10]: 760it [00:18, 41.97it/s]                         
Epoch [3/10]: 760it [00:18, 41.78it/s]                         
Epoch [4/10]:   2%|▏         | 13/759 [00:00<00:18, 39.76it/s]

### Single run

In [None]:
args['assay_list'] = ['1979']
args['num_assays'] = 1
args['assays_idx'] = find_assay_indeces(args['assay_list'], assay_order)

args['model'] = 'GROVER_FP'
args['dropout'] = 0.2
args['batch_size'] = 256
args['hidden_channels'] = 256
args['lr'] = 0.01
# Create a custom run name dynamically
run_name = f"{args['model']}_b{args['batch_size']}_d{args['dropout']}_hdim{args['hidden_channels']}_ass{args['assay_list'][0]}_noout"
run = wandb.init(
    name=run_name,
    # Set the project where this run will be logged
    project="GDL_molecular_activity_prediction",
    # Track hyperparameters and run metadata
    config={
        'num_data_points': args['num_data_points'],
        'assays': 'cell_based_high_hr',
        'num_assays': args['num_assays'],

        'model': args['model'],
        'num_layers': args['num_layers'],
        'hidden_channels': args['hidden_channels'],
        'dropout': args['dropout'],
        'batch_size': args['batch_size'],
        'num_epochs': args['num_epochs'],
        'lr': args['lr'],
    })

0,1
AUC test,▁▁▁▂▁▁▂▁▁▁▂▁▃▂▂▁▁▁▁▂▂▁▂▁▂▁▁▂▃▁▂█▂▃▁▂▂▅▁▁
AUC train,▁▁▁▁▁▂▂▂▁▂▂▁▂▂▁▁▁▁▁▁▂▁▁▁▁▁▂▁▃▁▂▇▂▃▂▃▂█▁▁
F1 test,▁▁▁▁▂▂▂▂▁▂▂▁▃▂▂▁▁▁▁▂▂▁▂▁▂▁▁▂▄▁▃█▂▃▁▂▂▆▁▁
F1 train,▁▁▁▁▁▂▂▂▂▂▂▁▂▂▂▁▁▁▁▂▂▁▁▁▂▁▂▁▃▁▂▇▂▃▂▃▂█▁▁
Precision test,▁▁▂▅▂▂▃▂▂▂▃▁▅▃▃█▁▁▁▂▂▃▃▃▃▃▂▅▃▃▃▅▂▃▂▄▃▃▂▂
Precision train,▁▁▃▄▃▃▃▃▂▃▄▃▄▃▃▁▃▄█▃▃▅▃▃▄▄▃▅▅▄▃▇▃▄▃▅▄▇▂▂
Recall test,▁▁▁▂▁▁▂▁▁▁▂▁▃▂▂▁▁▁▁▂▂▁▂▁▂▁▁▂▃▁▂█▂▃▁▂▂▅▁▁
Recall train,▁▁▁▁▁▂▂▂▁▂▂▁▂▂▁▁▁▁▁▁▂▁▁▁▁▁▂▁▃▁▂▇▂▃▂▃▂█▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
eval time,▆▆▃▆▃▆▃▁█▁▆▁█▁▁▆▁▆▁▆▆▃▇▁▆▃▆█▁▆▁▆▁▁▆▁▆▁▆▁

0,1
AUC test,0.49973
AUC train,0.50079
F1 test,0.48799
F1 train,0.49006
Precision test,0.49362
Precision train,0.51639
Recall test,0.49973
Recall train,0.50079
epoch,56.0
eval time,5.67551


In [None]:
# create dataset from data_list
dataloader = prepare_dataloader(data_list, args)

Number of training graphs: 194514
Number of validation graphs: 64838
Number of test graphs: 64839
Example of a graph data object: Data(x=[24, 79], edge_index=[2, 52], edge_attr=[52, 10], y=[1, 271], fp=[2215], grover_fp=[5000])


In [None]:
# train model
exp = TrainManager(dataloader, args)
exp.train(epochs=100, log=True, wb_log=True)

Model is on device: cuda:0


Epoch [0/100]: 760it [00:17, 43.22it/s]                         


Epoch: 0, Loss: 0.1432, Train AUC: 0.5000, Test AUC: 0.5000
                        Train F1: 0.4869, Test F1: 0.4870


Epoch [1/100]: 760it [00:17, 42.59it/s]                         
Epoch [2/100]: 760it [00:18, 41.87it/s]                         
Epoch [3/100]: 760it [00:17, 42.70it/s]                         
Epoch [4/100]: 760it [00:17, 42.51it/s]                         
Epoch [5/100]: 760it [00:17, 42.41it/s]                         
Epoch [6/100]: 760it [00:17, 42.23it/s]                         
Epoch [7/100]: 760it [00:18, 41.69it/s]                         
Epoch [8/100]: 760it [00:18, 41.62it/s]                         
Epoch [9/100]: 760it [00:18, 41.34it/s]                         
Epoch [10/100]: 760it [00:18, 41.25it/s]                         


Epoch: 10, Loss: 0.1985, Train AUC: 0.5008, Test AUC: 0.4998
                        Train F1: 0.4902, Test F1: 0.4887


Epoch [11/100]: 760it [00:17, 43.05it/s]                         
Epoch [12/100]: 760it [00:17, 43.46it/s]                         
Epoch [13/100]: 760it [00:17, 42.82it/s]                         
Epoch [14/100]: 760it [00:17, 42.87it/s]                         
Epoch [15/100]: 760it [00:17, 42.67it/s]                         
Epoch [16/100]: 760it [00:17, 42.31it/s]                         
Epoch [17/100]: 760it [00:17, 42.64it/s]                         
Epoch [18/100]: 760it [00:17, 42.75it/s]                         
Epoch [19/100]: 760it [00:17, 42.88it/s]                         
Epoch [20/100]: 760it [00:17, 42.81it/s]                         


Epoch: 20, Loss: 0.2938, Train AUC: 0.5005, Test AUC: 0.4998
                        Train F1: 0.4900, Test F1: 0.4886


Epoch [21/100]: 760it [00:18, 42.05it/s]                         
Epoch [22/100]: 760it [00:17, 42.36it/s]                         
Epoch [23/100]: 760it [00:18, 41.18it/s]                         
Epoch [24/100]: 760it [00:18, 40.33it/s]                         
Epoch [25/100]: 760it [00:18, 40.32it/s]
Epoch [26/100]: 760it [00:18, 40.52it/s]                         
Epoch [27/100]: 760it [00:18, 40.26it/s]                         
Epoch [28/100]: 760it [00:18, 40.45it/s]                         
Epoch [29/100]: 760it [00:18, 40.52it/s]                         
Epoch [30/100]: 760it [00:18, 40.37it/s]                         


Epoch: 30, Loss: 0.1760, Train AUC: 0.5002, Test AUC: 0.5000
                        Train F1: 0.4885, Test F1: 0.4884


Epoch [31/100]: 760it [00:18, 40.28it/s]                         
Epoch [32/100]: 760it [00:18, 40.61it/s]                         
Epoch [33/100]: 760it [00:18, 40.76it/s]                         
Epoch [34/100]: 760it [00:18, 40.37it/s]
Epoch [35/100]: 760it [00:18, 40.52it/s]                         
Epoch [36/100]: 760it [00:18, 40.64it/s]                         
Epoch [37/100]: 760it [00:18, 40.26it/s]                         
Epoch [38/100]: 760it [00:19, 38.82it/s]                         
Epoch [39/100]: 760it [00:18, 40.16it/s]                         
Epoch [40/100]: 760it [00:18, 40.27it/s]                         


Epoch: 40, Loss: 0.2425, Train AUC: 0.4999, Test AUC: 0.4999
                        Train F1: 0.4869, Test F1: 0.4870


Epoch [41/100]: 760it [00:18, 40.00it/s]                         
Epoch [42/100]: 760it [00:19, 39.77it/s]                         
Epoch [43/100]: 760it [00:18, 40.48it/s]
Epoch [44/100]: 760it [00:18, 40.37it/s]                         
Epoch [45/100]: 760it [00:18, 41.47it/s]
Epoch [46/100]: 760it [00:18, 41.60it/s]                         
Epoch [47/100]: 760it [00:18, 40.58it/s]
Epoch [48/100]: 760it [00:19, 39.78it/s]                         
Epoch [49/100]: 760it [00:19, 39.82it/s]                         
Epoch [50/100]: 760it [00:19, 39.99it/s]                         


Epoch: 50, Loss: 0.1454, Train AUC: 0.5000, Test AUC: 0.4994
                        Train F1: 0.4880, Test F1: 0.4872


Epoch [51/100]: 760it [00:19, 39.75it/s]                         
Epoch [52/100]: 760it [00:19, 39.61it/s]
Epoch [53/100]: 760it [00:19, 39.63it/s]                         
Epoch [54/100]: 760it [00:19, 39.57it/s]                         
Epoch [55/100]: 760it [00:19, 39.67it/s]                         
Epoch [56/100]: 760it [00:19, 40.00it/s]                         
Epoch [57/100]: 760it [00:18, 40.57it/s]
Epoch [58/100]: 760it [00:18, 41.03it/s]                         
Epoch [59/100]: 760it [00:17, 42.51it/s]                         
Epoch [60/100]: 760it [00:18, 41.62it/s]


Epoch: 60, Loss: 0.1880, Train AUC: 0.5009, Test AUC: 0.5009
                        Train F1: 0.4909, Test F1: 0.4909


Epoch [61/100]: 760it [00:20, 37.85it/s]
Epoch [62/100]: 760it [00:19, 39.37it/s]                         
Epoch [63/100]: 760it [00:19, 39.21it/s]                         
Epoch [64/100]: 760it [00:19, 39.67it/s]                         
Epoch [65/100]: 760it [00:19, 39.60it/s]                         
Epoch [66/100]: 760it [00:18, 40.07it/s]                         
Epoch [67/100]: 760it [00:18, 40.43it/s]                         
Epoch [68/100]: 760it [00:19, 39.48it/s]                         
Epoch [69/100]: 760it [00:18, 40.08it/s]                         
Epoch [70/100]: 760it [00:19, 39.99it/s]                         


Epoch: 70, Loss: 0.1858, Train AUC: 0.4999, Test AUC: 0.5000
                        Train F1: 0.4870, Test F1: 0.4873


Epoch [71/100]: 760it [00:18, 40.14it/s]
Epoch [72/100]: 760it [00:18, 41.59it/s]                         
Epoch [73/100]: 760it [00:18, 41.66it/s]                         
Epoch [74/100]: 760it [00:19, 39.49it/s]                         
Epoch [75/100]: 760it [00:18, 41.65it/s]
Epoch [76/100]: 760it [00:18, 41.45it/s]                         
Epoch [77/100]: 760it [00:18, 40.61it/s]                         
Epoch [78/100]: 760it [00:18, 41.29it/s]                         
Epoch [79/100]: 760it [00:18, 41.24it/s]                         
Epoch [80/100]: 760it [00:18, 41.56it/s]                         


Epoch: 80, Loss: 0.1888, Train AUC: 0.4994, Test AUC: 0.4997
                        Train F1: 0.4873, Test F1: 0.4882


Epoch [81/100]: 760it [00:18, 41.66it/s]                         
Epoch [82/100]: 760it [00:18, 41.39it/s]                         
Epoch [83/100]: 760it [00:18, 41.47it/s]                         
Epoch [84/100]: 760it [00:18, 41.10it/s]                         
Epoch [85/100]: 760it [00:18, 41.65it/s]                         
Epoch [86/100]: 760it [00:18, 41.27it/s]                         
Epoch [87/100]: 760it [00:18, 41.87it/s]                         
Epoch [88/100]: 760it [00:18, 41.70it/s]                         
Epoch [89/100]: 760it [00:18, 41.72it/s]                         
Epoch [90/100]: 760it [00:18, 40.70it/s]                         


Epoch: 90, Loss: 0.1568, Train AUC: 0.5001, Test AUC: 0.5001
                        Train F1: 0.4873, Test F1: 0.4873


Epoch [91/100]: 760it [00:18, 41.46it/s]
Epoch [92/100]: 760it [00:18, 41.86it/s]                         
Epoch [93/100]: 760it [00:18, 41.70it/s]                         
Epoch [94/100]: 760it [00:18, 41.69it/s]                         
Epoch [95/100]: 760it [00:19, 39.57it/s]
Epoch [96/100]: 760it [00:18, 40.49it/s]
Epoch [97/100]: 760it [00:17, 42.81it/s]                         
Epoch [98/100]: 760it [00:17, 42.66it/s]                         
Epoch [99/100]: 760it [00:17, 42.46it/s]                         


In [None]:
exp.analyze()

In [None]:
wandb.finish()