<a href="https://colab.research.google.com/github/aaekay/CovidAID/blob/master/CovidAID.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Enable GPU first

In [None]:
!rm -rf ./CovidAID
!git clone --recursive https://github.com/aaekay/CovidAID.git
!wget https://sogppg.bl.files.1drv.com/y4m7xCVzOJm2tZO9UttZgEidLVEYUsYqN4zjak7oGSlj9k75JYCOwQNF-NXmkolYeS0oUJ8STXyq5wKG7SHTBXF_uRXefR4FbA0x8-b-lQVHNVqp5Ts72XLkxvG4R6Z_EonfAX-XYeSRpkSgmOVqXGVCCOB-GQAbnxrpi0TdvZB4--kBMLqtLe7ESNTaruxnOpjt7z466AvplCvYc8UtdO86Q
!mv /content/y4m7xCVzOJm2tZO9UttZgEidLVEYUsYqN4zjak7oGSlj9k75JYCOwQNF-NXmkolYeS0oUJ8STXyq5wKG7SHTBXF_uRXefR4FbA0x8-b-lQVHNVqp5Ts72XLkxvG4R6Z_EonfAX-XYeSRpkSgmOVqXGVCCOB-GQAbnxrpi0TdvZB4--kBMLqtLe7ESNTaruxnOpjt7z466AvplCvYc8UtdO86Q pneumo.zip
!unzip ./pneumo.zip -d ./CovidAID/chest-xray-pneumonia

In [None]:
#for installing the torch version 0.3.0
!wget https://download.pytorch.org/whl/cu75/torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl

In [None]:
!pip install /content/torch-0.3.0.post4-cp36-cp36m-linux_x86_64.whl

In [None]:
# this is for installing requirements
!pip install -r /content/CovidAID/requirements.txt 

In [None]:
!python3 /content/CovidAID/data_tools/prepare_covid_data.py

In [None]:
!python3 /content/CovidAID/data_tools/prepare_data.py --combine_pneumonia

In [None]:
!python3 /content/CovidAID/tools/transfer.py

In [None]:
!python3 ./CovidAID/tools/trainer.py --mode train --freeze --checkpoint ./CovidAID/models/CovidAID_3_class.pth --bs 16 --save ./new_models

### Now preparing for the covid data set

In [None]:
import os
import numpy as np
import pandas as pd
from collections import Counter 


In [None]:
COVID_DATA_PATH='./CovidAID/covid-chestxray-dataset'
METADATA_CSV = os.path.join(COVID_DATA_PATH, 'metadata.csv')
TRAIN_FILE = './CovidAID/data/covid19/train_list.txt'
VAL_FILE = './CovidAID/data/covid19/val_list.txt'
TEST_FILE = './CovidAID/data/covid19/test_list.txt'


In [None]:
# Load patient stats
covids = dict()
df = pd.read_csv(METADATA_CSV)
df = df[(df['finding'] == 'COVID-19') & (df['modality'] == 'X-ray') & (
                (df['view'] == 'PA') | (df['view'] == 'AP') | (df['view'] == 'AP Supine')
            )]
print(df)

In [None]:
patient_ids = Counter(df['patientid'].tolist())
covids = {k: v for k, v in sorted(patient_ids.items(), key=lambda item: item[1])}
total_data = sum([v for k,v in covids.items()])
print ("Patient-#X-Rays statistics:")
print (covids)
print ("Total Images:", total_data, '\n')

In [None]:
# Assign train-val-test split
test_patients = set({4, 15, 86, 59, 6, 82, 80, 78, 76, 65, 36, 32, 50, 18, 115, 152, 138, 70, 116})
val_patients = set({73, 51, 48, 11, 43, 24, 112})


