In [1]:
import pandas as pd
import numpy as np
from datasets import Dataset
from transformers import (T5ForConditionalGeneration,
                          T5Tokenizer)
from nltk import sent_tokenize
from sklearn.model_selection import train_test_split

import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.autograd import Variable

device = torch.device("cuda:1")

BATCH_SIZE = 8
EPOCHS     = 2

def encode(examples, tokenizer):
    inputs = examples['TEXT']    
    tokenized_inputs = tokenizer(inputs,
                                 return_tensors='pt',
                                 max_length=512,
                                 truncation=True,
                                 padding=True)
    model_inputs = {}
    model_inputs['input_ids']      = tokenized_inputs['input_ids']
    model_inputs['attention_mask'] = tokenized_inputs['attention_mask']
    model_inputs['labels']         = tokenized_inputs['input_ids']
    return model_inputs, nn.functional.one_hot(examples['ETHNICITY'], num_classes=2)

class Discriminator(nn.Module):
    def __init__(self, input_dim, output_dim=2):
        super(Discriminator, self).__init__()
        
        self.hidden_layer1 = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.LeakyReLU(0.2)
        )

        self.hidden_layer2 = nn.Sequential(
            nn.Linear(128, 32),
            nn.LeakyReLU(0.2),
            nn.Dropout(0.3)
        )

        self.hidden_layer3 = nn.Sequential(
            nn.Linear(32, output_dim),
            nn.Sigmoid()
        )

    def forward(self, x, labels=None):
        output = self.hidden_layer1(x)
        output = self.hidden_layer2(output)
        output = self.hidden_layer3(output)
        return output

def train_generator(batch):
        
    model.train()
    generator_optimizer.zero_grad()
    criterion = nn.BCELoss()

    # Get the generator outputs
    gen_outputs = model(**batch)
    emb = gen_outputs.encoder_last_hidden_state.mean(dim=1)
    dis_outputs = discriminator(emb)
    
    # Penalize the generator when the discriminator is able to figure
    # out what the correct race is
    temp = torch.ones((dis_outputs.shape[0], 2))*0.5
    temp = temp.to(device)
    gen_loss = criterion(dis_outputs, temp)
    gen_loss.backward()
    generator_optimizer.step()
    
    return gen_loss

def train_discriminator(batch, race):
        
    discriminator.train()
    discriminator_optimizer.zero_grad()
    criterion = nn.BCELoss()

    # Get the generator outputs
    gen_outputs = model(**batch)
    emb = gen_outputs.encoder_last_hidden_state.mean(dim=1)
    dis_outputs = discriminator(emb)
    
    # Penalize the generator when the discriminator is able to figure
    # out what the correct race is
    dis_loss = criterion(dis_outputs, race.float())
    dis_loss.backward()
    discriminator_optimizer.step()
    
    return dis_loss    

def eval_step(batch, race):
    
    model.eval()
    discriminator.eval()
    
    with torch.no_grad():
        criterion = nn.BCELoss()

        # Get the generator outputs
        gen_outputs = model(**batch)
        emb = gen_outputs.encoder_last_hidden_state.mean(dim=1)
        dis_outputs = discriminator(emb)
        
        # Penalize the generator when the discriminator is able to figure
        # out what the correct race is
        temp = torch.ones((dis_outputs.shape[0],2))*0.5
        temp = temp.to(device)
        gen_loss = criterion(dis_outputs, temp)
        
        # Penalize the discriminator for failing to figure out the race
        dis_loss = criterion(dis_outputs, race.float())
    
    return gen_loss, dis_loss

