# CONCERTO architecture

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch as th
plt.rcParams["font.family"] = "Palatino"

import dgl
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam, lr_scheduler
from torch.utils.data import Dataset, DataLoader
import wandb
import warnings

from rdkit import Chem
from rdkit.Chem import MACCSkeys
from rdkit.Chem import AllChem
from rdkit.Chem import DataStructs

In [None]:
def run_a_train_epoch(args, epoch, model, data_loader, mut_loss_criterion, carc_loss_criterion, optimizer):
    model.train()
    train_meter_carc = Meter()
    train_meter_mut = Meter()
    losses = []
    mut_losses = []
    carc_losses = []

    if type(data_loader.dataset) == dgl.data.utils.Subset:
        use_carc_prob = data_loader.dataset.dataset.use_carc_prob
    elif type(data_loader.dataset) == GraphCancerMolecules or type(data_loader.dataset) == SelfiesCancerMolecules:
        use_carc_prob = data_loader.dataset.use_carc_prob
    else:
        raise ValueError

    for batch_id, batch_data in enumerate(data_loader):
        batch_data = to_device(batch_data, args["device"])
        # evaluate model for this batch for both carcinogenic and mutagenic labels
        logits = model(batch_data)
        # Mask non-existing labels

        # get carcinogenic logits + labels based on the loss function
        if type(carc_loss_criterion) == torch.nn.modules.loss.BCEWithLogitsLoss:
            carc_logits = torch.masked_select(logits[:, 1], batch_data['carc_mask'])

            if use_carc_prob:
                carc_labels = torch.masked_select(batch_data["carc_prob"], batch_data['carc_mask'])
            else:
                carc_labels = torch.masked_select(batch_data["carc_label"], batch_data['carc_mask'])
            carc_logging = torch.masked_select(batch_data["carc_label"], batch_data['carc_mask'])

        elif type(carc_loss_criterion) == torch.nn.modules.loss.MSELoss:
            carc_logits = torch.masked_select(logits[:, 1], batch_data['carc_mask_continuous'])
            carc_labels = torch.masked_select(batch_data["carc_continuous"], batch_data['carc_mask_continuous'])
            carc_logging = carc_labels

        elif type(carc_loss_criterion) == torch.nn.modules.loss.CrossEntropyLoss:
            new_mask = batch_data['carc_mask_continuous'].view(-1, 1).expand(logits[:, 1:6].shape)

            carc_logits = torch.masked_select(logits[:, 1:6], new_mask).view(-1, 5)
            carc_labels = torch.masked_select(batch_data["carc_label_multi"], batch_data['carc_mask_multi'])
            carc_logging = carc_labels

        else:
            raise ValueError

        # Get mutagenic logits
        mut_logits = torch.masked_select(logits[:, 0], batch_data['mut_mask'])
        mut_labels = torch.masked_select(batch_data["mut_label"], batch_data['mut_mask'])

        # In case batch does not contain any carcinogenic labels set loss manually to 0
        if args['use_carc_loss'] and len(carc_logits) > 0:
            carc_loss = carc_loss_criterion(carc_logits, carc_labels).mean()
            train_meter_carc.update(carc_logits.view(-1, 1), carc_logging.view(-1, 1))
        else:
            carc_loss = torch.tensor(0)

        if args['use_mut_loss'] and len(mut_logits) > 0:
            mut_loss = mut_loss_criterion(mut_logits, mut_labels).mean()
            train_meter_mut.update(mut_logits.view(-1, 1), mut_labels.view(-1, 1))
        else:
            mut_loss = torch.tensor(0)

        # Take weighted average of mut_loss and carc loss
        loss = carc_loss * (1 - args['mut_loss_ratio']) + mut_loss * args['mut_loss_ratio']

        # Zero out the optimizer
        optimizer.zero_grad()

        # Backpropagate loss
        loss.backward()

        # Clip loss
        torch.nn.utils.clip_grad_norm_(model.parameters(), args["gradient_clip_norm"])

        # apply changes to weights
        optimizer.step()

        losses.append(loss.item())
        carc_losses.append(carc_loss.item())
        mut_losses.append(mut_loss.item())

        if batch_id + 1 % args['print_every'] == 0:
            print('epoch {:d}/{:d}, batch {:d}/{:d}'.format(
                epoch + 1, args['num_epochs'], batch_id + 1, len(data_loader)))

        if args['use_carc_loss']:
            train_carc_metric, train_carc_metric_name, train_carc_metric2, train_carc_metric_name2 = perform_eval(
                carc_loss_criterion, train_meter_carc)
        else:
            train_carc_metric = np.nan
            train_carc_metric_name = np.nan
            train_carc_metric2 = np.nan
            train_carc_metric_name2 = np.nan

        if args['use_mut_loss']:
            train_mut_metric, train_mut_metric_name, train_mut_metric2, train_mut_metric_name2 = perform_eval(
                mut_loss_criterion, train_meter_mut)
        else:
            train_mut_metric = np.nan
            train_mut_metric_name = np.nan
            train_mut_metric2 = np.nan
            train_mut_metric_name2 = np.nan

        train_loss = np.nanmean(losses)
        train_mut_loss = np.nanmean(mut_losses)
        train_carc_loss = np.nanmean(carc_losses)

        return train_loss, train_carc_loss, train_mut_loss,\
               train_mut_metric, train_mut_metric_name, train_mut_metric2, train_mut_metric_name2, \
               train_carc_metric, train_carc_metric_name, train_carc_metric2, train_carc_metric_name2


