# Test Model Saving
This notebook is designed to make sure saving and loading models works correctly

In [17]:
import argparse                                                                                     
import logging                                                                                      
import numpy as np   
import pickle
import random
import sys

import torch

# Add whistl modules to the path
sys.path.append('../whistl')
import classifier
import dataset
import model
import plot_util
import util

In [2]:
# Tell pytorch to use the gpu
device = torch.device('cuda')

In [3]:
# Set up logging
logging.basicConfig(level=logging.ERROR)                                                         
logger = logging.getLogger(__name__)

In [4]:
# Ensure the models train deterministically
seed = 42

np.random.seed(seed)                                                                       
random.seed(seed)                                                                          
torch.manual_seed(seed)                                                                    
if torch.backends.cudnn.enabled:                                                                
    torch.backends.cudnn.deterministic = True                                                   
    torch.backends.cudnn.benchmark = False

In [5]:
# Select a classifier architecture
label_to_encoding = {'tb': 1, 'healthy': 0}                                                 
net = model.ThreeLayerNet

In [6]:
# Split train and test data
train_dirs, tune_dirs = util.train_tune_split('../data/', 2)

In [7]:
# Initialize arguments to use in training the models
map_file = '../data/sample_classifications.pkl'
gene_file = '../data/intersection_genes.csv'
num_epochs = 1500
loss_scaling_factor = 1

## Train a three layer neural network with IRM

In [8]:
irm_results = classifier.train_with_irm(net, map_file, train_dirs, tune_dirs, gene_file, num_epochs, 
                                        loss_scaling_factor, label_to_encoding, device, logger, '../logs/irm.pkl', 5)

HBox(children=(IntProgress(value=0, max=1500), HTML(value='')))




## Load the model and ensure the weights saved properly

In [9]:
trained_net = torch.load('../logs/irm.pkl')

In [10]:
sample_to_label = util.parse_map_file(map_file)

tune_dataset = dataset.ExpressionDataset(tune_dirs, sample_to_label, label_to_encoding, gene_file)
tune_loader = torch.utils.data.DataLoader(tune_dataset, batch_size=16, num_workers=4, pin_memory=True)

### Test trained network

In [11]:
tune_loss = 0
tune_correct = 0

for tune_batch in tune_loader:
    expression, labels, ids = tune_batch                                            
    tune_expression = expression.to(device)                                         
    tune_labels = labels.to(device).double()                                        

    loss_function = torch.nn.BCEWithLogitsLoss()                                          

    tune_preds = trained_net(tune_expression)                                        
    loss = loss_function(tune_preds, tune_labels)
    
    tune_loss += float(loss)
    tune_correct += util.count_correct(tune_preds, tune_labels)
    
avg_loss = tune_loss / len(tune_dataset)
tune_acc = tune_correct / len(tune_dataset)

In [12]:
print('Trained network tune accuracy: {}'.format(tune_acc))
print('Trained network tune loss: {}'.format(avg_loss))

Trained network tune accuracy: 1.0
Trained network tune loss: 5.13992141030165e-19


### Test untrained network

In [13]:
input_size = tune_dataset[0][0].shape[0]
untrained_net = model.ThreeLayerNet(input_size).double().to(device)

In [14]:
tune_loss = 0
tune_correct = 0

for tune_batch in tune_loader:
    expression, labels, ids = tune_batch                                            
    tune_expression = expression.to(device)                                         
    tune_labels = labels.to(device).double()                                        

    loss_function = torch.nn.BCEWithLogitsLoss()                                          

    tune_preds = untrained_net(tune_expression)                                        
    loss = loss_function(tune_preds, tune_labels)
    
    tune_loss += float(loss)
    tune_correct += util.count_correct(tune_preds, tune_labels)
    
avg_loss = tune_loss / len(tune_dataset)
tune_acc = tune_correct / len(tune_dataset)

In [15]:
print('Untrained network tune accuracy: {}'.format(tune_acc))
print('Untrained network tune loss: {}'.format(avg_loss))

Untrained network tune accuracy: 0.0
Untrained network tune loss: 0.0527816792333729


## Save results to a file to keep track of genes and samples used

In [18]:
with open('../logs/model_saving_test_results.pkl', 'wb') as out_file:
    pickle.dump(irm_results, out_file)

## Conclusion
The model saving functions in `classifier.py` work, and the trained network outperforms an untrained one.