def train(ep):
    g_loss_train_total, d_loss_train_total = 0, 0
    g_loss_test_total,  d_loss_test_total  = 0, 0
    
    for steps, batch in enumerate(dataloader_train):
        print(f"Training: Batch {steps} of {len(dataloader_train)}")
        batch, race = encode(batch, tokenizer)
        batch['input_ids'] = batch['input_ids'].to(device)
        batch['attention_mask'] = batch['attention_mask'].to(device)
        batch['labels'] = batch['labels'].to(device)
        race = race.to(device)
        if steps % 500 == 0:
            model.save_pretrained(f'results/gan/model-{ep}-{steps}')
        
        g_loss_train = train_generator(batch)
        d_loss_train = train_discriminator(batch, race)
        g_loss_train_total += g_loss_train
        d_loss_train_total += d_loss_train

    for steps, batch in enumerate(dataloader_test):
        print(f"Evaluation: Batch {steps} of {len(dataloader_test)}")
        batch, race = encode(batch, tokenizer)
        batch['input_ids'] = batch['input_ids'].to(device)
        batch['attention_mask'] = batch['attention_mask'].to(device)
        batch['labels'] = batch['labels'].to(device)
        race = race.to(device)
        
        g_loss_test, d_loss_test = eval_step(batch, race)
        g_loss_test_total += g_loss_test
        d_loss_test_total += d_loss_test

    
    g_loss_train_total /= len(dataloader_train)
    d_loss_train_total /= len(dataloader_train)
    g_loss_test_total  /= len(dataloader_test)
    d_loss_test_total  /= len(dataloader_test)
    
    print(f"Train Loss: {g_loss_train_total} (Gen), {d_loss_train_total} (Disc)")
    print(f"Test Loss:  {g_loss_test_total} (Gen), {d_loss_test_total} (Disc)")
    



In [2]:
# Read the data, and load into dataloaders
print('Start!')
df = pd.read_csv('data/preprocessed.csv', lineterminator='\n')
df = df.sort_values(by=['SUBJECT_ID','HADM_ID','CHARTDATE'])\
                .groupby(['SUBJECT_ID','HADM_ID'])\
                .head(1).reset_index(drop=True)

df = df[['TEXT','ETHNICITY']]
df['ETHNICITY'] = 1*(df['ETHNICITY']=='BLACK')
df_train, df_test = train_test_split(df, test_size=0.2)

dataset_train = Dataset.from_pandas(df_train, split='train')
dataset_test  = Dataset.from_pandas(df_test,  split='test')

dataloader_train = DataLoader(dataset_train, batch_size=BATCH_SIZE, shuffle=True)
dataloader_test  = DataLoader(dataset_test,  batch_size=BATCH_SIZE, shuffle=True)

# Define the models
print('Loading models')
discriminator = Discriminator(512).to(device)

model = T5ForConditionalGeneration.from_pretrained("t5-small").to(device)
tokenizer = T5Tokenizer.from_pretrained("results/generator/checkpoint-33500",
                                            output_scores=True,
                                            output_hidden_states=True,
                                            model_max_length = 512)
# Define the optimizers
discriminator_optimizer = AdamW(discriminator.parameters(), 
                          lr=5e-3)
generator_optimizer = AdamW(model.parameters(), 
                      lr=5e-5)


Start!
Loading models


In [3]:
# Train
print('Training')
for ep in range(EPOCHS):
    train(ep)
    model.save_pretrained(f'results/gan/model-{ep}-final')

Training
Training: Batch 0 of 3485
Training: Batch 1 of 3485
Training: Batch 2 of 3485
Training: Batch 3 of 3485
Training: Batch 4 of 3485
Training: Batch 5 of 3485
Training: Batch 6 of 3485
Training: Batch 7 of 3485
Training: Batch 8 of 3485
Training: Batch 9 of 3485
Training: Batch 10 of 3485
Training: Batch 11 of 3485
Training: Batch 12 of 3485
Training: Batch 13 of 3485
Training: Batch 14 of 3485
Training: Batch 15 of 3485
Training: Batch 16 of 3485
Training: Batch 17 of 3485
Training: Batch 18 of 3485
Training: Batch 19 of 3485
Training: Batch 20 of 3485
Training: Batch 21 of 3485
Training: Batch 22 of 3485
Training: Batch 23 of 3485
Training: Batch 24 of 3485
Training: Batch 25 of 3485
Training: Batch 26 of 3485
Training: Batch 27 of 3485
Training: Batch 28 of 3485
Training: Batch 29 of 3485
Training: Batch 30 of 3485
Training: Batch 31 of 3485
Training: Batch 32 of 3485
Training: Batch 33 of 3485
Training: Batch 34 of 3485
Training: Batch 35 of 3485
Training: Batch 36 of 3485
Tr

KeyboardInterrupt: 