In [None]:
class BaseTrainer:

  def __init__(self, model, criterion, optimizer, train_dataloader, val_dataloader, test_dataloader, male_dataloader, female_dataloader, epochs, model_name, device):

    self.model = model
    self.criterion = criterion
    self.optimizer = optimizer
    self.train_dataloader = train_dataloader
    self.val_dataloader = val_dataloader
    self.test_dataloader = test_dataloader
    self.male_dataloader = male_dataloader
    self.female_dataloader = female_dataloader
    self.epochs = epochs
    self.model_name = model_name
    self.device = device

  def train_and_validate(self):

    train_losses = []
    val_losses = []

    train_r2s = []
    train_pearsons = []
    train_mses = []
    train_maes = []
    train_mapes = []

    val_r2s = []
    val_pearsons = []
    val_mses = []
    val_maes = []
    val_mapes = []

    calc_train_r2 = CorrAndRSquared(self.model_name, 'train')
    calc_val_r2 = CorrAndRSquared(self.model_name, 'val')

    calc_train_mae = MAE(self.model_name, 'train')
    calc_val_mae = MAE(self.model_name, 'val')

    calc_train_mape = MAPE(self.model_name, 'train')
    calc_val_mape = MAPE(self.model_name, 'val')

    calc_train_mse = MSE(self.model_name, 'train')
    calc_val_mse = MSE(self.model_name, 'val')

    calc_plot_loss = PlotLoss(self.model_name)
    calc_r2_epochs = PlotR2(self.model_name)
    calc_pearson_epochs = PlotPearson(self.model_name)
    calc_mse_epochs = PlotMSE(self.model_name)
    calc_mae_epochs = PlotMAE(self.model_name)
    calc_mape_epochs = PlotMAPE(self.model_name)

    for epoch in range(self.epochs):

        print(f'Epoch {epoch + 1}')
        print('---------------------')

        train_labels, train_preds, train_loss = self.model.train_model(self.criterion, self.optimizer, self.train_dataloader)

        val_labels, val_preds, val_loss = self.model.evaluate_model(self.criterion, self.val_dataloader)

        train_losses.append(train_loss)
        val_losses.append(val_loss)

        print()
        print(f'Training Loss: {train_loss:.4f}')
        print(f'Validation Loss: {val_loss:.4f}')
        print()

        train_r2, train_pearson = calc_train_r2(train_labels, train_preds, plot = False)
        val_r2, val_pearson = calc_val_r2(val_labels, val_preds, plot = False)

        train_mse = calc_train_mse(train_labels, train_preds)
        val_mse = calc_train_mse(val_labels, val_preds)

        train_mae = calc_train_mae(train_labels, train_preds)
        val_mae = calc_val_mae(val_labels, val_preds)

        train_mape = calc_train_mape(train_labels, train_preds)
        val_mape = calc_val_mape(val_labels, val_preds)

        train_r2s.append(train_r2)
        val_r2s.append(val_r2)

        train_pearsons.append(train_pearson[0])
        val_pearsons.append(val_pearson[0])

        train_mses.append(train_mse)
        val_mses.append(val_mse)

        train_maes.append(train_mae)
        val_maes.append(val_mae)

        train_mapes.append(train_mape)
        val_mapes.append(val_mape)

    train_r2, train_pearson = calc_train_r2(train_labels, train_preds, plot = True)
    val_r2, val_pearson = calc_val_r2(val_labels, val_preds, plot = True)

    train_mse = calc_train_mse(train_labels, train_preds)
    val_mse = calc_val_mse(val_labels, val_preds)

    train_mae = calc_train_mae(train_labels, train_preds)
    val_mae = calc_val_mae(val_labels, val_preds)

    train_mape = calc_train_mape(train_labels, train_preds)
    val_mape = calc_val_mape(val_labels, val_preds)

    print(f'Train R2: {train_r2:.4f}')
    print(f'Train Pearson: {train_pearson[0]:.4f}')
    print(f'Train MSE: {train_mse:.4f}')
    print(f'Train MAE: {train_mae:.4f}')
    print(f'Train MAPE: {train_mape:.4f}%')
    print()
    print(f'Validation R2: {val_r2:.4f}')
    print(f'Validation Pearson: {val_pearson[0]:.4f}')
    print(f'Validation MSE: {val_mse:.4f}')
    print(f'Validation MAE: {val_mae:.4f}')
    print(f'Validation MAPE: {val_mape:.4f}%')

    calc_plot_loss(train_losses, val_losses)
    calc_r2_epochs(train_r2s, val_r2s)
    calc_pearson_epochs(train_pearsons, val_pearsons)
    calc_mse_epochs(train_mses, val_mses)
    calc_mae_epochs(train_maes, val_maes)
    calc_mape_epochs(train_mapes, val_mapes)

    results = {
        'train_r2': train_r2,
        'train_mse': train_mse,
        'train_mae': train_mae,
        'train_mape': train_mape,
        'val_r2': val_r2,
        'val_mse': val_mse,
        'val_mae': val_mae,
        'val_mape': val_mape
    }

    with open('/content/drive/MyDrive/BINF_4008_Final_Project/Pickles/' + self.model_name + '_training_results.pkl', 'wb') as f:
      pickle.dump(results, f)

    with open('/content/drive/MyDrive/BINF_4008_Final_Project/Pickles/' + self.model_name + '_model.pkl', 'wb') as f:
      pickle.dump(self.model, f)

    with open('/content/drive/MyDrive/BINF_4008_Final_Project/Pickles/' + self.model_name + '_train_losses.pkl', 'wb') as f:
      pickle.dump(train_losses, f)

    with open('/content/drive/MyDrive/BINF_4008_Final_Project/Pickles/' + self.model_name + '_val_losses.pkl', 'wb') as f:
      pickle.dump(val_losses, f)

    with open('/content/drive/MyDrive/BINF_4008_Final_Project/Pickles/' + self.model_name + '_train_labels.pkl', 'wb') as f:
      pickle.dump(train_labels, f)

    with open('/content/drive/MyDrive/BINF_4008_Final_Project/Pickles/' + self.model_name + '_train_preds.pkl', 'wb') as f:
      pickle.dump(train_preds, f)

    with open('/content/drive/MyDrive/BINF_4008_Final_Project/Pickles/' + self.model_name + '_val_labels.pkl', 'wb') as f:
      pickle.dump(val_labels, f)

    with open('/content/drive/MyDrive/BINF_4008_Final_Project/Pickles/' + self.model_name + '_val_preds.pkl', 'wb') as f:
      pickle.dump(val_preds, f)

    return results

  def bootstrap_test_set(self):

    test_results, test_full_results = self._bootstrap_test_set(self.test_dataloader, 'test')
    male_test_results, male_full_test_results = self._bootstrap_test_set(self.male_dataloader, 'male_test')
    female_test_results, female_full_test_results = self._bootstrap_test_set(self.female_dataloader, 'female_test')

    return {
        'test': [test_results, test_full_results],
        'male': [male_test_results, male_full_test_results],
        'female': [female_test_results, female_full_test_results]
    }

  def _bootstrap_test_set(self, dataloader, split):

    all_labels = torch.empty(0).to(self.device)
    all_preds = torch.empty(0).to(self.device)

    running_loss = 0.0

    self.model.eval()

    with torch.no_grad():

      for i, instance in enumerate(dataloader):

        image_input, tab_input, labels = instance[0].to(torch.float32).to(self.device), instance[1].to(torch.float32).to(self.device).reshape(-1, 1), instance[2].to(torch.float32).to(self.device)

        outputs = self.model(image_input, tab_input)

        all_labels = torch.cat((all_labels, labels), dim = 0)
        all_preds = torch.cat((all_preds, outputs), dim = 0)

    all_labels = all_labels.cpu().numpy()
    all_preds = all_preds.cpu().numpy()

    rng = np.random.RandomState(seed = 6)
    idx = np.arange(len(all_labels))

    test_r2 = []
    test_pearson = []
    test_mse = []
    test_mae = []
    test_mape = []

    calc_test_r2 = CorrAndRSquared(self.model_name, split)
    calc_test_mse = MSE(self.model_name, split)
    calc_test_mae = MAE(self.model_name, split)
    calc_test_mape = MAPE(self.model_name, split)

    _ = calc_test_r2(all_labels, all_preds, plot = True)

    for x in range(200):

      pred_idx = list(rng.choice(idx, size = idx.shape[0], replace = True))

      test_r2.append(calc_test_r2(all_labels[pred_idx], all_preds[pred_idx], plot = False)[0])
      test_pearson.append(calc_test_r2(all_labels[pred_idx], all_preds[pred_idx], plot = False)[1][0])
      test_mse.append(calc_test_mse(all_labels[pred_idx], all_preds[pred_idx]))
      test_mae.append(calc_test_mae(all_labels[pred_idx], all_preds[pred_idx]))
      test_mape.append(calc_test_mape(all_labels[pred_idx], all_preds[pred_idx]))

    bootstrap_r2 = np.mean(test_r2)
    r2_lower_ci = np.percentile(test_r2, 2.5)
    r2_upper_ci = np.percentile(test_r2, 97.5)

    bootstrap_pearson = np.mean(test_pearson)
    pearson_lower_ci = np.percentile(test_pearson, 2.5)
    pearson_upper_ci = np.percentile(test_pearson, 97.5)

    bootstrap_mse = np.mean(test_mse)
    mse_lower_ci = np.percentile(test_mse, 2.5)
    mse_upper_ci = np.percentile(test_mse, 97.5)

    bootstrap_mae = np.mean(test_mae)
    mae_lower_ci = np.percentile(test_mae, 2.5)
    mae_upper_ci = np.percentile(test_mae, 97.5)

    bootstrap_mape = np.mean(test_mape)
    mape_lower_ci = np.percentile(test_mape, 2.5)
    mape_upper_ci = np.percentile(test_mape, 97.5)

    results = {
        'r2': [r2_lower_ci, bootstrap_r2, r2_upper_ci],
        'pearson': [pearson_lower_ci, bootstrap_pearson, pearson_upper_ci],
        'mse': [mse_lower_ci, bootstrap_mse, mse_upper_ci],
        'mae': [mae_lower_ci, bootstrap_mae, mae_upper_ci],
        'mape': [mape_lower_ci, bootstrap_mape, mape_upper_ci]
    }

    full_results = {
        'r2': test_r2,
        'pearson': test_pearson,
        'mse': test_mse,
        'mae': test_mae,
        'mape': test_mape
    }


    with open('/content/drive/MyDrive/BINF_4008_Final_Project/Pickles/' + self.model_name + '_' + split + '_preds.pkl', 'wb') as f:
      pickle.dump(all_preds, f)

    with open('/content/drive/MyDrive/BINF_4008_Final_Project/Pickles/' + self.model_name + '_' + split + '_labels.pkl', 'wb') as f:
      pickle.dump(all_labels, f)

    with open('/content/drive/MyDrive/BINF_4008_Final_Project/Pickles/' + self.model_name + '_' + split + '_results.pkl', 'wb') as f:
      pickle.dump(results, f)

    with open('/content/drive/MyDrive/BINF_4008_Final_Project/Pickles/' + self.model_name + '_full_' + split + '_results.pkl', 'wb') as f:
      pickle.dump(full_results, f)

    return results, full_results