# Solve Selection Bias by Multitasking

- Train multitasking model on biased data to predict risk as well as sensoring
- Use sensoring prediction task to identify units and use risk prediction for identified units only.


In [None]:
# set working directory
from random import SystemRandom
import pandas as pd
import numpy as np
import xgboost as xgb
# from hyperopt import hp, fmin, tpe, STATUS_OK, Trials
from sklearn.model_selection import cross_val_score
from sklearn.metrics import roc_auc_score
from sklearn.model_selection import RandomizedSearchCV, GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.preprocessing import RobustScaler
  
import os
import pickle
from sklearn.model_selection import train_test_split
import wandb

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import GridSearchCV
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score

# from torch._C import float32
import argparse
from asyncio.log import logger
import os, math
import logging
import torch
import numpy as np
import json

import torch.nn as nn
import torch

import pickle
import json
import numpy as np
import random
from torch.utils.data import Dataset, DataLoader, random_split, WeightedRandomSampler
import torch
import os
import torch.nn as nn
import torch.optim as optim
from sklearn.datasets import load_iris
from sklearn.datasets import load_breast_cancer

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

import torch.nn as nn
import torch
import math
import pandas as pd
import random

# Importing matplotlib and seaborn
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler

import sys
import warnings

import numpy as np
import pandas as pd

from IPython.display import HTML, display
# import tabulate

# import utils
from utils import *

if not sys.warnoptions:
    warnings.simplefilter("ignore")

%matplotlib inline

In [None]:
MINI_BATCH = 64
MINI_BATCH2 = 256
EPOCHS = 1000
LOAD = None
SEED = 42
REPEAT = 10

if torch.cuda.is_available():
    device = torch.device("cuda")  # Set device to GPU
    print("CUDA is available! Using GPU.")
else:
    device = torch.device("cpu")  # Set device to CPU
    print("CUDA is not available. Using CPU.")

torch.manual_seed(SEED)
np.random.seed(SEED)
torch.cuda.manual_seed(SEED)
random.seed(SEED)

param_grid = {
            'drop_rate': [0.1],
            'hidden_sizes':[[50], [100], [100, 100]],# [[50], [50, 50], [50, 100], [100], [100, 100]],
            'head_sizes':[[50], [100]],# [[50], [100]],
            'lr':[0.0001, 0.0005]
        }
# params_risk = {
#             'drop_rate': 0.05,
#             'hidden_sizes':[50],# [[50], [50, 50], [50, 100], [100], [100, 100]],
#             'head_sizes':[32],# [[50], [100]],
#             'lr':0.0001
#         }
# best_param = params_risk 
############################################
# initilising wandb
# wandb.init(project='SeletionBML', entity="jmdvinodjmd")
wandb.init(mode="disabled")
wandb.run.name = 'SB'
makedirs('./results/')
experimentID = LOAD
if experimentID is None:
    experimentID = int(SystemRandom().random()*100000)
# checkpoint
ckpt_path = os.path.join('./results/checkpoints/MT_model.ckpt')
makedirs('./results/checkpoints/')
# set logger
log_path = os.path.join("./results/logs/" + "exp_MT_" + str(experimentID) + ".log")
makedirs("./results/logs/")
logger = get_logger(logpath=log_path, filepath="exp_MT_" + str(experimentID) + ".log", displaying=False)
logger.info("Experiment " + str(experimentID))
############################################

