# Transfer Sepsis to TB
This notebook looks at whether a model trained on sepsis data learns anything about how to classify tuberculosis data. Both the sepsis and TB samples are a combination of whole blood microarrays and RNA-seq.

In [1]:
import argparse                                                                                     
import logging                                                                                      
import numpy as np 
import pickle
import random
import sys

import torch

# Add whistl modules to the path
sys.path.append('../whistl')
import classifier
import dataset
import model
import plot_util
import util

In [2]:
# Tell pytorch to use the gpu
device = torch.device('cuda')

In [3]:
# Set up logging
logging.basicConfig(level=logging.ERROR)                                                         
logger = logging.getLogger(__name__)

In [4]:
# Ensure the models train deterministically
seed = 42

np.random.seed(seed)                                                                       
random.seed(seed)                                                                          
torch.manual_seed(seed)                                                                    
if torch.backends.cudnn.enabled:                                                                
    torch.backends.cudnn.deterministic = True                                                   
    torch.backends.cudnn.benchmark = False

## Train models on sepsis using ERM and IRM

In [5]:
# Initialize arguments to use in training the models
map_file = '../data/sample_classifications.pkl'
gene_file = '../data/intersection_genes.csv'
num_epochs = 1500
loss_scaling_factor = 1

In [6]:
# Select a classifier architecture
label_to_encoding = {'sepsis': 1, 'healthy': 0}
sample_to_label = util.parse_map_file(map_file)

net = model.ThreeLayerNet

In [7]:
# Split train and test data

data_dirs = util.get_data_dirs('../data')                                                      
train_dirs, test_dirs = util.extract_test_dirs(data_dirs, 'tb', sample_to_label)                                               

train_dirs, tune_dirs = util.train_tune_split(train_dirs, 2)

In [8]:
erm_results = classifier.train_with_erm(net, map_file, train_dirs, tune_dirs, gene_file, num_epochs, 
                                        label_to_encoding, device, logger, '../logs/sepsis_erm.pkl', 20)

HBox(children=(IntProgress(value=0, max=1500), HTML(value='')))




In [9]:
irm_results = classifier.train_with_irm(net, map_file, train_dirs, tune_dirs, gene_file, num_epochs, 
                                        loss_scaling_factor, label_to_encoding, device, logger, '../logs/sepsis_irm.pkl', 100)

HBox(children=(IntProgress(value=0, max=1500), HTML(value='')))




### Print tuning set accuracy at the end of training for ERM and IRM models

In [22]:
print(erm_results['tune_acc'][-5:])
print(irm_results['tune_acc'][-5:])

[1.0, 1.0, 1.0, 1.0, 1.0]
[0.9459459459459459, 0.9459459459459459, 0.8918918918918919, 0.9459459459459459, 0.9054054054054054]


## Run models on tuberculosis samples

In [10]:
# Set the disease encoding
label_to_encoding = {'tb': 1, 'healthy': 0}                                                 

In [11]:
# Load tuberculosis data
tb_dataset = dataset.ExpressionDataset(test_dirs, sample_to_label, label_to_encoding, gene_file)
tb_loader = torch.utils.data.DataLoader(tb_dataset, batch_size=16, num_workers=4, pin_memory=True)

## Evaluate erm and irm models on tb set

In [12]:
erm_net = torch.load('../logs/sepsis_erm.pkl')
irm_net = torch.load('../logs/sepsis_irm.pkl')

In [25]:
erm_loss = 0
erm_correct = 0
irm_loss = 0
irm_correct = 0

erm_tb_preds= 0
irm_tb_preds = 0


for tb_batch in tb_loader:
    expression, labels, ids = tb_batch                                            
    tune_expression = expression.to(device)                                         
    tune_labels = labels.to(device).double()                                        

    loss_function = torch.nn.BCEWithLogitsLoss()                                          

    erm_preds = erm_net(tune_expression)
    loss = loss_function(erm_preds, tune_labels)
    erm_loss += float(loss)
    erm_correct += util.count_correct(erm_preds, tune_labels)
    
    irm_preds = irm_net(tune_expression)      
    loss = loss_function(irm_preds, tune_labels)
    irm_loss += float(loss)
    irm_correct += util.count_correct(irm_preds, tune_labels)
    
    # Count how many times the model predicted something other than zero
    erm_tb_preds += np.sum(erm_preds.cpu().detach().numpy() >= 0)
    irm_tb_preds += np.sum(irm_preds.cpu().detach().numpy() >= 0)
    
avg_erm_loss = erm_loss / len(tb_dataset)
erm_acc = erm_correct / len(tb_dataset)
avg_irm_loss = irm_loss / len(tb_dataset)
irm_acc = irm_correct / len(tb_dataset)

In [14]:
print(erm_tb_preds)
print(irm_tb_preds)
print(len(tb_dataset))

808
520
907


In [18]:
print('ERM trained network tb accuracy: {}'.format(erm_acc))
print('ERM trained network tb loss: {}'.format(avg_erm_loss))
print('IRM trained network tb accuracy: {}'.format(irm_acc))
print('IRM trained network tb loss: {}'.format(avg_irm_loss))

ERM trained network tb accuracy: 0.7375964718853363
ERM trained network tb loss: 0.06897230209944204
IRM trained network tb accuracy: 0.6008820286659317
IRM trained network tb loss: 0.04761558565200637


### Find percent healthy
Our base rate is the percentage of healthy samples in the test data

In [23]:
total_healthy = 0
for test_batch in tb_loader:
    expression, labels, ids = test_batch
    for label in labels:
        if label == label_to_encoding['healthy']:
            total_healthy += 1

print('Prercent healthy = {}'.format(total_healthy / len(tb_dataset)))

Prercent healthy = 0.23042998897464168


## Save Results

In [17]:
with open('../logs/sepsis_erm_results.pkl', 'wb') as out_file:
    pickle.dump(erm_results, out_file)
with open('../logs/sepsis_irm_results.pkl', 'wb') as out_file:
    pickle.dump(irm_results, out_file)

## Conclusion

While they don't perform as well on the transferred domain as the source domain, both the erm and irm trained networks have better than random (and better than the frequency of the more common class). Interestingly, the ERM trained model predicted tuberculosis 8/9 of the time, while the IRM trained model predicted tb about 5/9 of the time. The true frequency of tb in the dataset is 77%.

These results are evidence that domain adaptation works to some extent on gene expression data even between diseases.