In [None]:
def training_loop(model, args, mut_loss_criterion, carc_loss_criterion, train_loader, val_loader, note=''):
    '''
    Performs the training loop for the model with either carc or mut datasets with corresponding losses
    '''
    stopper = construct_stopper(args)

    optimizer = Adam(model.parameters(), lr=args['lr'], weight_decay=args['network_weight_decay'])
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=args['lr_decay_factor'])

    for epoch in range(args['num_epochs']):
        # run a training epoch
        train_loss, train_carc_loss, train_mut_loss,\
        train_mut_metric, train_mut_metric_name, train_mut_metric2, train_mut_metric_name2, \
        train_carc_metric, train_carc_metric_name, train_carc_metric2, train_carc_metric_name2 = run_a_train_epoch(
            args, epoch, model, train_loader, mut_loss_criterion, carc_loss_criterion, optimizer)

        # Validation and early stop
        val_carc_metric, val_carc_metric_name, val_carc_metric2, val_carc_metric_name2,\
        val_mut_metric, val_mut_metric_name, val_mut_metric2, val_mut_metric_name2, \
        val_loss, val_carc_loss, val_mut_loss, \
        performance_df = run_an_eval_epoch(
            args, model, val_loader, mut_loss_criterion, carc_loss_criterion
        )

        if args['early_stopping_metric'][1] == 'carc':
            if args['early_stopping_metric'][0] == 'roc_auc_score':
                val_score = val_carc_metric
                if val_carc_metric_name != 'roc_auc_score':
                    raise ValueError

            elif args['early_stopping_metric'][0] == 'pearson_r2':
                val_score = val_carc_metric
                if val_carc_metric_name != 'pearson_r2':
                    raise ValueError

            elif args['early_stopping_metric'][0] == 'rmse':
                val_score = val_carc_metric2
                if val_carc_metric_name2 != 'rmse':
                    raise ValueError

            elif args['early_stopping_metric'][0] == 'validation_loss':
                val_score = val_loss

            else:
                raise ValueError
        elif args['early_stopping_metric'][1] == 'mut':
            if args['early_stopping_metric'][0] == 'roc_auc_score':
                val_score = val_mut_metric
                if val_mut_metric_name != 'roc_auc_score':
                    raise ValueError
            else:
                raise ValueError
        else:
            raise ValueError

        scheduler.step(val_score)
        early_stop = stopper.step(val_score, model)

        print(f"Training: epoch   {epoch + 1:d}/{args['num_epochs']:d}, "
              f"training loss     {train_loss:.3f}, "
              f"mut_{train_mut_metric_name} {train_mut_metric:.3f} "
              f"mut_{train_mut_metric_name2} {train_mut_metric2:.3f} "
              f"carc_{train_carc_metric_name} {train_carc_metric:.3f} "
              f"carc_{train_carc_metric_name2} {train_carc_metric2:.3f} "
              )
        print(f"Validation: epoch {epoch + 1:d}/{args['num_epochs']:d}, "
              f"validation loss   {val_loss:.3f}, "
              f"mut_{val_mut_metric_name} {val_mut_metric:.3f} "
              f"mut_{val_mut_metric_name2} {val_mut_metric2:.3f} "
              f"carc_{val_carc_metric_name} {val_carc_metric:.3f} "
              f"carc_{val_carc_metric_name2} {val_carc_metric2:.3f} \n"
              )

        if args["use_wandb"]:
            wandb.log({
                f"epoch{note}": epoch + 1,
                f"training_carcinogenic_loss{note}": train_carc_loss,
                f"training_mutagenic_loss{note}": train_mut_loss,
                f"training_loss{note}": train_loss,
                f"training_mut_{train_mut_metric_name}{note}": train_mut_metric,
                f"training_mut_{train_mut_metric_name2}{note}": train_mut_metric2,
                f"training_carc_{train_carc_metric_name}{note}": train_carc_metric,
                f"training_carc_{train_carc_metric_name2}{note}": train_carc_metric2,

                f"validation_loss{note}": val_loss,
                f"validation_carc_loss{note}": val_carc_loss,
                f"validation_mut_loss{note}": val_mut_loss,
                f"validation_carc_{val_carc_metric_name}{note}": val_carc_metric,
                f"validation_carc_{val_carc_metric_name2}{note}": val_carc_metric2,
                f"validation_mut_{val_mut_metric_name}{note}": val_mut_metric,
                f"validation_mut_{val_mut_metric_name2}{note}": val_mut_metric2,
            })

        if early_stop:
            break
    stopper.load_checkpoint(model)
    return model



