In [1]:
import os
import re
import PIL
import sys
import json
import time
import timm
import math
import copy
import torch
import pickle
import logging
import fnmatch
import argparse
import itertools
import torchvision
import numpy as np
%matplotlib inline
import pandas as pd
import seaborn as sns
import albumentations
import torch.nn as nn
from PIL import Image
from pathlib import Path
from copy import deepcopy
import scikitplot as skplt
from sklearn import metrics
import torch.optim as optim
from datetime import datetime
from timm.data.loader import *
from torchvision import models
import matplotlib.pyplot as plt
import torch.nn.functional as F
import torch.utils.data as data
from cutmix.cutmix import CutMix
from torch.autograd import Variable
from tqdm import tqdm, tqdm_notebook
from torch.optim import lr_scheduler
#from pytorch_metric_learning import loss
import torch.utils.model_zoo as model_zoo
from timm.models.layers.activations import *
from timm.utils import accuracy, AverageMeter
%config InlineBackend.figure_format = 'retina'
from cutmix.utils import CutMixCrossEntropyLoss
from collections import OrderedDict, defaultdict
from torch.utils.tensorboard import SummaryWriter
from warmup_scheduler import GradualWarmupScheduler
from torchvision import transforms, models, datasets
from torch.utils.data.sampler import SubsetRandomSampler
from randaugment import RandAugment, ImageNetPolicy, Cutout
from timm.loss import LabelSmoothingCrossEntropy, SoftTargetCrossEntropy
from sklearn.metrics import classification_report, confusion_matrix,accuracy_score, roc_curve, auc, roc_auc_score
#from timm.data import Dataset, DatasetTar, RealLabelsImagenet, create_loader, Mixup, FastCollateMixup, AugMixDataset

In [2]:
torch.backends.cudnn.benchmark = True
def set_seed(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True

In [3]:
#data_dir = '/home/linh/Downloads/TB/'
#data_dir = '/home/linh/Downloads/Covid-19/CXR_20201006/data_20201006/'
data_dir = '/home/linh/Downloads/TB_COVID-19/data'

device = torch.device("cuda:0")
batch_size = 64
# batch_size (48 or 50 for EfficientNet-B0, img_size=320, cuda=0 or cuda=1) 
# batch_size (66 or 68 for EfficientNet-B1, img_size=240, cuda=0 or cuda=1)
num_epochs = 450
lr = 0.01
beta = 1
step_size = 100
img_size = 240 #320 #240
test_size = int((256 / 224) * img_size)
mean = [0.485, 0.456, 0.406] 
std = [0.229, 0.224, 0.225]
num_workers = 4
# Define your transforms for the training and testing sets
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomRotation(30),
        transforms.RandomResizedCrop(img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(0.3, 0.3, 0.3),
        RandAugment(),
        ImageNetPolicy(),
        Cutout(size=16),
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
        transforms.RandomErasing()
    ]),
    'val': transforms.Compose([
        transforms.Resize(test_size),
        transforms.CenterCrop(img_size),
        transforms.ColorJitter(0.5, 0.5, 0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ]),
    'test': transforms.Compose([
        transforms.Resize(test_size),
        transforms.CenterCrop(img_size),
        transforms.ColorJitter(0.5, 0.5, 0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
}

# Load the datasets with ImageFolder
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x), data_transforms[x]) for x in ['train', 'val', 'test']}
class_names = image_datasets['train'].classes
num_classes = len(class_names)
data_loader = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=batch_size,
                                             shuffle=True, num_workers=num_workers, pin_memory = True)
              for x in ['train', 'val', 'test']}

dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val', 'test']}

print(class_names)
print(dataset_sizes)
print(device)

### we get the class_to_index in the data_Set but what we really need is the cat_to_names  so we will create
_ = image_datasets['val'].class_to_idx
cat_to_name = {_[i]: i for i in list(_.keys())}
print(cat_to_name)
    
# Run this to test the data loader
images, labels = next(iter(data_loader['val']))
images.size()