In [None]:
print ('#Train patients:', len(set(covids.keys()).difference(test_patients.union(val_patients))))
print ('#Test patients:', len(test_patients))
print ('#Val patients:', len(val_patients))
print ()
print ('#Train data points:', sum([v for k, v in covids.items() if int(k) not in test_patients.union(val_patients)]))
print ('#Test data points:', sum([v for k, v in covids.items() if int(k) in test_patients]))
print ('#Val data points:', sum([v for k, v in covids.items() if int(k) in val_patients]))

In [None]:
# Construct the split lists
train_list = []
test_list = []
val_list = []

In [None]:
for i, row in df.iterrows():
    patient_id = row['patientid']
    filename = os.path.join(row['folder'], row['filename'])

    if int(patient_id) in test_patients:
        test_list.append(filename)
    elif int(patient_id) in val_patients:
        val_list.append(filename)
    else:
        train_list.append(filename)

print (len(train_list), len(test_list), len(val_list))

In [None]:
# Write image list in file
def make_img_list(data_file, img_file_list):
    with open(data_file, 'w') as f:
        for imgfile in img_file_list:
            try: 
                assert os.path.isfile(os.path.join(COVID_DATA_PATH, imgfile))
                f.write("%s\n" % imgfile)
            except: 
                print ("Image %s NOT FOUND" % imgfile)

In [None]:
make_img_list(TRAIN_FILE, train_list)
make_img_list(VAL_FILE, val_list)
make_img_list(TEST_FILE, test_list)
TRAIN_FILE, train_list

### Prepare combined dataset


In [None]:
"""
Script to prepare combined dataset
Class 0: Normal
Class 1: Bacterial Pneumonia
Class 2: Viral Pneumonia
Class 3: COVID-19
"""
import glob
import os
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--combine_pneumonia", action='store_true', default=False)
args = parser.parse_args()

COVID19_DATA_PATH = "./CovidAID/data/covid19"
COVID19_IMGS_PATH = "./CovidAID/covid-chestxray-dataset"
PNEUMONIDA_DATA_PATH = "./CovidAID/chest-xray-pneumonia"
DATA_PATH = "./CovidAID/data"

# Assert that the data directories are present
for d in [COVID19_DATA_PATH, COVID19_IMGS_PATH, PNEUMONIDA_DATA_PATH, DATA_PATH]:
    try:
        assert os.path.isdir(d) 
    except:
        print ("Directory %s does not exists" % d)

def create_list (split):
    assert split in ['train', 'test', 'val']

    l = []
    # Prepare list using kaggle pneumonia dataset
    for f in glob.glob(os.path.join(PNEUMONIDA_DATA_PATH, split, 'NORMAL', '*')):
        l.append((f, 0)) # Class 0

    for f in glob.glob(os.path.join(PNEUMONIDA_DATA_PATH, split, 'PNEUMONIA', '*')):
        if args.combine_pneumonia:
            l.append((f, 1)) # Class 1
        else:
            if 'bacteria' in f:
                l.append((f, 1)) # Class 1
            else:
                l.append((f, 2)) # Class 2

    # Prepare list using covid dataset
    covid_file = os.path.join(COVID19_DATA_PATH, '%s_list.txt'%split)
    with open(covid_file, 'r') as cf:
        for f in cf.readlines():
            f = os.path.join(COVID19_IMGS_PATH, f.strip())
            if args.combine_pneumonia:
                l.append((f, 2)) # Class 2
            else:
                l.append((f, 3)) # Class 3

    with open(os.path.join(DATA_PATH, '%s.txt'%split), 'w') as f:
        for item in l:
            f.write("%s %d\n" % item)

for split in ['train', 'test', 'val']:
    create_list(split)

### Prepare pretrained layer

In [None]:
"""
Code to transfer weights from CheXNet (torch 0.3) to CovidAID
"""

import sys
from covidaid import CovidAID, CheXNet
import torch
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--combine_pneumonia", action='store_true', default=False)
parser.add_argument("--chexnet_model_checkpoint", "--old", type=str, default="./CovidAID/data/CheXNet_model.pth.tar")
parser.add_argument("--covidaid_model_trained_checkpoint", "--new", type=str, default="./CovidAID/models/CovidAID_transfered.pth.tar")
args = parser.parse_args()