In [None]:
def train_val_test_mode(args, mut_loss_criterion, carc_loss_criterion, held_out_test_carc_loss_criterion):
    '''
    Executes the simple training loop if no pre-training is required or executes the pre-training loop
    '''

    train_loader, val_loader, test_loader, held_out_test_data_loader, data_feats = load_data(args)

    model = get_model(args, data_feats)
    model.to(args['device'])

    wandb_run = set_up_wandb(model, args, name=args['run'], group=args['group_name'])
    print(data_feats)
    print(model)

    if not args['mut_pre_training']:
        model = training_loop(model, args, mut_loss_criterion, carc_loss_criterion, train_loader, val_loader)

    else:
        for j in range(args['num_mut_pre_training_loop']):
            # perform mutagenicity pre-training cycle loop
            model = mutagenicity_pre_training(
                args, train_loader, val_loader, model, mut_loss_criterion, carc_loss_criterion)

            if j < args['num_mut_pre_training_loop'] - 1:
                summary_dict = end_of_training_evaluation(
                    model, args, mut_loss_criterion, carc_loss_criterion, held_out_test_carc_loss_criterion,
                    val_loader, test_loader, held_out_test_data_loader, note=f"_{j}", save_data=False
                )

    summary_dict = end_of_training_evaluation(
        model, args, mut_loss_criterion, carc_loss_criterion, held_out_test_carc_loss_criterion,
        val_loader, test_loader, held_out_test_data_loader
    )

    wandb.log(summary_dict)
    return summary_dict

In [None]:
def load_data(args):
    data, held_out_test_data = get_datasets(args)
    train, val, test = split_data(data)
    # data features from data
    data_feats = data.get_data_feats()
    collate_fn = data.get_collate_fn()

    train_loader = construct_data_loader(train, args, collate_fn)
    val_loader = construct_data_loader(val, args, collate_fn)
    test_loader = construct_data_loader(test, args, collate_fn)
    held_out_test_data_loader = construct_data_loader(held_out_test_data, args, collate_fn)

    return train_loader, val_loader, test_loader, held_out_test_data_loader, data_feats

