# CONCERTO architecture

In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch as th
plt.rcParams["font.family"] = "Palatino"

In [None]:
def training_loop(model, args, mut_loss_criterion, carc_loss_criterion, train_loader, val_loader, note=''):
    '''
    Performs the training loop for the model with either carc or mut datasets with corresponding losses
    '''
    stopper = construct_stopper(args)

    optimizer = Adam(model.parameters(), lr=args['lr'], weight_decay=args['network_weight_decay'])
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=args['lr_decay_factor'])

    for epoch in range(args['num_epochs']):
        # run a training epoch
        train_loss, train_carc_loss, train_mut_loss,\
        train_mut_metric, train_mut_metric_name, train_mut_metric2, train_mut_metric_name2, \
        train_carc_metric, train_carc_metric_name, train_carc_metric2, train_carc_metric_name2 = run_a_train_epoch(
            args, epoch, model, train_loader, mut_loss_criterion, carc_loss_criterion, optimizer)

        # Validation and early stop
        val_carc_metric, val_carc_metric_name, val_carc_metric2, val_carc_metric_name2,\
        val_mut_metric, val_mut_metric_name, val_mut_metric2, val_mut_metric_name2, \
        val_loss, val_carc_loss, val_mut_loss, \
        performance_df = run_an_eval_epoch(
            args, model, val_loader, mut_loss_criterion, carc_loss_criterion
        )

        if args['early_stopping_metric'][1] == 'carc':
            if args['early_stopping_metric'][0] == 'roc_auc_score':
                val_score = val_carc_metric
                if val_carc_metric_name != 'roc_auc_score':
                    raise ValueError

            elif args['early_stopping_metric'][0] == 'pearson_r2':
                val_score = val_carc_metric
                if val_carc_metric_name != 'pearson_r2':
                    raise ValueError

            elif args['early_stopping_metric'][0] == 'rmse':
                val_score = val_carc_metric2
                if val_carc_metric_name2 != 'rmse':
                    raise ValueError

            elif args['early_stopping_metric'][0] == 'validation_loss':
                val_score = val_loss

            else:
                raise ValueError
        elif args['early_stopping_metric'][1] == 'mut':
            if args['early_stopping_metric'][0] == 'roc_auc_score':
                val_score = val_mut_metric
                if val_mut_metric_name != 'roc_auc_score':
                    raise ValueError
            else:
                raise ValueError
        else:
            raise ValueError

        scheduler.step(val_score)
        early_stop = stopper.step(val_score, model)

        print(f"Training: epoch   {epoch + 1:d}/{args['num_epochs']:d}, "
              f"training loss     {train_loss:.3f}, "
              f"mut_{train_mut_metric_name} {train_mut_metric:.3f} "
              f"mut_{train_mut_metric_name2} {train_mut_metric2:.3f} "
              f"carc_{train_carc_metric_name} {train_carc_metric:.3f} "
              f"carc_{train_carc_metric_name2} {train_carc_metric2:.3f} "
              )
        print(f"Validation: epoch {epoch + 1:d}/{args['num_epochs']:d}, "
              f"validation loss   {val_loss:.3f}, "
              f"mut_{val_mut_metric_name} {val_mut_metric:.3f} "
              f"mut_{val_mut_metric_name2} {val_mut_metric2:.3f} "
              f"carc_{val_carc_metric_name} {val_carc_metric:.3f} "
              f"carc_{val_carc_metric_name2} {val_carc_metric2:.3f} \n"
              )

        if args["use_wandb"]:
            wandb.log({
                f"epoch{note}": epoch + 1,
                f"training_carcinogenic_loss{note}": train_carc_loss,
                f"training_mutagenic_loss{note}": train_mut_loss,
                f"training_loss{note}": train_loss,
                f"training_mut_{train_mut_metric_name}{note}": train_mut_metric,
                f"training_mut_{train_mut_metric_name2}{note}": train_mut_metric2,
                f"training_carc_{train_carc_metric_name}{note}": train_carc_metric,
                f"training_carc_{train_carc_metric_name2}{note}": train_carc_metric2,

                f"validation_loss{note}": val_loss,
                f"validation_carc_loss{note}": val_carc_loss,
                f"validation_mut_loss{note}": val_mut_loss,
                f"validation_carc_{val_carc_metric_name}{note}": val_carc_metric,
                f"validation_carc_{val_carc_metric_name2}{note}": val_carc_metric2,
                f"validation_mut_{val_mut_metric_name}{note}": val_mut_metric,
                f"validation_mut_{val_mut_metric_name2}{note}": val_mut_metric2,
            })

        if early_stop:
            break
    stopper.load_checkpoint(model)
    return model