chexnet_model_checkpoint = args.chexnet_model_checkpoint
covidaid_model_trained_checkpoint = args.covidaid_model_trained_checkpoint

model = CovidAID(combine_pneumonia=args.combine_pneumonia)

def load_weights(checkpoint_pth, state_dict=True):
    model = torch.load(checkpoint_pth)
    
    if state_dict:
        return model['state_dict']
    else:
        return model

def get_top_keys(model, depth=0):
    return set({w.split('.')[depth] for w in model.keys()})

chexnet_model = load_weights(chexnet_model_checkpoint)
template = model.state_dict()

assert get_top_keys(chexnet_model, depth=2) == set({'features', 'classifier'})
assert get_top_keys(template, depth=1) == set({'features', 'classifier'})

# print (chexnet_model.keys())
# print (template.keys())
# print (model.state_dict().keys())

c_keys = {k for k in chexnet_model.keys()}
t_keys = {'module.' + k for k in template.keys()}

assert len(c_keys.difference(t_keys)) == 0
assert len(t_keys.difference(c_keys)) == 0


# Transfer the feature weights
for k, w in template.items():
    chex_key = 'module.' + k

    if k.split('.')[1] == 'classifier':
        # Uncomment below to copy trained weights of pneumonia
        # 6th class is pneumonia in CheXNet => Copy it's weights to pneumonia classes
        # for c in [1, 2, 3]:
        #     template[k][c, ...] = chexnet_model[chex_key][6, ...]
        print ('doing nothing for', k)
    else:
        # print (type(template[k]), template[k].size())
        # print (type(chexnet_model[chex_key]), chexnet_model[chex_key].size())
        assert chexnet_model[chex_key].size() == template[k].size()
        template[k] = chexnet_model[chex_key]

torch.save(template, covidaid_model_trained_checkpoint)


### training the classifier layer

In [None]:
"""
Trainer for training and testing the networks
"""

import os
import numpy as np
import datetime
import json
import argparse
import glob
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch import optim
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from read_data import ChestXrayDataSet, ChestXrayDataSetTest
from sklearn.metrics import roc_auc_score, confusion_matrix, roc_curve, auc, f1_score
import seaborn as sn
import pandas as pd
from scipy import interp
from itertools import cycle
import matplotlib
import matplotlib.pyplot as plt
from covidaid import CovidAID
from tqdm import tqdm