In [None]:
def mutagenicity_pre_training(args, train_loader, val_loader, model, mut_loss_criterion, carc_loss_criterion, note=''):
    '''
    Performs mutagenicity pre-training and carciniogenicity training iteration
    terminate upon early stopping
    '''
    train_mut_loader, train_carc_loader = split_loader_into_carc_and_mut(train_loader, args)

    # mut loss only
    args['use_carc_loss'] = False
    args['use_mut_loss'] = True
    args['early_stopping_metric'] = ('roc_auc_score', 'mut')
    # train model on mut data with mut loss
    model = training_loop(
        model, args, mut_loss_criterion, carc_loss_criterion, train_mut_loader, val_loader, note=note)

    # carc loss only
    args['use_carc_loss'] = True
    args['use_mut_loss'] = False

    if args['train_carc_loss_fnc'] == 'MSE':
        args['early_stopping_metric'] = ('rmse', 'carc')
    elif args['train_carc_loss_fnc'] == 'BCE':
        args['early_stopping_metric'] = ('roc_auc_score', 'carc')
    elif args['train_carc_loss_fnc'] == 'CE':
        args['early_stopping_metric'] = ('validation_loss', 'carc')
    else:
        raise ValueError

    # train model on carc data with carc loss
    model = training_loop(
        model, args, mut_loss_criterion, carc_loss_criterion, train_carc_loader, val_loader, note=note)
    # will do this K times
    args['use_carc_loss'] = True
    args['use_mut_loss'] = True

    return model