In [None]:
def experiment(data, params, repeat=1):
    [X_train, y_train, s_train, X_val, y_val, s_val, X_test, y_test, s_test] = data

    # loader preparation
    loader_train, input_size = get_loaders([X_train, y_train, s_train], batch_size=MINI_BATCH, is_train=True, device=device)
    loader_val, _ = get_loaders([X_val, y_val, s_val], batch_size=X_val.shape[0], is_train=False, device=device)
    loader_test, _ = get_loaders([X_test, y_test, s_test], batch_size=X_test.shape[0], is_train=False, device=device)

    # repeating experiment for a given number of times
    results_risk = {}
    for i in range(repeat):
        logger.info('Repeating: ' + str(i+1))
        results_risk[i] = {}
        #############################
        # train model
        model_risk, optimizer, criterion = create_model('Multitasking', params, input_size, output_size=1, device=device)
        early_stopping = EarlyStopping(patience=10, path=ckpt_path, verbose=True, logger=logger)
        logger.info(model_risk)
        wandb.watch(model_risk)
        # train
        model_risk = train_model(model_risk, 'Multitasking', loader_train, loader_val, optimizer, criterion, early_stopping, logger, epochs=EPOCHS, plot=False, wandb=wandb)
        # evaluate
        auroc_vb, best_threshold, _ = evaluate_model('Val', loader_val, model_risk, 'Multitasking', criterion, logger, -1, device, wandb)
        auroc_tb, _, _ = evaluate_model('Val', loader_test, model_risk, 'Multitasking', criterion, logger, -1, device, wandb)
        auroc_tu, _, _ = evaluate_model('Test', loader_test, model_risk, 'Multitasking', criterion, logger, -1, device, wandb)
        
        logger.info('Biased VAL AUROC:' + str(auroc_vb['Val AUROC']) + ' biased Test AUROC:' + str(auroc_tb['Val AUROC']) + ' unbiased Test AUROC:' + str(auroc_tu['Test AUROC']))
        results_risk[i] = {'VAL AUROC-B':auroc_vb['Val AUROC'], 'Test AUROC-B':auroc_tb['Val AUROC'], 'Test AUROC-U':auroc_tu['Test AUROC'], 'C-Test AUROC-U':auroc_tu['C-Test AUROC']}

        ############################
        # Identify sensored and predict risk for unsensored
        logger.info('Best threshold:'+ str(best_threshold))
        _, preds_sensoring, _ = model_risk(torch.tensor(X_test, dtype=torch.float).to(device))
        sensored_units = (preds_sensoring > torch.tensor(best_threshold, dtype=torch.float).to(device)).cpu().numpy().astype(int)
        loader_utest, _ = get_loaders([X_test[sensored_units.squeeze()==0], y_test[sensored_units.squeeze()==0], s_test[sensored_units.squeeze()==0]], batch_size=y_test.shape[0], is_train=False, device=device)
        
        # check if all censored or none
        if (sensored_units.sum()==0) or (sensored_units.sum()==X_test.shape[0]) or (y_test[sensored_units.squeeze()==0].sum()==0) or (y_test[sensored_units.squeeze()==0].sum()==y_test[sensored_units.squeeze()==0].shape[0]):
            auroc = {'TB-EU:Multitasking AUROC':-1}
            continue
        else:
            logger.info('sizes:'+str(X_test[sensored_units.squeeze()==0].shape)+str(X_test.shape)+'---------------')
            logger.info('sizes:'+str(y_test[sensored_units.squeeze()==0].shape)+str(y_test[sensored_units.squeeze()==0].sum())+'---------------')
            auroc, _, _ = evaluate_model('TB-EU:Multitasking', loader_utest, model_risk, 'Multitasking', criterion, logger, -1, device, wandb)

        logger.info('Multitasking Test AUROC-U:' + str(auroc['TB-EU:Multitasking AUROC']) + '. Unensored/Total:' + str(X_test[sensored_units.squeeze()==0].shape[0])+ '/'+str(X_test.shape[0])\
                    +'. Actual Unensored/Total:' + str(X_test[s_test.squeeze()==0].shape[0])+ '/'+str(X_test.shape[0]))
        results_risk[i].update({'Test AUROC-Multitasking':auroc['TB-EU:Multitasking AUROC'], 'Predicted Sensored':sensored_units.sum(), 'Actual Sensored':s_test.sum()})
    
    return results_risk