class Trainer:
    def __init__ (self, local_rank=None, checkpoint=None, combine_pneumonia=False):
        """
        Trainer for the CovidAID
        """
        self.distributed = True if local_rank is not None else False
        print ("Distributed training %s" % ('ON' if self.distributed else 'OFF'))
        if self.distributed:
            raise NotImplementedError("Currently distributed training not supported")
            self.device = torch.cuda.device('cuda', local_rank)
        else:
            self.device = torch.cuda.device('cuda')

        # Using 2 classes for pneumonia vs 1 class
        self.combine_pneumonia = combine_pneumonia

        # self.net = CovidAID().to(self.device)
        self.net = CovidAID(combine_pneumonia).cuda()
        if self.distributed:
            self.net = torch.nn.parallel.DistributedDataParallel(self.net,
                                                            device_ids=[local_rank],
                                                            output_device=local_rank)
        
        # load model
        if checkpoint is not None:
            self.load_model(checkpoint)

    def train(self, TRAIN_IMAGE_LIST, VAL_IMAGE_LIST, NUM_EPOCHS=10, LR=0.001, BATCH_SIZE=64,
                start_epoch=0, logging=True, save_path=None, freeze_feature_layers=True):
        """
        Train the CovidAID
        """
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

        train_dataset = ChestXrayDataSet(image_list_file=TRAIN_IMAGE_LIST,
                                        transform=transforms.Compose([
                                            transforms.Resize(256),
                                            transforms.TenCrop(224),
                                            transforms.Lambda
                                            (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                                            transforms.Lambda
                                            (lambda crops: torch.stack([normalize(crop) for crop in crops]))
                                        ]),
                                        combine_pneumonia=self.combine_pneumonia)
        if self.distributed:
            sampler = DistributedSampler(train_dataset)
            train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE,
                                    shuffle=False, num_workers=8, pin_memory=True,
                                    sampler=sampler)
        else:
            train_loader = DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE,
                                    shuffle=True, num_workers=8, pin_memory=True)

        val_dataset = ChestXrayDataSet(image_list_file=VAL_IMAGE_LIST,
                                        transform=transforms.Compose([
                                            transforms.Resize(256),
                                            transforms.TenCrop(224),
                                            transforms.Lambda
                                            (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                                            transforms.Lambda
                                            (lambda crops: torch.stack([normalize(crop) for crop in crops]))
                                        ]),
                                        combine_pneumonia=self.combine_pneumonia)
        if self.distributed:
            sampler = DistributedSampler(val_dataset)
            val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE,
                                    shuffle=False, num_workers=8, pin_memory=True,
                                    sampler=sampler)
        else:
            val_loader = DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE,
                                    shuffle=True, num_workers=8, pin_memory=True)

        # Freeze heads and create optimizer
        if freeze_feature_layers:
            print ("Freezing feature layers")
            for param in self.net.densenet121.features.parameters():
                param.requires_grad = False

        # optimizer = optim.SGD(filter(lambda p: p.requires_grad, self.net.parameters()),
        #                 lr=LR, momentum=0.9)
        optimizer = optim.Adam(filter(lambda p: p.requires_grad, self.net.parameters()), lr=LR)


        for epoch in range(start_epoch, NUM_EPOCHS):
            # switch to train mode
            self.net.train()
            tot_loss = 0.0
            for i, (inputs, target) in tqdm(enumerate(train_loader), total=len(train_dataset)/BATCH_SIZE):
                # inputs = inputs.to(self.device)
                # target = target.to(self.device)
                inputs = inputs.cuda()
                target = target.cuda()

                # Shape of input == [BATCH_SIZE, NUM_CROPS=19, CHANNELS=3, HEIGHT=224, WIDTH=244]
                bs, n_crops, c, h, w = inputs.size()
                inputs = inputs.view(-1, c, h, w)
                inputs = torch.autograd.Variable(inputs.view(-1, c, h, w))
                target = torch.autograd.Variable(target)
                preds = self.net(inputs).view(bs, n_crops, -1).mean(dim=1)

                # loss = torch.sum(torch.abs(preds - target) ** 2)    
                loss = train_dataset.loss(preds, target)  
                # exit()          
                tot_loss += float(loss.data)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

            tot_loss /= len(train_dataset)

            # Clear cache
            torch.cuda.empty_cache()
            
            # Running on validation set
            self.net.eval()
            val_loss = 0.0
            for i, (inputs, target) in tqdm(enumerate(val_loader), total=len(val_dataset)/BATCH_SIZE):
                # inputs = inputs.to(self.device)
                # target = target.to(self.device)
                inputs = inputs.cuda()
                target = target.cuda()

                # Shape of input == [BATCH_SIZE, NUM_CROPS=19, CHANNELS=3, HEIGHT=224, WIDTH=244]
                bs, n_crops, c, h, w = inputs.size()
                inputs = inputs.view(-1, c, h, w)
                inputs = torch.autograd.Variable(inputs.view(-1, c, h, w), volatile=True)
                target = torch.autograd.Variable(target, volatile=True)

                preds = self.net(inputs).view(bs, n_crops, -1).mean(1)
                # loss = torch.sum(torch.abs(preds - target) ** 2)
                loss = val_dataset.loss(preds, target) 
                
                val_loss += float(loss.data)

            val_loss /= len(val_dataset)

            # Clear cache
            torch.cuda.empty_cache()

            # logging statistics
            timestamp = str(datetime.datetime.now()).split('.')[0]
            log = json.dumps({
                'timestamp': timestamp,
                'epoch': epoch+1,
                'train_loss': float('%.5f' % tot_loss),
                'val_loss': float('%.5f' % val_loss),
                'lr': float('%.6f' % LR)
            })
            if logging:
                print (log)

            log_file = os.path.join(save_path, 'train.log')
            if log_file is not None:
                with open(log_file, 'a') as f:
                    f.write("{}\n".format(log))

            model_path = os.path.join(save_path, 'epoch_%d.pth'%(epoch+1))
            self.save_model(model_path)
            
        print ('Finished Training')

    def predict(self, TEST_IMAGE_LIST, BATCH_SIZE=64):
        """
        Predict the task labels corresponding to the input images
        """
        normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])

        test_dataset = ChestXrayDataSetTest(image_list_file=TEST_IMAGE_LIST,
                                        transform=transforms.Compose([
                                            transforms.Resize(256),
                                            transforms.TenCrop(224),
                                            transforms.Lambda
                                            (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                                            transforms.Lambda
                                            (lambda crops: torch.stack([normalize(crop) for crop in crops]))
                                        ]),
                                        combine_pneumonia=self.combine_pneumonia)
        if self.distributed:
            sampler = DistributedSampler(test_dataset)
            test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                                shuffle=False, num_workers=8, pin_memory=True,
                                sampler=sampler)
        else:
            test_loader = DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE,
                                shuffle=True, num_workers=8, pin_memory=True)

        # initialize the ground truth and output tensor
        gt = torch.FloatTensor().cuda()
        pred = torch.FloatTensor().cuda()

        # switch to evaluate mode
        self.net.eval()

        for i, (inputs, target) in tqdm(enumerate(test_loader), total=len(test_loader)):
            # inputs = inputs.to(self.device)
            # target = target.to(self.device)
            inputs = inputs.cuda()
            target = target.cuda()
            gt = torch.cat((gt, target), 0)

            # Shape of input == [BATCH_SIZE, NUM_CROPS=19, CHANNELS=3, HEIGHT=224, WIDTH=244]
            bs, n_crops, c, h, w = inputs.size()
            inputs = torch.autograd.Variable(inputs.view(-1, c, h, w), volatile=True)

            # Pass through the network and take average prediction from all the crops
            output = self.net(inputs)
            output_mean = output.view(bs, n_crops, -1).mean(1)
            pred = torch.cat((pred, output_mean.data), 0)
        gt = gt.cpu().numpy()
        pred = pred.cpu().numpy()

        return gt, pred

    def evaluate(self, TEST_IMAGE_LIST, BATCH_SIZE=64, cm_path='cm', roc_path='roc'):
        """
        Evaluate on the test set plotting confusion matrix and roc curves
        """
        gt, pred = self.predict(TEST_IMAGE_LIST, BATCH_SIZE=64)
        print (pred)

        # Compute ROC scores
        labels = ['Normal', 'Bacterial', 'Viral', 'COVID-19']
        if self.combine_pneumonia:
            labels = ['Normal', 'Pneumonia', 'COVID-19']
        self.compute_AUC_scores(gt, pred, labels)

        # Plot ROC scores
        self.plot_ROC_curve(gt, pred, labels, roc_path)

        # Treat the max. output as prediction. 
        # Plot Confusion Matrix
        gt = gt.argmax(axis=1)
        pred = pred.argmax(axis=1)
        self.plot_confusion_matrix(gt, pred, labels, cm_path)

    def F1(self, TEST_DIR, out_file, BATCH_SIZE=64):
        """
        Evaluate on multiple test sets and compute F1 scores
        """
        f = open(out_file, 'w')
        for test_file in glob.glob(os.path.join(TEST_DIR, '*.txt')):
            print (test_file)
            gt, pred = self.predict(test_file, BATCH_SIZE)
            # Treat the max. output as prediction. 
            gt = gt.argmax(axis=1)
            pred = pred.argmax(axis=1)
            f1 = f1_score(gt, pred, average='macro')
            f.write('%s %.6f\n' % (test_file, f1))
        f.close()

    def plot_confusion_matrix(self, y_true, y_pred, labels, cm_path):
        norm_cm = confusion_matrix(y_true, y_pred, normalize='true')
        norm_df_cm = pd.DataFrame(norm_cm, index=labels, columns=labels)
        plt.figure(figsize = (10,7))
        sn.heatmap(norm_df_cm, annot=True, fmt='.2f', square=True, cmap=plt.cm.Blues)
        plt.xlabel("Predicted")
        plt.ylabel("Ground Truth")
        matplotlib.rcParams.update({'font.size': 14})
        plt.savefig('%s_norm.png' % cm_path, pad_inches = 0, bbox_inches='tight')
        
        cm = confusion_matrix(y_true, y_pred)
        # Finding the annotations
        cm = cm.tolist()
        norm_cm = norm_cm.tolist()
        annot = [
            [("%d (%.2f)" % (c, nc)) for c, nc in zip(r, nr)]
            for r, nr in zip(cm, norm_cm)
        ]
        plt.figure(figsize = (10,7))
        sn.heatmap(norm_df_cm, annot=annot, fmt='', cbar=False, square=True, cmap=plt.cm.Blues)
        plt.xlabel("Predicted")
        plt.ylabel("Ground Truth")
        matplotlib.rcParams.update({'font.size': 14})
        plt.savefig('%s.png' % cm_path, pad_inches = 0, bbox_inches='tight')
        print (cm)

        accuracy = np.sum(y_true == y_pred) / len(y_true)
        print ("Accuracy: %.5f" % accuracy)

    def compute_AUC_scores(self, y_true, y_pred, labels):
        """
        Computes the Area Under the Curve (AUC) from prediction scores
        y_true.shape  = [n_samples, n_classes]
        y_preds.shape = [n_samples, n_classes]
        labels.shape  = [n_classes]
        """
        AUROC_avg = roc_auc_score(y_true, y_pred)
        print('The average AUROC is {AUROC_avg:.4f}'.format(AUROC_avg=AUROC_avg))
        for y, pred, label in zip(y_true.transpose(), y_pred.transpose(), labels):
            print('The AUROC of {0:} is {1:.4f}'.format(label, roc_auc_score(y, pred)))

    def plot_ROC_curve(self, y_true, y_pred, labels, roc_path): 
        """
        Plots the ROC curve from prediction scores
        y_true.shape  = [n_samples, n_classes]
        y_preds.shape = [n_samples, n_classes]
        labels.shape  = [n_classes]
        """
        n_classes = len(labels)
        # Compute ROC curve and ROC area for each class
        fpr = dict()
        tpr = dict()
        roc_auc = dict()
        for y, pred, label in zip(y_true.transpose(), y_pred.transpose(), labels):
            fpr[label], tpr[label], _ = roc_curve(y, pred)
            roc_auc[label] = auc(fpr[label], tpr[label])

        # First aggregate all false positive rates
        all_fpr = np.unique(np.concatenate([fpr[label] for label in labels]))

        # Then interpolate all ROC curves at this points
        mean_tpr = np.zeros_like(all_fpr)
        for label in labels:
            mean_tpr += interp(all_fpr, fpr[label], tpr[label])

        # Finally average it and compute AUC
        mean_tpr /= n_classes

        # Compute micro-average ROC curve and ROC area
        fpr["micro"], tpr["micro"], _ = roc_curve(y_true.ravel(), y_pred.ravel())
        roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

        fpr["macro"] = all_fpr
        tpr["macro"] = mean_tpr
        roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

        # Plot all ROC curves
        plt.figure()
        lw = 2
        plt.plot(fpr["micro"], tpr["micro"],
                label='micro-average ROC curve (area = {0:0.3f})'
                    ''.format(roc_auc["micro"]),
                color='deeppink', linestyle=':', linewidth=2)

        plt.plot(fpr["macro"], tpr["macro"],
                label='macro-average ROC curve (area = {0:0.3f})'
                    ''.format(roc_auc["macro"]),
                color='navy', linestyle=':', linewidth=2)

        if len(labels) == 4:
            colors = ['green', 'cornflowerblue', 'darkorange', 'darkred']
        else:
            colors = ['green', 'cornflowerblue', 'darkred']
        for label, color in zip(labels, cycle(colors)):
            plt.plot(fpr[label], tpr[label], color=color, lw=lw,
                    label='ROC curve of {0} (area = {1:0.3f})'
                    ''.format(label, roc_auc[label]))

        plt.plot([0, 1], [0, 1], 'k--', lw=lw)
        plt.xlim([0.0, 1.0])
        plt.ylim([0.0, 1.05])
        plt.xlabel('False Positive Rate')
        plt.ylabel('True Positive Rate')
        plt.title('ROC curve')
        plt.legend(loc="lower right")
        matplotlib.rcParams.update({'font.size': 14})
        plt.savefig('%s.png' % roc_path, pad_inches = 0, bbox_inches='tight')

    def save_model(self, checkpoint_path, model=None):
        if model is None: model = self.net
        torch.save(model.state_dict(), checkpoint_path)
    
    def load_model(self, checkpoint_path, model=None):
        print ("Loading model from %s" % checkpoint_path)
        if model is None: model = self.net
        model.load_state_dict(torch.load(checkpoint_path))


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int) # For distributed processing
    parser.add_argument("--mode", choices=['train', 'test', 'f1'], required=True)
    parser.add_argument("--checkpoint", type=str, required=True)
    parser.add_argument("--combine_pneumonia", action='store_true', default=False)
    parser.add_argument("--save", type=str)
    parser.add_argument("--start", type=int, default=0)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--freeze", action='store_true', default=False)
    parser.add_argument("--bs", type=int, default=8)
    parser.add_argument("--cm_path", type=str, default='plots/cm')
    parser.add_argument("--roc_path", type=str, default='plots/roc')

    # parser.add_argment("--torch_version", "--tv", choices=["0.3", "new"], default="0.3")
    args = parser.parse_args()

    TRAIN_IMAGE_LIST = './data/train.txt'
    VAL_IMAGE_LIST = './data/val.txt'
    TEST_IMAGE_LIST = './data/test.txt'
    TEST_DIR = './data/samples'

    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    if args.local_rank is not None:
        torch.distributed.init_process_group(backend='nccl')

    trainer = Trainer(local_rank=args.local_rank, checkpoint=args.checkpoint, combine_pneumonia=args.combine_pneumonia)
    if args.mode == 'test':
        trainer.evaluate(TEST_IMAGE_LIST, cm_path=args.cm_path, roc_path=args.roc_path)
    elif args.mode == 'train':
        assert args.save is not None
        trainer.train(TRAIN_IMAGE_LIST, VAL_IMAGE_LIST, BATCH_SIZE=args.bs, NUM_EPOCHS=300, LR=args.lr,
                        start_epoch=args.start, save_path=args.save, freeze_feature_layers=args.freeze)
    else:
        trainer.F1(TEST_DIR, 'models/samples.txt')