In [2]:
class DatasetClass(Dataset):
	def __init__(self, out_feats=1, drop_ionic=True, min_carbon_count=0, fraction_of_data=1, use_carc_prob=False, carc_percentile_to_drop=0):
		self.out_feats = out_feats
		self.drop_ionic = drop_ionic
		self.min_carbon_count = min_carbon_count
		self.fraction_of_data = fraction_of_data
		self.use_carc_prob = use_carc_prob
		self.carc_percentile_to_drop = carc_percentile_to_drop
		
	def load_carc_cpdb(cls):
		cpdb = pd.read_csv(
			'Data/concerto data/cpdb_aggergated.csv',
			usecols=['smiles', 'td50_log_harmonic', 'cas', 'carc_class', 'carc_class_multi']
		)
		cpdb['smiles'] = cls.smiles_standardize(cpdb['smiles'].values)
		cpdb = cpdb[cpdb['smiles'].notnull()]
		cpdb.rename(columns={'td50_log_harmonic': 'td50'}, inplace=True)
		assert cpdb['smiles'].duplicated().sum() == 0, \
			cpdb[cpdb['smiles'].duplicated(keep=False)].sort_values('smiles').to_string()
		cpdb['source'] = 'cpdb'
		return cpdb
	
	def load_mut_hansen(cls):
		ames_df = pd.read_csv('Data/concerto data/hansen_2009_ames.smi', sep='\t', names=['smiles', 'cas', 'class'])
		ames_df['smiles'] = cls.smiles_standardize(ames_df['smiles'].values)
		ames_df = ames_df[ames_df['smiles'].notnull()]
		ames_df.rename(columns={'class': 'mut_class'}, inplace=True)
		ames_df.drop(columns='cas', inplace=True)
		ames_df = ames_df[~ames_df.duplicated(['smiles', 'mut_class'])]
		assert ames_df['smiles'].duplicated().sum() == 0, \
			ames_df[ames_df['smiles'].duplicated(keep=False)].sort_values('smiles').to_string()
		ames_df['source'] = 'hansen'
		return ames_df
    

	def smiles_standardize(cls, smiles):
		new_smiles = []
		for smile in smiles:
			if pd.isnull(smile):
				new_smiles.append(None)
				continue
			# Generate the molecule from smile string
			mol = Chem.MolFromSmiles(smile)
			# If the smile string is null then continue
			if pd.isnull(mol):
				new_smiles.append(None)
			else:
				new_smiles.append(Chem.MolToSmiles(mol,canonical=True,isomericSmiles=False,allBondsExplicit=False))
		return new_smiles
	
	def count_num_carbons(cls, smile):
		mol = Chem.MolFromSmiles(smile)
		num_c = 0
		for atom in mol.GetAtoms():
			if atom.GetSymbol().upper() == 'C':
				num_c += 1
		return num_c
	
	def load_data(self):
		# Loaded data should not overlap mut_hansn and carc_cpdb since they are the two primary training data sources
		mut = self.load_mut_hansen()
		carc = self.load_carc_cpdb()

		carc_datasets = []
		mut_datasets = []
		
		carc_datasets.append(carc)

		# if 'carc_ccris' in self.carc_datasets:
		# 	temp = self.load_carc_ccris()
		# 	# drop samples that are in training data by default cpdb & carc_pred_el
		# 	temp = temp[~temp['smiles'].isin(carc['smiles'])]
		# 	carc_datasets.append(temp)

		mut_datasets.append(mut)

		if mut_datasets:
			mut = pd.concat(mut_datasets)
		else:
			mut = pd.DataFrame(columns=mut.columns)

		if carc_datasets:
			carc = pd.concat(carc_datasets)
		else:
			carc = pd.DataFrame(columns=carc.columns)

		# check for duplicates
		if mut['smiles'].duplicated().sum() != 0:
			warnings.warn(f"duplicated samples {mut['smiles'].duplicated().sum()} contained in mutagenicity data "
						  f"from {mut_datasets}\n"
						  f" {mut[mut['smiles'].duplicated(keep=False)].sort_values('smiles').head().to_string()}")
			mut.sort_values(['smiles', 'mut_class'], ascending=False)
			mut = mut[~mut.duplicated('smiles')]

		if carc['smiles'].duplicated().sum() != 0:
			warnings.warn(f"duplicated samples {carc['smiles'].duplicated().sum()} contained in carcinogenicity data "
						  f"from {carc_datasets}\n"
						  f" {carc[carc['smiles'].duplicated(keep=False)].sort_values('smiles').head().to_string()}")
			carc = carc[~carc.duplicated('smiles')]

		# merge the mut and carc datasets
		df = pd.merge(mut, carc, how='outer', on='smiles')
		df['source'] = df['source_x'].fillna('') + ',' + df['source_y'].fillna('')
		df.drop(columns=['source_x', 'source_y'], inplace=True)

		required_columns = ['td50', 'carc_class', 'mut_class', 'carc_class_multi']
		for column in required_columns:
			if column not in df.columns:
				df[column] = np.nan

		if self.drop_ionic:
			df = df[~df['smiles'].str.contains('\.')]

		df['carbon_count'] = [self.count_num_carbons(x) for x in df['smiles'].values]
		if self.min_carbon_count:
			df = df[df['carbon_count'] >= self.min_carbon_count]
			
		# Shuffle the data
		index = np.arange(len(df))
		np.random.seed(1337)
		np.random.shuffle(index)
		df = df.iloc[index].reset_index(drop=True)

		# Log of Betas that were fitted using a cox regression model
		if self.carc_percentile_to_drop > 0:
			# cleave off the top percentile
			lowest_td_50_val = np.nanpercentile(df['td50'].values, self.carc_percentile_to_drop)
			mask = df[df['td50'] <= lowest_td_50_val].index
			df.loc[mask, 'td50'] = lowest_td_50_val

		betas = np.log((np.log(2) / df["td50"].values))

		# standardize
		beta_standardized = (betas - np.nanmean(betas)) / np.nanstd(betas)
		# normalize
		beta_normalized = betas - np.nanmin(betas)
		beta_normalized = beta_normalized / np.nanmax(beta_normalized)

		df['beta_standardized'] = beta_standardized
		df['beta_normalized'] = beta_normalized

		if self.fraction_of_data < 1:
			new_data_len = int(len(df) * self.fraction_of_data)
			df = df.iloc[:new_data_len]

		#save the resulting dataframe
		self.df = df
    

In [None]:
def convert_smile_to_fp_bit_string(self, smile):
        """
        RDFKIT Morgan and MACCSS are default fingerprints. Torsion and atom pairs are optional
        """
        # RDFKIT
        x = Chem.MolFromSmiles(smile)
        fp1 = Chem.RDKFingerprint(x, fpSize=self.fp_nbits)
        # MACCSS substructure
        fp2 = MACCSkeys.GenMACCSKeys(x)
        # Morgan
        fp_hashes = []
        fp3 = AllChem.GetHashedMorganFingerprint(x, 2, nBits=self.fp_nbits)
        fp3_array = np.zeros((0,), dtype=np.int8)
        DataStructs.ConvertToNumpyArray(fp3, fp3_array)
        fp_hashes.append(fp3_array)

        # Hashed atom pairs
        if self.atom_pairs_fingerprints:
            fp4 = AllChem.GetHashedAtomPairFingerprint(x, nBits=self.fp_nbits)
            fp4_array = np.zeros((0,), dtype=np.int8)
            DataStructs.ConvertToNumpyArray(fp4, fp4_array)
            fp_hashes.append(fp4_array)

        # Torsion fingerprints?
        if self.torsion_fingerprints:
            fp5 = AllChem.GetHashedTopologicalTorsionFingerprint(x, nBits=self.fp_nbits)
            fp5_array = np.zeros((0,), dtype=np.int8)
            DataStructs.ConvertToNumpyArray(fp5, fp5_array)
            fp_hashes.append(fp5_array)

        fp = fp1.ToBitString() + fp2.ToBitString()
        fp = np.array(list(fp)).astype(np.int8)
        fp = np.concatenate([fp] + fp_hashes)
        fp = torch.tensor(fp).to(torch.float32)
        return fp