['COVID-19', 'NORMAL', 'NOT NORMAL & NO LUNG OPACITY', 'PNEUMONIA', 'TB']
{'train': 24722, 'val': 3236, 'test': 3238}
cuda:0
{0: 'COVID-19', 1: 'NORMAL', 2: 'NOT NORMAL & NO LUNG OPACITY', 3: 'PNEUMONIA', 4: 'TB'}


torch.Size([64, 3, 240, 240])

In [4]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    if normalize:
        cm = cm.astype('float') / cm.suam(axis=1)[:, np.newaxis]

    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, cm[i, j],
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    
def plot_roc_curve(fpr, tpr):
    plt.plot(fpr, tpr, color='darkorange',
         lw=2, label='ROC curve (area = %0.4f)' % roc_auc)
    plt.plot([0, 1], [0, 1], color='darkblue', linestyle='--')
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver Operating Characteristic (ROC) Curve')
    plt.legend(loc="lower right")  
    plt.show()
    
def plt_roc(test_y, probas_y, plot_micro=False, plot_macro=False):
    assert isinstance(test_y, list) and isinstance(probas_y, list), 'the type of input must be list'
    
    skplt.metrics.plot_roc(test_y, probas_y, plot_micro=plot_micro,plot_macro=plot_macro, figsize=(10, 8))
    #plt.savefig(add_prefix(args.prefix, 'roc_auc_curve.png'))
    plt.show()
    #plt.close()

In [5]:
"""def showimage(data_loader, number_images, cat_to_name):
    dataiter = iter(data_loader)
    images, labels = dataiter.next()
    images = images.numpy() # convert images to numpy for display
    # plot the images in the batch, along with the corresponding labels
    fig = plt.figure(figsize=(number_images, 4))
    for idx in np.arange(number_images):
        ax = fig.add_subplot(4, number_images/2, idx+1, xticks=[], yticks=[])
        img = np.transpose(images[idx])
        plt.imshow(img)
        ax.set_title(cat_to_name[labels.tolist()[idx]])
        
#### to show some  images
showimage(data_loader['test'], 10, cat_to_name)"""

"def showimage(data_loader, number_images, cat_to_name):\n    dataiter = iter(data_loader)\n    images, labels = dataiter.next()\n    images = images.numpy() # convert images to numpy for display\n    # plot the images in the batch, along with the corresponding labels\n    fig = plt.figure(figsize=(number_images, 4))\n    for idx in np.arange(number_images):\n        ax = fig.add_subplot(4, number_images/2, idx+1, xticks=[], yticks=[])\n        img = np.transpose(images[idx])\n        plt.imshow(img)\n        ax.set_title(cat_to_name[labels.tolist()[idx]])\n        \n#### to show some  images\nshowimage(data_loader['test'], 10, cat_to_name)"

In [6]:
#model = timm.create_model('efficientnet_b0', pretrained=True, drop_rate=0.2)
#model = timm.create_model('tf_efficientnet_b0_ap', pretrained=True, drop_rate=0.2)
#model = timm.create_model('tf_efficientnet_b0_ns', pretrained=True, drop_rate=0.2)

model = timm.create_model('efficientnet_b1', pretrained=True, drop_rate=0.2)
#model = timm.create_model('tf_efficientnet_b1_ap', pretrained=True, drop_rate=0.2)
#model = timm.create_model('tf_efficientnet_b1_ns', pretrained=True, drop_rate=0.2)
#model.fc #show fully connected layer for ResNet family
model.classifier #show the classifier layer (fully connected layer) for EfficientNets

Linear(in_features=1280, out_features=1000, bias=True)

In [7]:
# Create classifier
for param in model.parameters():
    param.requires_grad = True
# define `classifier` for ResNet
# Otherwise, define `fc` for EfficientNet family 
#because the definition of the full connection/classifier of 2 CNN families is differnt
fc = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 1000, bias=True)),
							     ('BN1', nn.BatchNorm2d(1000, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
								 ('dropout1', nn.Dropout(0.7)),
                                 ('fc2', nn.Linear(1000, 512)),
								 ('BN2', nn.BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
								 ('swish1', Swish()),
								 ('dropout2', nn.Dropout(0.5)),
								 ('fc3', nn.Linear(512, 128)),
								 ('BN3', nn.BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)),
							     ('swish2', Swish()),
								 ('fc4', nn.Linear(128, num_classes)),
								 ('output', nn.Softmax(dim=1))
							 ]))