# Run command for distributed
# python -m torch.distributed.launch --nproc_per_node=2 --nnodes=2 --node_rank=0 --master_addr="192.168.1.1" --master_port=1234 OUR_TRAINING_SCRIPT.py (--arg1 --arg2 --arg3 and all other arguments of our training script)
# python -m torch.distributed.launch --nproc_per_node=2 trainer.py --mode train

### inference layers

In [None]:
"""
Script to process set of images and output predictions
"""
from RISE.visualize import visualize
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from PIL import Image
import os
import glob
import argparse
from covidaid import CovidAID
from tqdm import tqdm
import termtables as tt



class CovidDataLoader(Dataset):
    """
    Read images and corresponding labels.
    """
    def __init__(self, image_dir, transform=None):
        """
        Args:
            image_dir: path to image directory.
            transform: optional transform to be applied on a sample.
        """
        self.image_names = [img for img in glob.glob(os.path.join(image_dir, '*'))]
        self.transform = transform

    def __getitem__(self, index):
        """
        Args:
            index: the index of item
        Returns:
            image and its name
        """
        image_name = self.image_names[index]
        image = Image.open(image_name).convert('RGB')
        if self.transform is not None:
            image = self.transform(image)
        return image, image_name.split('/')[-1]

    def __len__(self):
        return len(self.image_names)



