In [1]:
import pandas as pd
import random
from scipy.stats import bernoulli
from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW

from transformers import (T5ForConditionalGeneration,
                          T5Tokenizer)
from datasets import Dataset

device = torch.device("cuda:1")
random.seed(42)

BATCH_SIZE = 16
EPOCHS     = 2

def sample_disparity(col, race, tau):
    race = race.apply(lambda x: 1 if x=='WHITE' else -1)*tau
    temp = col+race
    temp = temp.apply(lambda x: 1 if x>1 else 0 if x<0 else x)
    return bernoulli.rvs(temp)

class Classifier(nn.Module):
    def __init__(self, input_dim, output_dim=1):
        super(Classifier, self).__init__()
        
        self.hidden_layer1 = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.LeakyReLU(0.2)
        )

        self.hidden_layer2 = nn.Sequential(
            nn.Linear(128, 32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )

        self.hidden_layer3 = nn.Sequential(
            nn.Linear(32, output_dim),
            nn.Sigmoid()
        )

    def forward(self, x, labels=None):
        output = self.hidden_layer1(x)
        output = self.hidden_layer2(output)
        output = self.hidden_layer3(output)
        return output

In [2]:
df = pd.read_csv('data/preprocessed.csv', lineterminator='\n')
df = df.sort_values(by=['SUBJECT_ID','HADM_ID','CHARTDATE'])\
                .groupby(['SUBJECT_ID','HADM_ID'])\
                .head(1)
df = df[['SUBJECT_ID','ETHNICITY','DIAGNOSIS','TEXT','apsiii']].reset_index(drop=True)
df['apsiii_norm'] = (df['apsiii']-min(df['apsiii']))/(max(df['apsiii'])-min(df['apsiii']))
df['actual_treatment'] = bernoulli.rvs(df['apsiii_norm'])
df['given_treatment'] = sample_disparity(df['apsiii_norm'], df['ETHNICITY'], 0.1)

df_train, df_test = train_test_split(df, test_size=0.2)

dataset_train = Dataset.from_pandas(df_train, split='train')
dataset_test  = Dataset.from_pandas(df_test,  split='test')

dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
dataloader_test  = DataLoader(dataset_test,  batch_size=BATCH_SIZE, shuffle=True)

In [3]:
def encode(examples, tokenizer):
    inputs = examples['TEXT']  
    labels = nn.functional.one_hot(examples['given_treatment'], 
                                   num_classes=2)
    tokenized_inputs = tokenizer(inputs,
                                 return_tensors='pt',
                                 max_length=512,
                                 truncation=True,
                                 padding=True)
    model_inputs = {}
    model_inputs['input_ids']      = tokenized_inputs['input_ids']
    model_inputs['attention_mask'] = tokenized_inputs['attention_mask']
    return model_inputs, labels.float()


In [4]:
# Load debiased GAN, only keep the encoder
model = T5ForConditionalGeneration.from_pretrained("results/gan/model-0-2000")
model_enc = model.encoder
del model
model_enc.to(device)

tokenizer = T5Tokenizer.from_pretrained("t5-small",
                                            output_scores=True,
                                            output_hidden_states=True,
                                            model_max_length=512)
# Define the models and optimizer
classifier = Classifier(input_dim=512, output_dim=2).to(device)
classifier_optimizer = AdamW(classifier.parameters(), 
                          lr=5e-3)

# Define loss
criterion = nn.BCELoss()

In [5]:
def train_classifier(batch, labels):
    classifier.train()
    classifier_optimizer.zero_grad()
    
    x = model_enc(**batch)                                       # Get embeddings from encoder
    x = x.last_hidden_state.sum(axis=1).detach()                 # Sum across emb_dim, detach

    pred = classifier(x)                                         # Make prediction
    loss = criterion(pred, labels)
    loss.backward()
    classifier_optimizer.step()
    return loss

def eval_classifier(batch, labels):
    classifier.eval()
    with torch.no_grad():
        x = model_enc(**batch)                                       # Get embeddings from encoder
        x = x.last_hidden_state.sum(axis=1).detach()                 # Sum across emb_dim, detach

        pred = classifier(x)                                         # Make prediction
        loss = criterion(pred, labels)
        return loss

In [6]:
def train(ep):
    train_loss, test_loss, steps = 0, 0, 0
    for batch in dataloader_train:
        steps += 1
        batch, labels           = encode(batch, tokenizer)           # Tokenize and obtain labels
        batch['input_ids']      = batch['input_ids'].to(device)      # Send to GPU
        batch['attention_mask'] = batch['attention_mask'].to(device)
        labels                  = labels.to(device)
        train_loss += train_classifier(batch, labels)
        if steps % 500 == 0:
            torch.save(classifier.state_dict(), f"results/gan/classifier-{ep}-{steps}")

    for batch in dataloader_test:
        batch, labels           = encode(batch, tokenizer)           # Tokenize and obtain labels
        batch['input_ids']      = batch['input_ids'].to(device)      # Send to GPU
        batch['attention_mask'] = batch['attention_mask'].to(device)
        labels                  = labels.to(device)
        test_loss += eval_classifier(batch, labels)
    
    train_loss /= len(dataloader_train)
    test_loss  /= len(dataloader_test)

In [7]:
for ep in range(EPOCHS):
    train(ep)
    torch.save(classifier.state_dict(), f"results/gan/classifier-{ep}-final")

In [40]:
race_lst, actual_lst, pred_lst, apsiii_lst = [], [], [], []
for idx, batch in enumerate(dataloader_test):
    print(f"Predicting: Batch {idx}")
    
    # Obtain actual treatment and race features
    actual_treatment        = nn.functional.one_hot(batch['actual_treatment'], 
                                                    num_classes=2).float()
    race                    = torch.tensor(list(map(lambda x: 1*(x=='BLACK'), 
                                                    batch['ETHNICITY'])))
    batch, labels           = encode(batch, tokenizer)           # Tokenize and obtain labels
    batch['input_ids']      = batch['input_ids'].to(device)      # Send to GPU
    batch['attention_mask'] = batch['attention_mask'].to(device)
    labels                  = labels.to(device)
    
    x = model_enc(**batch)                                       # Get embeddings from encoder
    x = x.last_hidden_state.sum(axis=1).detach()                 # Sum across emb_dim, detach
    pred = classifier(x).cpu().detach()                          # Make prediction
    
    race_lst.append(race)
    actual_lst.append(actual_treatment)
    pred_lst.append(pred)


Predicting: Batch 0
Predicting: Batch 1
Predicting: Batch 2
Predicting: Batch 3
Predicting: Batch 4
Predicting: Batch 5
Predicting: Batch 6
Predicting: Batch 7
Predicting: Batch 8
Predicting: Batch 9
Predicting: Batch 10
Predicting: Batch 11
Predicting: Batch 12
Predicting: Batch 13
Predicting: Batch 14
Predicting: Batch 15
Predicting: Batch 16
Predicting: Batch 17
Predicting: Batch 18
Predicting: Batch 19
Predicting: Batch 20
Predicting: Batch 21
Predicting: Batch 22
Predicting: Batch 23
Predicting: Batch 24
Predicting: Batch 25
Predicting: Batch 26
Predicting: Batch 27
Predicting: Batch 28
Predicting: Batch 29
Predicting: Batch 30
Predicting: Batch 31
Predicting: Batch 32
Predicting: Batch 33
Predicting: Batch 34
Predicting: Batch 35
Predicting: Batch 36
Predicting: Batch 37
Predicting: Batch 38
Predicting: Batch 39
Predicting: Batch 40
Predicting: Batch 41
Predicting: Batch 42
Predicting: Batch 43
Predicting: Batch 44
Predicting: Batch 45
Predicting: Batch 46
Predicting: Batch 47
Pr

In [60]:
race_lst   = torch.concat(race_lst)
actual_lst = torch.concat(actual_lst)
pred_lst   = torch.concat(pred_lst) 
apsiii_lst = torch.concat(apsiii_lst)

idx_0 = torch.where(race_lst==0)[0]
idx_1 = torch.where(race_lst==1)[0]


In [73]:
from sklearn.metrics import roc_auc_score

In [125]:
torch.mean(1.0*(actual_lst[idx_0].argmax(dim=1)==pred_lst[idx_0].argmax(dim=1)))

tensor(0.7847)

In [126]:
torch.mean(1.0*(actual_lst[idx_1].argmax(dim=1)==pred_lst[idx_1].argmax(dim=1)))

tensor(0.7581)

In [84]:
roc_auc_score(y_true  = actual_lst[idx_0].numpy(),
              y_score = pred_lst[idx_0].numpy())

0.4956999550579042

In [85]:
roc_auc_score(y_true  = actual_lst[idx_1].numpy(),
              y_score = pred_lst[idx_1].numpy())

0.5349819342123786

In [93]:
import numpy as np

In [96]:
actual_lst

tensor([[1., 0.],
        [1., 0.],
        [1., 0.],
        ...,
        [1., 0.],
        [1., 0.],
        [1., 0.]])

In [101]:
df_analysis = pd.DataFrame({'race':race_lst.numpy(),
              'actual':np.argmax(actual_lst.numpy(), axis=1),
              'pred_not_prescribe':pred_lst[:,0].numpy(),
              'pred_prescribe':pred_lst[:,1].numpy(),
              'apsiii':apsiii_lst.numpy()})


In [104]:
df_analysis['prescribe'] = 1*(df_analysis['pred_prescribe'] > df_analysis['pred_not_prescribe'])

In [107]:
df_analysis['correct'] = 1*(df_analysis['prescribe']==df_analysis['actual'])

In [131]:
for i in range(8):
    if i<=5:
        temp = df_analysis.loc[(df_analysis.apsiii >= 20*i)&\
                            (df_analysis.apsiii < 20*(i+1))]\
                                .groupby('race')\
                                .aggregate({'correct':'mean', 'prescribe':'mean', 'apsiii':'count'})
    else:
        temp = df_analysis.loc[(df_analysis.apsiii >= 20*i)]\
                                .groupby('race')\
                                .aggregate({'correct':'mean', 'prescribe':'mean', 'apsiii':'count'})
    print(f"Accuracy (Black): {round(temp['correct'][1],3)} ({temp['apsiii'][1]} patients), Accuracy (White): {round(temp['correct'][0],3)} ({temp['apsiii'][0]} patients)")


Accuracy (Black): 0.789 (95 patients), Accuracy (White): 0.788 (688 patients)
Accuracy (Black): 0.729 (376 patients), Accuracy (White): 0.782 (2580 patients)
Accuracy (Black): 0.78 (246 patients), Accuracy (White): 0.786 (1966 patients)
Accuracy (Black): 0.787 (75 patients), Accuracy (White): 0.794 (617 patients)
Accuracy (Black): 0.759 (29 patients), Accuracy (White): 0.768 (198 patients)
Accuracy (Black): 0.833 (12 patients), Accuracy (White): 0.778 (63 patients)
Accuracy (Black): 0.5 (2 patients), Accuracy (White): 0.783 (23 patients)
Accuracy (Black): 0.5 (2 patients), Accuracy (White): 0.778 (9 patients)