# connect base model (EfficientNet_B0) with modified classifier layer
model.fc = fc
criterion = LabelSmoothingCrossEntropy()
#criterion = CutMixCrossEntropyLoss(True)
#criterion = nn.CrossEntropyLoss()
#optimizer = Nadam(model.parameters(), lr=0.001)
#optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)
optimizer = optim.SGD(model.parameters(), lr=lr, momentum=0.9, nesterov=True, weight_decay=0.0001)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=0.1)
#lr = lambda x: (((1 + math.cos(x * math.pi / num_epochs)) / 2) ** 1) * 0.9
#scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr)
scheduler_warmup = GradualWarmupScheduler(optimizer, multiplier=1, total_epoch=5, after_scheduler=scheduler)
#show our model architechture and send to GPU
model.to(device)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count = count_parameters(model)
print("The number of parameters of the model is:", count)

The number of parameters of the model is: 10425285


In [8]:
#https://github.com/clovaai/CutMix-PyTorch
def rand_bbox(size, lam):
    W = size[2]
    H = size[3]
    cut_rat = np.sqrt(1. - lam)
    cut_w = np.int(W * cut_rat)
    cut_h = np.int(H * cut_rat)

    # uniform
    cx = np.random.randint(W)
    cy = np.random.randint(H)

    bbx1 = np.clip(cx - cut_w // 2, 0, W)
    bby1 = np.clip(cy - cut_h // 2, 0, H)
    bbx2 = np.clip(cx + cut_w // 2, 0, W)
    bby2 = np.clip(cy + cut_h // 2, 0, H)

    return bbx1, bby1, bbx2, bby2

In [9]:
def train_model(model, criterion, optimizer, scheduler, num_epochs=200, checkpoint = None):
    since = time.time()

    if checkpoint is None:
        best_model_wts = copy.deepcopy(model.state_dict())
        best_loss = math.inf
        best_acc = 0.
    else:
        print(f'Val loss: {checkpoint["best_val_loss"]}, Val accuracy: {checkpoint["best_val_accuracy"]}')
        model.load_state_dict(checkpoint['model_state_dict'])
        best_model_wts = copy.deepcopy(model.state_dict())
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        best_loss = checkpoint['best_val_loss']
        best_acc = checkpoint['best_val_accuracy']
   
    # Tensorboard summary
    writer = SummaryWriter()
    start_time_per_epoch = time.time()
    
    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch + 1, num_epochs)) #(epoch, num_epochs -1)
        print('-' * 20)

        # Each epoch has a training and validation phase
        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()  # Set model to training mode
            else:
                model.eval()   # Set model to evaluate mode

            running_loss = 0.0
            running_corrects = 0

            # Iterate over data.
            for i, (inputs, labels) in enumerate(data_loader[phase]):
                inputs = inputs.to(device)
                labels = labels.to(device)
                
                r = np.random.rand(1)
                if r < 0.5: #cutmix_prob=0.5
                # generate mixed sample
                    lam = np.random.beta(beta, beta)
                    rand_index = torch.randperm(inputs.size()[0]).to(device)
                    target_a = labels
                    target_b = labels[rand_index]
                    bbx1, bby1, bbx2, bby2 = rand_bbox(inputs.size(), lam)
                    inputs[:, :, bbx1:bbx2, bby1:bby2] = inputs[rand_index, :, bbx1:bbx2, bby1:bby2]
                # adjust lambda to exactly match pixel ratio
                    lam = 1 - ((bbx2 - bbx1) * (bby2 - bby1) / (inputs.size()[-1] * inputs.size()[-2]))
 
                # zero the parameter gradients
                optimizer.zero_grad()
                
                if i % 1000 == 999:
                    print('[%d, %d] loss: %.8f' % 
                          (epoch + 1, i, running_loss / (i * inputs.size(0))))

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    
                    if r < 0.5:
                        loss = criterion(outputs, target_a) * lam + criterion(outputs, target_b) * (1. - lam)
                    else:
                        loss = criterion(outputs, labels)
                    
                    #loss = criterion(outputs, labels)
                    _, preds = torch.max(outputs, 1)
                    
                    # backward + optimize only if in training phase
                    if phase == 'train':                
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)
            
            if phase == 'train':                
                scheduler_warmup.step()
                
            epoch_loss = running_loss / dataset_sizes[phase]
            epoch_acc = running_corrects.double() / dataset_sizes[phase]

            print('{} Loss: {:.8f} Acc: {:.8f}'.format(
                phase, epoch_loss, epoch_acc))
            
            # Record training loss and accuracy for each phase
            if phase == 'train':
                writer.add_scalar('Train/Loss', epoch_loss, epoch)
                writer.add_scalar('Train/Accuracy', epoch_acc, epoch)
                writer.flush()
            else:
                writer.add_scalar('Valid/Loss', epoch_loss, epoch)
                writer.add_scalar('Valid/Accuracy', epoch_acc, epoch)
                writer.flush()
            # deep copy the model
            
            if phase == 'val' and epoch_acc > best_acc:
                print(f'New best model found!')
                print(f'New record ACC: {epoch_acc}, previous record acc: {best_acc}')
                best_loss = epoch_loss
                best_acc = epoch_acc
                best_model_wts = copy.deepcopy(model.state_dict())
                torch.save({'model_state_dict': model.state_dict(),
                            'optimizer_state_dict': optimizer.state_dict(),
                            'best_val_loss': best_loss,
                            'best_val_accuracy': best_acc,
                            'scheduler_state_dict' : scheduler.state_dict(),
                            }, 
                            CHECK_POINT_PATH
                            )
                print(f'New record acc is SAVED: {epoch_acc}')
                
        end_time_per_epoch = (time.time() - start_time_per_epoch)
        print('Time for training the last epoch: {:.0f}m {:.0f}s'.format(
        end_time_per_epoch // 60, end_time_per_epoch % 60))
        
    time_elapsed = time.time() - since
    print('Total training time complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val Acc: {:.8f} Best val loss: {:.8f}'.format(best_acc, best_loss))

    # load best model weights
    model.load_state_dict(best_model_wts)
    return model, best_loss, best_acc

In [10]:
%load_ext tensorboard
%tensorboard --logdir runs/

Reusing TensorBoard on port 6006 (pid 4848), started 23:02:02 ago. (Use '!kill 4848' to kill it.)

In [None]:
%time
#CHECK_POINT_PATH = '/home/linh/Downloads/TB/weights/EfficientNet_B0_320.pth'
#CHECK_POINT_PATH = '/home/linh/Downloads/TB/weights/EfficientNet_B0_AP_320.pth'
#CHECK_POINT_PATH = '/home/linh/Downloads/TB/weights/EfficientNet_B0_NS_320.pth'

#CHECK_POINT_PATH = '/home/linh/Downloads/TB/weights/EfficientNet_B1_240.pth'
#CHECK_POINT_PATH = '/home/linh/Downloads/TB/weights/EfficientNet_B1_AP_240.pth'
#CHECK_POINT_PATH = '/home/linh/Downloads/TB/weights/EfficientNet_B1_NS_240.pth'

#CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19_CXR/weights/EfficientNet_B0_320.pth'
#CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19_CXR/weights/EfficientNet_B0_AP_320.pth'
#CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19_CXR/weights/EfficientNet_B0_NS_320.pth'

#CHECK_POINT_PATH = '/home/linh/Downloads/TB_COVID-19/weights/EfficientNet_B0_320.pth'
#CHECK_POINT_PATH = '/home/linh/Linh/Downloads/TB_COVID-19/weights/EfficientNet_B0_AP_320.pth'
#CHECK_POINT_PATH = '/home/linh/Downloads/TB_COVID-19/weights/EfficientNet_B0_NS_320.pth'

CHECK_POINT_PATH = '/home/linh/Downloads/TB_COVID-19/weights/EfficientNet_B1_240.pth'
#CHECK_POINT_PATH = '/home/linh/Downloads/TB_COVID-19/weights/EfficientNet_B1_AP_240.pth'
#CHECK_POINT_PATH = '/home/linh/Downloads/TB_COVID-19/weights/EfficientNet_B1_NS_240.pth'

try:
    checkpoint = torch.load(CHECK_POINT_PATH)
    print("checkpoint loaded")
except:
    checkpoint = None
    print("checkpoint not found")
if checkpoint == None:
    CHECK_POINT_PATH = CHECK_POINT_PATH
model, best_val_loss, best_val_acc = train_model(model,
                                                 criterion,
                                                 optimizer,
                                                 scheduler,
                                                 num_epochs = num_epochs,
                                                 checkpoint = torch.load(CHECK_POINT_PATH)
                                                 ) 
                                                
torch.save({'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'best_val_loss': best_val_loss,
            'best_val_accuracy': best_val_acc,
            'scheduler_state_dict': scheduler.state_dict(),
            }, CHECK_POINT_PATH)

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 5.96 µs
checkpoint not found
Epoch 1/450
--------------------
train Loss: 7.56732683 Acc: 0.00016180
val Loss: 7.42466513 Acc: 0.00000000
Time for training the last epoch: 3m 36s
Epoch 2/450
--------------------
train Loss: 2.40965650 Acc: 0.49603592
val Loss: 2.02245504 Acc: 0.58930779
New best model found!
New record ACC: 0.5893077873918418, previous record acc: 0.0
New record acc is SAVED: 0.5893077873918418
Time for training the last epoch: 7m 3s
Epoch 3/450
--------------------
train Loss: 1.99657417 Acc: 0.58733112
val Loss: 1.83619068 Acc: 0.67027194
New best model found!
New record ACC: 0.6702719406674906, previous record acc: 0.5893077873918418
New record acc is SAVED: 0.6702719406674906
Time for training the last epoch: 10m 30s
Epoch 4/450
--------------------
train Loss: 1.91002737 Acc: 0.62871127
val Loss: 1.79213964 Acc: 0.67398022
New best model found!
New record ACC: 0.6739802224969097, previous record acc: 0.670271

In [None]:
%time
def compute_validate_meter(model, val_loader): # best_model_path,
    
    since = time.time()

    #CHECK_POINT_PATH = '/home/linh/Downloads/TB/weights/EfficientNet_B0_320.pth'
    #CHECK_POINT_PATH = '/home/linh/Downloads/TB/weights/EfficientNet_B0_AP_320.pth'
    #CHECK_POINT_PATH = '/home/linh/Downloads/TB/weights/EfficientNet_B0_NS_320.pth'
    
    #CHECK_POINT_PATH = '/home/linh/Downloads/TB/weights/EfficientNet_B1_240.pth'
    #CHECK_POINT_PATH = '/home/linh/Downloads/TB/weights/EfficientNet_B1_AP_240.pth'
    #CHECK_POINT_PATH = '/home/linh/Downloads/TB/weights/EfficientNet_B1_NS_240.pth'


    #CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19_CXR/weights/EfficientNet_B0_320.pth'
    #CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19_CXR/weights/EfficientNet_B0_AP_320.pth'
    #CHECK_POINT_PATH = '/home/linh/Downloads/Covid-19_CXR/weights/EfficientNet_B0_NS_320.pth'

    #CHECK_POINT_PATH = '/home/linh/Downloads/TB_COVID-19/weights/EfficientNet_B0_320.pth'
    #CHECK_POINT_PATH = '/home/linh/Linh/Downloads/TB_COVID-19/weights/EfficientNet_B0_AP_320.pth'
    #CHECK_POINT_PATH = '/home/linh/Downloads/TB_COVID-19/weights/EfficientNet_B0_NS_320.pth'

    #CHECK_POINT_PATH = '/home/linh/Downloads/TB_COVID-19/weights/EfficientNet_B1_240.pth'
    CHECK_POINT_PATH = '/home/linh/Downloads/TB_COVID-19/weights/EfficientNet_B1_AP_240.pth'
    #CHECK_POINT_PATH = '/home/linh/Downloads/TB_COVID-19/weights/EfficientNet_B1_NS_240.pth'

    try:
        checkpoint = torch.load(CHECK_POINT_PATH)
        print("checkpoint loaded")
    except:
        checkpoint = None
        print("checkpoint not found")

    def load_model(best_model_path):                                
        model.load_state_dict(checkpoint['model_state_dict'])
        best_model_wts = copy.deepcopy(model.state_dict())
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        best_loss = checkpoint['best_val_loss']
        best_acc = checkpoint['best_val_accuracy']
    load_model(CHECK_POINT_PATH)
    model.to(device)
    model.eval()
    pred_y = list()
    test_y = list()
    probas_y = list()
    with torch.no_grad():
        for data, target in val_loader:
            data, target = data.to(device), target.to(device)
            data, target = Variable(data), Variable(target)
            output = model(data)
            probas_y.extend(output.data.cpu().numpy().tolist())
            pred_y.extend(output.data.cpu().max(1, keepdim=True)[1].numpy().flatten().tolist())
            test_y.extend(target.data.cpu().numpy().flatten().tolist())
        # compute the confusion matrix
        confusion = confusion_matrix(test_y, pred_y)
        # plot the confusion matrix
        #plot_labels = ['NORMAL', 'PNEUNOMIA','TUBERCULOSIS']
        plot_labels = ['COVID-19', 'NORMAL', 'NOT NORMAL NOT OPACITY', 'PNEUMONIA', 'TUBERCULOSIS']
        plot_confusion_matrix(confusion, plot_labels)
        #plot_confusion_matrix(confusion, classes=val_loader.dataset.classes,title='Confusion matrix')
        # print Recall, Precision, F1-score, Accuracy
        report = classification_report(test_y, pred_y, digits=4)
        print(report)
        plt_roc(test_y, probas_y)
        
    time_elapsed = time.time() - since

    print('Inference completes in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))


def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
count = count_parameters(model)
print(count)

#best_model_path = '/home/linh/Downloads/TB/weights/EfficientNet_B1_240.pth'

compute_validate_meter(model, data_loader['test']) #best_model_path,

In [None]:
%time
#CHECK_POINT_PATH = '/media/linh/Linh/TB/weights/EfficientNet_B0_320.pth'
#CHECK_POINT_PATH = '/media/linh/Linh/TB/weights/EfficientNet_B0_AP_320.pth'
#CHECK_POINT_PATH = '/media/linh/Linh/TB/weights/EfficientNet_B0_NS_320.pth'


#CHECK_POINT_PATH = '/media/linh/Linh/Covid-19_20201007_CXR_CT/weights_CXR/EfficientNet_B0_320.pth'
#CHECK_POINT_PATH = '/media/linh/Linh/Covid-19_20201007_CXR_CT/weights_CXR/EfficientNet_B0_AP_320.pth'
#CHECK_POINT_PATH = '/media/linh/Linh/Covid-19_20201007_CXR_CT/EfficientNet_B0_NS_320.pth'

#CHECK_POINT_PATH = '/media/linh/Linh/TB_COVID-19/weights/EfficientNet_B0_320.pth'
CHECK_POINT_PATH = '/home/linh/Downloads/TB_COVID-19/weights/EfficientNet_B1_AP_240.pth'
#CHECK_POINT_PATH = '/media/linh/Linh/TB_COVID-19/weights/EfficientNet_B0_NS_320.pth'


try:
    checkpoint = torch.load(CHECK_POINT_PATH)
    print("checkpoint loaded")
except:
    checkpoint = None
    print("checkpoint not found")

def load_model(path):                                
    model.load_state_dict(checkpoint['model_state_dict'])
    best_model_wts = copy.deepcopy(model.state_dict())
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
    best_loss = checkpoint['best_val_loss']
    best_acc = checkpoint['best_val_accuracy']
load_model(CHECK_POINT_PATH) 
model.to(device)
model.eval()
since = time.time()
y_true = []
y_predict = []
with torch.no_grad():
    for i, data in enumerate(data_loader['test']):
        images, labels = data
        N = images.size(0)
        images = Variable(images).to(device)
        outputs = model(images)
        prediction = outputs.max(1, keepdim=True)[1]
        y_true.extend(labels.cpu().numpy())
        y_predict.extend(np.squeeze(prediction.cpu().numpy().T))   
        
time_elapsed = time.time() - since

print('Inference time is {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))

# compute the confusion matrix
confusion_mtx = confusion_matrix(y_true, y_predict)
# plot the confusion matrix
#plot_labels = ['NORMAL', 'TUBERCULOSIS']
plot_labels = ['NORMAL', 'PNEUNOMIA','TUBERCULOSIS']
#plot_labels = ['COVID-19', 'NORMAL', 'PNEUMONIA']
#plot_labels = ['COVID-19','NORMAL', 'PNEUMONIA', 'TUBERCULOSIS']
#plot_labels = ['COVID-19','NORMAL', 'NOT NORMAL & NO LUNG OPACITY', 'PNEUMONIA', 'TUBERCULOSIS']

plot_confusion_matrix(confusion_mtx, plot_labels)
report = classification_report(y_true, y_predict, digits=4)
print(report)

In [None]:
#https://github.com/arpanmangal/CovidAID/blob/master/tools/trainer.py
def plot_confusion_matrix(y_true, y_pred, class_names, cm_path):
    norm_cm = confusion_matrix(y_true, y_pred, normalize='true')
    norm_df_cm = pd.DataFrame(norm_cm, index=class_names, columns=class_names)
    plt.figure(figsize = (10,7))
    sns.heatmap(norm_df_cm, annot=True, fmt='.2f', square=True, cmap=plt.cm.Blues)
    plt.xlabel("Predicted")
    plt.ylabel("Ground Truth")
    plt.rcParams.update({'font.size': 14})
    plt.savefig('%s_norm.png' % cm_path, pad_inches = 0, bbox_inches='tight')
        
    cm = confusion_matrix(y_true, y_pred)
    # Finding the annotations
    cm = cm.tolist()
    norm_cm = norm_cm.tolist()
    annot = [[("%d (%.2f)" % (c, nc)) for c, nc in zip(r, nr)] for r, nr in zip(cm, norm_cm)]
    plt.figure(figsize = (10,7))
    sns.heatmap(norm_df_cm, annot=annot, fmt='', cbar=False, square=True, cmap=plt.cm.Blues)
    plt.xlabel("Predicted")
    plt.ylabel("Ground Truth")
    plt.rcParams.update({'font.size': 14})
    plt.savefig('%s.png' % cm_path, pad_inches = 0, bbox_inches='tight')
    print (cm)

    accuracy = np.sum(y_true == y_pred) / len(y_true)
    print ("Accuracy: %.5f" % accuracy)

def compute_AUC_scores(y_true, y_pred, class_names):
    AUROC_avg = roc_auc_score(y_true, y_pred)
    print('The average AUROC is {AUROC_avg:.4f}'.format(AUROC_avg=AUROC_avg))
    for y, pred, class_name in zip(y_true.transpose(), y_pred.transpose(), class_names):
        print('The AUROC of {0:} is {1:.4f}'.format(class_name, roc_auc_score(y, pred)))

def plot_ROC_curve(y_true, y_pred, class_names, roc_path): 
    n_classes = len(class_names)
    # Compute ROC curve and ROC area for each class
    fpr = dict()
    tpr = dict()
    roc_auc = dict()
    for y, pred, class_name in zip(y_true.transpose(), y_pred.transpose(), class_names):
        fpr[class_name], tpr[class_name], _ = roc_curve(y, pred)
        roc_auc[class_name] = auc(fpr[class_name], tpr[class_name])

    # First aggregate all false positive rates
    all_fpr = np.unique(np.concatenate([fpr[class_name] for class_name in class_names]))

    # Then interpolate all ROC curves at this points
    mean_tpr = np.zeros_like(all_fpr)
    for label in class_names:
        mean_tpr += interp(all_fpr, fpr[class_name], tpr[class_name])

    # Finally average it and compute AUC
    mean_tpr /= n_classes

    # Compute micro-average ROC curve and ROC area
    fpr["micro"], tpr["micro"], _ = roc_curve(y_true.ravel(), y_pred.ravel())
    roc_auc["micro"] = auc(fpr["micro"], tpr["micro"])

    fpr["macro"] = all_fpr
    tpr["macro"] = mean_tpr
    roc_auc["macro"] = auc(fpr["macro"], tpr["macro"])

    # Plot all ROC curves
    plt.figure()
    lw = 2
    plt.plot(fpr["micro"], tpr["micro"],
                label='micro-average ROC curve (area = {0:0.3f})'
                    ''.format(roc_auc["micro"]),
                color='deeppink', linestyle=':', linewidth=2)

    plt.plot(fpr["macro"], tpr["macro"],
                label='macro-average ROC curve (area = {0:0.3f})'
                    ''.format(roc_auc["macro"]),
                color='navy', linestyle=':', linewidth=2)
    if len(class_names) == 5:
        colors = ['green', 'cornflowerblue', 'darkorange', 'darkred', 'purple']
    elif len(class_names) == 4:
        colors = ['green', 'cornflowerblue', 'darkorange', 'darkred']
    elif len(class_names) == 3:
        colors = ['green', 'cornflowerblue', 'darkred']
    else:
        colors = ['green', 'cornflowerblue']
    for label, color in zip(class_names, cycle(colors)):
        plt.plot(fpr[label], tpr[label], color=color, lw=lw,
                label='ROC curve of {0} (area = {1:0.3f})'
                    ''.format(label, roc_auc[label]))

    plt.plot([0, 1], [0, 1], 'k--', lw=lw)
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('ROC curve')
    plt.legend(loc="lower right")
    plt.rcParams.update({'font.size': 14})
    plt.savefig('%s.png' % roc_path, pad_inches = 0, bbox_inches='tight')

In [None]:
cm_path=data_dir
plot_confusion_matrix(y_true, y_predict, class_names, cm_path)

In [None]:
roc_auc = roc_auc_score(y_true, y_predict, average='macro', sample_weight=None, max_fpr=None, multi_class='ovr', labels=None) #'ovo', 'ovr'
print('ROC curve (area = %0.4f)' % roc_auc)
fpr, tpr, thresholds = roc_curve(y_true, y_predict)
plot_roc_curve(fpr, tpr)

In [None]:
def get_preds(model,device=None,tta=3):
    device=device
    model.to(device)
    preds = np.zeros(len(data_loader['test']))
    for tta_id in range(tta):
        test_preds = []
        with torch.no_grad():
            for xb in data_loader['test']:
                xb = xb.to(device)
                out = model(xb)
                out = torch.sigmoid(out)
                test_preds.extend(out.cpu().numpy())
            preds += np.array(test_preds).reshape(-1)
        print(f'TTA {tta_id}')
    preds /= tta
    return preds
preds = get_preds(model,tta=25)  

In [None]:
subm = pd.read_csv("/home/linh/Downloads/TB/SampleSubmission.csv")
subm.LABEL = preds
subm.to_csv('/home/linh/Downloads/TB/submission.csv',index=False)

In [None]:
import tqdm
def test(model, data_loader, device):
    model.eval()
    targets, predicts = list(), list()
    with torch.no_grad():
        for fields, target in tqdm.tqdm(data_loader, smoothing=0, mininterval=1.0):
            fields, target = fields.to(device), target.to(device)
            y = model(fields)
            targets.extend(target.tolist())
            predicts.extend(y.tolist())
    return roc_auc_score(y_true, y_predict, average='macro', sample_weight=None, max_fpr=None, multi_class='ovo', labels=None)

test(model, data_loader['test'], device)