In [3]:
dataset = DatasetClass()
dataset.load_data()

[09:11:05] SMILES Parse Error: syntax error while parsing: NNC(=O)CNC(=O)\C=N\#N
[09:11:05] SMILES Parse Error: Failed parsing SMILES 'NNC(=O)CNC(=O)\C=N\#N' for input: 'NNC(=O)CNC(=O)\C=N\#N'
[09:11:05] SMILES Parse Error: syntax error while parsing: O=C1NC(=O)\C(=N/#N)\C=N1
[09:11:05] SMILES Parse Error: Failed parsing SMILES 'O=C1NC(=O)\C(=N/#N)\C=N1' for input: 'O=C1NC(=O)\C(=N/#N)\C=N1'
[09:11:05] SMILES Parse Error: syntax error while parsing: NC(=O)CNC(=O)\C=N\#N
[09:11:05] SMILES Parse Error: Failed parsing SMILES 'NC(=O)CNC(=O)\C=N\#N' for input: 'NC(=O)CNC(=O)\C=N\#N'
[09:11:05] SMILES Parse Error: syntax error while parsing: CCCCN(CC(O)C1=C\C(=N/#N)\C(=O)C=C1)N=O
[09:11:05] SMILES Parse Error: Failed parsing SMILES 'CCCCN(CC(O)C1=C\C(=N/#N)\C(=O)C=C1)N=O' for input: 'CCCCN(CC(O)C1=C\C(=N/#N)\C(=O)C=C1)N=O'
[09:11:05] SMILES Parse Error: syntax error while parsing: NC(COC(=O)\C=N/#N)C(=O)O
[09:11:05] SMILES Parse Error: Failed parsing SMILES 'NC(COC(=O)\C=N/#N)C(=O)O' for inp

In [7]:
# give random draw of data
dataset.df.sample(10)


Unnamed: 0,smiles,mut_class,cas,td50,carc_class_multi,carc_class,source,carbon_count,beta_standardized,beta_normalized
3299,CN1CCc2cc3c(c4c2C1Cc1ccccc1-4)OCO3,1.0,,,,,"hansen,",18,,
2897,CCOP(=O)(Oc1ccc([N+](=O)[O-])cc1)c1ccccc1,1.0,,,,,"hansen,",14,,
977,O=C=Nc1cccc2c(N=C=O)cccc12,0.0,,,,,"hansen,",12,,
3231,Cc1c(N=O)cccc1[N+](=O)[O-],0.0,,,,,"hansen,",7,,
1117,O=C1OC(O)C(C(Cl)Br)=C1Cl,1.0,,,,,"hansen,",5,,
4442,O=[N+]([O-])c1cc(N(CCO)CCO)ccc1NCCO,1.0,33229-34-4,10.932522,1.0,0.0,"hansen,cpdb",12,-0.397473,0.138117
68,NC(CS)C(=O)NCC(=O)O,1.0,,,,,"hansen,",5,,
5493,Cc1ccc(N=Nc2c(O)ccc3ccccc23)c(C)c1,1.0,,,,,"hansen,",18,,
2331,CC1SCC(C(=O)NC(Cc2c[nH]cn2)C(=O)N2CCCC2C(N)=O)...,0.0,,,,,"hansen,",17,,
3599,COc1nc2cccc(CBr)c2nc1OC,1.0,,,,,"hansen,",11,,