def study_effect(data_name, file_name, results_file, r, c, n, search_param=False):
    ''' 
    This function is used to study effect of (riks rate, dataset size etc.).
    It expects a set of datasets with some variations.
    '''
    logger.info('\n\n-------------N:'+str(n)+'--Risk Rate:' + str(r)+'--Censoring Rate:' + str(c)+'-------------------------.')

    results_sizes = {}
    for ni in n:
        for ci in c:
            for ri in r:
                # load data dictionary
                data_dict = get_data_dict(file_name, [ri], [ci], [ni])

                logger.info('-----Running for Size:'+str(ni)+'--Risk Rate:' + str(ri)+'--Censoring Rate:' + str(ci)+'\n-----------')
                [X_train, y_train, s_train, X_val, y_val, s_val, X_test, y_test, s_test] = data_dict[str(ni)+'R'+str(ri)+'C'+str(ci)]
                data = [X_train, y_train, s_train, X_val, y_val, s_val, X_test, y_test, s_test]
        
                ##############################
                # Reading hyperparameters from the JSON file
                with open('best_hyperparams.json', 'r') as json_file:
                    best_hyperparams = json.load(json_file)
                if ('Multitasking-'+data_name+str(ni)+'R'+str(ri)+'C'+str(ci) not in best_hyperparams) or search_param:
                    # hyperparameter tuning  
                    logger.info('Finding best hyperparams.') 
                    loader_train_br, input_size = get_loaders([X_train, y_train, s_train], batch_size=MINI_BATCH, is_train=True, device=device)
                    loader_val_br, _ = get_loaders([X_val, y_val, s_val], batch_size=y_val.shape[0], is_train=False, device=device)
                    best_param, best_score, results = grid_search_MLP('Multitasking', loader_train_br, loader_val_br, input_size, ckpt_path, param_grid, EPOCHS, logger, wandb, device)
                    logger.info('Hyperparam tuning for Multitasking network:')
                    logger.info(results)

                    best_hyperparams['Multitasking-'+data_name+str(ni)+'R'+str(ri)+'C'+str(ci)] = {'best_param': best_param}
                    # save best params
                    with open('best_hyperparams.json', 'w') as json_file:
                        json.dump(best_hyperparams, json_file)
                
                else:
                    logger.info('Accessing the existing best hyperparams.')
                    best_param = best_hyperparams['Multitasking-'+data_name+str(ni)+'R'+str(ri)+'C'+str(ci)]['best_param']

                ################################
                # run experiments and repeat for given number of times
                results = experiment(data, best_param, repeat=REPEAT)
                logger.info('\n\nBest params for Multitasking network:\n' + str(best_param))
                logger.info(results)
                results_sizes[str(ni)+'R'+str(ri)+'C'+str(ci)] = results

                # save results
                dict_to_file(results_file, results_sizes)
                ################################

    logger.info('\n\n------------------- Experiments ended-------------------.\n'+str(results_sizes)+'\n------------------------------------------------\n\n')

    return results_sizes

## Synthetic

In [None]:
results_sizes = study_effect('synthetic', 'selection_bias_data.pkl', 'results_MTNet', r=[.05, .1, .2, .3, .4], c=[.05, .1, .2, .3, .4], n=[1000, 2000, 3000, 4000, 5000], search_param=False)

## Diabetes 

In [None]:
results_sizes = study_effect('diabetes', 'diabetes_bias_data.pkl', 'results_MTNet-diabetes', r=[.05, .1, .2, .3, .4], c=[.05, .1, .2, .3, .4], n=[25000, 10000, 5000, 2000, 1000], search_param=False)

## Covid 

In [None]:
results_sizes = study_effect('covid', 'covid_bias_data.pkl', 'results_MTNet-covid', r=[.05, .1, .2, .3, .4], c=[.3, .4], n=[15000, 10000, 5000, 2000, 1000], search_param=False)