if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument("--img_dir", type=str, required=True)
    parser.add_argument("--checkpoint", type=str, required=True)
    parser.add_argument("--combine_pneumonia", action='store_true', default=False)
    parser.add_argument("--visualize_dir", type=str, default=None)
    args = parser.parse_args()

    # Load the model
    model = CovidAID(args.combine_pneumonia).cuda()
    model.load_state_dict(torch.load(args.checkpoint))

    # Load the data
    normalize = transforms.Normalize([0.485, 0.456, 0.406],
                                     [0.229, 0.224, 0.225])
    
    test_dataset = CovidDataLoader(image_dir=args.img_dir,
            transform=transforms.Compose([
                transforms.Resize(256),
                transforms.TenCrop(224),
                transforms.Lambda
                (lambda crops: torch.stack([transforms.ToTensor()(crop) for crop in crops])),
                transforms.Lambda
                (lambda crops: torch.stack([normalize(crop) for crop in crops]))
            ])
    )
    test_loader = DataLoader(dataset=test_dataset, batch_size=64,
                    shuffle=False, num_workers=8, pin_memory=True)

    # initialize the output tensor
    pred = torch.FloatTensor().cuda()
    pred_names = []

    # switch to evaluate mode
    model.eval()

    for i, (inputs, names) in tqdm(enumerate(test_loader), total=len(test_loader)):
        inputs = inputs.cuda()

        # Shape of input == [BATCH_SIZE, NUM_CROPS=10, CHANNELS=3, HEIGHT=224, WIDTH=244]
        bs, n_crops, c, h, w = inputs.size()
        inputs = torch.autograd.Variable(inputs.view(-1, c, h, w), volatile=True)

        # Pass through the network and take average prediction from all the crops
        output = model(inputs)
        output_mean = output.view(bs, n_crops, -1).mean(1)
        pred = torch.cat((pred, output_mean.data), 0)
        pred_names += names

    pred = pred.cpu().numpy()

    assert len(pred) == len(pred_names)

    scores = []
    for p, n in zip(pred, pred_names):
        p = ["%.1f %%" % (i * 100) for i in p]
        scores.append([n] + p)

    header=['Name', 'Normal', 'Bacterial', 'Viral', 'COVID-19']
    alignment="c"*5
    if args.combine_pneumonia:
        header = ['Name', 'Normal', 'Pneumonia', 'COVID-19']
        alignment = "c"*4

    string = tt.to_string(
        scores,
        header=header,
        style=tt.styles.ascii_thin_double,
        padding=(0, 1),
        alignment=alignment
    )

    print (string)
    
    # RISE Visualization
    if args.visualize_dir:
        visualize(model,args.img_dir,args.visualize_dir,CovidDataLoader)
        print("Visualizations generated at "+str(args.visualize_dir))