Connect to Google drive

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Import library

In [None]:
%matplotlib inline

import os
import shutil
import random
import torch
import torchvision
import numpy as np

from PIL import Image
from matplotlib import pyplot as plt

torch.manual_seed(0)

print('Using PyTorch version', torch.__version__)

**Redesign DataSet for test data and train dataset**




In [None]:
class_names = ['0', '1', '2']
root_dir = '/content/drive/My Drive/Colab Notebooks/GAN_Images/Xray'
source_dirs = ['0', '1', '2']

if os.path.isdir(os.path.join(root_dir, source_dirs[1])):
    os.mkdir(os.path.join(root_dir, 'test'))

    for i, d in enumerate(source_dirs):
        os.rename(os.path.join(root_dir, d), os.path.join(root_dir, class_names[i]))

    for c in class_names:
        os.mkdir(os.path.join(root_dir, 'test', c))

    for c in class_names:
        images = [x for x in os.listdir(os.path.join(root_dir, c)) ]
        selected_images = random.sample(images, 111)
        for image in selected_images:
            source_path = os.path.join(root_dir, c, image)
            target_path = os.path.join(root_dir, 'test', c, image)
            shutil.move(source_path, target_path)

In [None]:
class ChestXRayDataset(torch.utils.data.Dataset):
    def __init__(self, image_dirs, transform):
        def get_images(class_name):
            images = [x for x in os.listdir(image_dirs[class_name]) ]
            print(f'Found {len(images)} {class_name} examples')
            return images
        
        self.images = {}
        self.class_names = ['normal', 'pneumonia', 'covid']
            
        for c in self.class_names:
            self.images[c] = get_images(c)
            
        self.image_dirs = image_dirs
        self.transform = transform
        
    def __len__(self):
        return sum([len(self.images[c]) for c in self.class_names])
    
    def __getitem__(self, index):
        class_name = random.choice(self.class_names)
        index = index % len(self.images[class_name])
        image_name = self.images[class_name][index]
        image_path = os.path.join(self.image_dirs[class_name], image_name)
        image = Image.open(image_path).convert('RGB')
        return self.transform(image), self.class_names.index(class_name)

# ***Data Loader and Normalization, GAN ***

In [None]:
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=(224, 224)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224,0.225])
])

In [None]:
test_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=(224, 224)),
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224,0.225])
])

In [None]:
train_dirs = {
    'normal': '/content/drive/My Drive/Colab Notebooks/GAN_Images/Xray/0',
    'pneumonia': '/content/drive/My Drive/Colab Notebooks/GAN_Images/Xray/1',
    'covid': '/content/drive/My Drive/Colab Notebooks/GAN_Images/Xray/2'
}
train_dataset = ChestXRayDataset(train_dirs, train_transform)

In [None]:

test_dirs = {
    'normal': '/content/drive/My Drive/Colab Notebooks/GAN_Images/Xray/test/0',
    'pneumonia': '/content/drive/My Drive/Colab Notebooks/GAN_Images/Xray/test/1',
    'covid': '/content/drive/My Drive/Colab Notebooks/GAN_Images/Xray/test/2'
}
test_dataset = ChestXRayDataset(test_dirs, test_transform)

In [None]:
batch_size = 5

dl_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size,
                                      shuffle=True)
dl_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size,
                                     shuffle=True)

print('Num of training batches', len(dl_train))
print('Num of test batches', len(dl_test))

In [None]:
class_names = train_dataset.class_names

def show_images(images, labels, preds):
    plt.figure(figsize=(12,12))
    for i, image in enumerate(images):
        plt.subplot(1, 5, i +1, xticks=[], yticks=[])
        image = image.numpy().transpose(1, 2, 0)
        mean = np.array([0.485, 0.456, 0.406])
        std = np.array([0.229, 0.224, 0.225])
        image = image * std + mean
        image = np.clip(image, 0., 1.)
        plt.imshow(image)
        
        col = 'green' if preds[i] == labels[i] else 'red'
        
        plt.xlabel(f'{class_names[int(labels[i].numpy())]}')
        plt.ylabel(f'{class_names[int(preds[i].numpy())]}', color=col)
    plt.tight_layout()
    plt.show()

In [None]:
images, labels = next(iter(dl_train))
show_images(images, labels, labels)

In [None]:

images, labels = next(iter(dl_test))
show_images(images, labels, labels)

## **Residual Network**

In [None]:

resnet18 = torchvision.models.resnet18(pretrained=True)
print(resnet18)

In [None]:

resnet18.fc = torch.nn.Linear(in_features=512, out_features=3)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet18.parameters(), lr=3e-5)

In [None]:
def show_preds():
    resnet18.eval()
    images, labels = next(iter(dl_test))
    outputs = resnet18(images)
    _, preds = torch.max(outputs, 1)
    show_images(images, labels, preds)

In [None]:
show_preds()


In [None]:
def train(epochs):
    print('Starting training..')
    for e in range(epochs):
        print('='*20)
        print(f'Starting epoch {e + 1}/{epochs}')
        print('='*20)
        
        train_loss = 0
        
        resnet18.train()
        
        for train_step, (images, labels) in enumerate(dl_train):
            optimizer.zero_grad()
           # images, labels = images.to(device), labels.to(device)
            
            outputs = resnet18(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            if train_step % 20 == 0:
                print('Evaluating at step', train_step)
                acc = 0
                val_loss = 0
                resnet18.eval()
            
                for val_step, (images, labels) in enumerate(dl_test):
                    outputs = resnet18(images)
                    loss = loss_fn(outputs, labels)
                    val_loss +=loss.item()
                
                    _, preds = torch.max(outputs, 1)
                    acc += sum((preds == labels).numpy())
                val_loss /= (val_step +1)
                acc = acc / len(test_dataset)
                print(f'Val loss: {val_loss:.4f}, Acc: {acc:.4f}')
                show_preds()
                  
                resnet18.train()
                  
                if acc > 0.95:
                    print('Performance condition satisfied')
  
                    break
        train_loss /= (train_step + 1)
        print(f'Training loss: {train_loss:.4f}')
    e=e+1
    if e>3:
      return        

In [None]:
hist=train(epochs=3)


In [None]:
def show_preds():
    resnet18.eval()
    images, labels = next(iter(dl_test))
    outputs = resnet18(images)
    _, preds = torch.max(outputs, 1)
    show_images(images, labels, preds)
   

In [None]:
show_preds()


# **Confusion Matrix**

In [None]:
nb_classes = 3

confusion_matrix1 = torch.zeros(nb_classes, nb_classes)
with torch.no_grad():
    for i, (images, labels) in enumerate(dl_test):
        resnet18.eval()
        outputs = resnet18(images)
        _, preds = torch.max(outputs, 1)
        for t, p in zip(labels.view(-1), preds.view(-1)):
                confusion_matrix1[t.long(), p.long()] += 1

print(confusion_matrix1)

In [None]:
import itertools
import numpy as np
import matplotlib.pyplot as plt

def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    precision = confusion_matrix1[1,1] / sum(confusion_matrix1[:,1])
    recall    = confusion_matrix1[1,1] / sum(confusion_matrix1[1,:])
    f1_score  = 2*precision*recall / (precision + recall)
    stats_text = "\n\n\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format(precision,recall,f1_score)
    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' 
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label' + stats_text)

In [None]:
plt.figure(figsize=(10,10))


In [None]:
plot_confusion_matrix(confusion_matrix1, train_dataset.class_names)

In [None]:

precision = confusion_matrix1[1,1] / sum(confusion_matrix1[:,1])
recall    = confusion_matrix1[1,1] / sum(confusion_matrix1[1,:])
f1_score  = 2*precision*recall / (precision + recall)


In [None]:
print(precision)

# **Residual Networ for 34 layers**

In [None]:

resnet34 = torchvision.models.resnet34(pretrained=True)
print(resnet34)

In [None]:

resnet34.fc = torch.nn.Linear(in_features=512, out_features=3)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet34.parameters(), lr=3e-5)

In [None]:
def train(epochs):
    print('Starting training..')
    for e in range(epochs):
        print('='*20)
        print(f'Starting epoch {e + 1}/{epochs}')
        print('='*20)
        
        train_loss = 0
        
        resnet34.train()
        
        for train_step, (images, labels) in enumerate(dl_train):
            optimizer.zero_grad()
           # images, labels = images.to(device), labels.to(device)
            
            outputs = resnet34(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            if train_step % 20 == 0:
                print('Evaluating at step', train_step)
                acc = 0
                val_loss = 0
                resnet34.eval()
            
                for val_step, (images, labels) in enumerate(dl_test):
                    outputs = resnet34(images)
                    loss = loss_fn(outputs, labels)
                    val_loss +=loss.item()
                
                    _, preds = torch.max(outputs, 1)
                    acc += sum((preds == labels).numpy())
                val_loss /= (val_step +1)
                acc = acc / len(test_dataset)
                print(f'Val loss: {val_loss:.4f}, Acc: {acc:.4f}')
                show_preds()
                  
                resnet34.train()
                  
                if acc > 0.95:
                    print('Performance condition satisfied')
  
                    break
        train_loss /= (train_step + 1)
        print(f'Training loss: {train_loss:.4f}')
    e=e+1
    if e>3:
      return        

In [None]:
hist=train(epochs=3)


In [None]:
nb_classes = 3

confusion_matrix2 = torch.zeros(nb_classes, nb_classes)
with torch.no_grad():
    for i, (images, labels) in enumerate(dl_test):

        resnet34.eval()
        outputs = resnet34(images)
        _, preds = torch.max(outputs, 1)
        for t, p in zip(labels.view(-1), preds.view(-1)):
                confusion_matrix2[t.long(), p.long()] += 1

print(confusion_matrix2)

In [None]:
import itertools
import numpy as np
import matplotlib.pyplot as plt

def plot_confusion_matrix2(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    precision = confusion_matrix2[1,1] / sum(confusion_matrix2[:,1])
    recall    = confusion_matrix2[1,1] / sum(confusion_matrix2[1,:])
    f1_score  = 2*precision*recall / (precision + recall)
    stats_text = "\n\n\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format(precision,recall,f1_score)
    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' 
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label' + stats_text)

In [None]:
plot_confusion_matrix2(confusion_matrix2, train_dataset.class_names)

In [None]:

resnet50 = torchvision.models.resnet50(pretrained=True)
print(resnet50)

In [None]:

resnet50.fc = torch.nn.Linear(in_features=2048, out_features=3)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet50.parameters(), lr=3e-5)

In [None]:
def train(epochs):
    print('Starting training..')
    for e in range(epochs):
        print('='*20)
        print(f'Starting epoch {e + 1}/{epochs}')
        print('='*20)
        
        train_loss = 0
        
        resnet50.train()
        
        for train_step, (images, labels) in enumerate(dl_train):
            optimizer.zero_grad()
           # images, labels = images.to(device), labels.to(device)
            
            outputs = resnet50(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss += loss.item()
            if train_step % 20 == 0:
                print('Evaluating at step', train_step)
                acc = 0
                val_loss = 0
                resnet50.eval()
            
                for val_step, (images, labels) in enumerate(dl_test):
                    outputs = resnet50(images)
                    loss = loss_fn(outputs, labels)
                    val_loss +=loss.item()
                
                    _, preds = torch.max(outputs, 1)
                    acc += sum((preds == labels).numpy())
                val_loss /= (val_step +1)
                acc = acc / len(test_dataset)
                print(f'Val loss: {val_loss:.4f}, Acc: {acc:.4f}')
                show_preds()
                  
                resnet50.train()
                  
                if acc > 0.96:
                    print('Performance condition satisfied')
  
                    break
        train_loss /= (train_step + 1)
        print(f'Training loss: {train_loss:.4f}')
    e=e+1
    if e>3:
      return        

In [None]:
hist=train(epochs=3)


In [None]:
nb_classes = 3

confusion_matrix3 = torch.zeros(nb_classes, nb_classes)
with torch.no_grad():
    for i, (images, labels) in enumerate(dl_test):

        resnet50.eval()
        outputs = resnet50(images)
        _, preds = torch.max(outputs, 1)
        for t, p in zip(labels.view(-1), preds.view(-1)):
                confusion_matrix3[t.long(), p.long()] += 1

print(confusion_matrix3)

In [None]:
import itertools
import numpy as np
import matplotlib.pyplot as plt

def plot_confusion_matrix3(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues):
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')
    precision = confusion_matrix3[1,1] / sum(confusion_matrix3[:,1])
    recall    = confusion_matrix3[1,1] / sum(confusion_matrix3[1,:])
    f1_score  = 2*precision*recall / (precision + recall)
    stats_text = "\n\n\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format(precision,recall,f1_score)
    print(cm)
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' 
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black")

    plt.tight_layout()
    plt.ylabel('True label')
    plt.xlabel('Predicted label' + stats_text)

In [None]:
plot_confusion_matrix3(confusion_matrix3, train_dataset.class_names)

In [None]:
model=resnet50

# **Save model**

In [None]:

torch.save(model.state_dict(),"/content/drive/My Drive/COVID-ResNext50_32x4d.pth")

### **Start Gridcam**

In [None]:
import os
import PIL
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.models as models
from torchvision.utils import make_grid, save_image

from utils import visualize_cam, Normalize
from gradcam import GradCAM, GradCAMpp
import cv2

In [None]:

#img_dir = 'images'
img_name1 = cv2.imread('1.jpg',1)
img_name2 = cv2.imread('2.png',1)
img_name3 = cv2.imread('3.jpg',1)
# img_name = 'multiple_dogs.jpg'
# img_name = 'snake.JPEG'
#img_name = 'water-bird.JPEG'
#img_path = os.path.join(img_dir, img_name)

#pil_img = PIL.Image.open(img_name)
#pil_img

In [None]:
normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
torch_img = torch.from_numpy(np.asarray(img_name1)).permute(2, 0, 1).unsqueeze(0).float().div(255).cuda()
torch_img = F.upsample(torch_img, size=(224, 224), mode='bilinear', align_corners=False)
normed_torch_img = normalizer(torch_img)

In [None]:
normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
torch_img2 = torch.from_numpy(np.asarray(img_name2)).permute(2, 0, 1).unsqueeze(0).float().div(255).cuda()
torch_img2 = F.upsample(torch_img2, size=(224, 224), mode='bilinear', align_corners=False)
normed_torch_img2 = normalizer(torch_img2)

In [None]:
normalizer = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
torch_img3 = torch.from_numpy(np.asarray(img_name2)).permute(2, 0, 1).unsqueeze(0).float().div(255).cuda()
torch_img3 = F.upsample(torch_img3, size=(224, 224), mode='bilinear', align_corners=False)
normed_torch_img3 = normalizer(torch_img3)

In [None]:
from torchvision import transforms

In [None]:
resnet = models.resnet50(pretrained=True)
resnet.eval(), resnet.cuda();



cam_dict = dict()
resnet_model_dict = dict(type='resnet50', arch=resnet, layer_name='layer4', input_size=(224, 224))
resnet_gradcam = GradCAM(resnet_model_dict, True)
resnet_gradcampp = GradCAMpp(resnet_model_dict, True)
cam_dict['resnet'] = [resnet_gradcam, resnet_gradcampp]



In [None]:
images = []
for gradcam, gradcam_pp in cam_dict.values():
    mask, _ = gradcam(normed_torch_img)
    mask=mask.cpu()
    heatmap, result = visualize_cam(mask, torch_img)

    mask_pp, _ = gradcam_pp(normed_torch_img)
    mask_pp=mask_pp.cpu()
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img)
    
    images.append(torch.stack([torch_img.squeeze().cpu(), heatmap, heatmap_pp, result, result_pp], 0))
    
images = make_grid(torch.cat(images, 0), nrow=5)

In [None]:
images2 = []
for gradcam, gradcam_pp in cam_dict.values():
    mask, _ = gradcam(normed_torch_img2)
    mask=mask.cpu()
    heatmap, result = visualize_cam(mask, torch_img2)

    mask_pp, _ = gradcam_pp(normed_torch_img2)
    mask_pp=mask_pp.cpu()
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img2)
    
    images2.append(torch.stack([torch_img2.squeeze().cpu(), heatmap, heatmap_pp, result, result_pp], 0))
    
images2 = make_grid(torch.cat(images2, 0), nrow=5)

In [None]:
plt.figure(figsize=(50,50))

#for  i in range (0,2):
plt.imshow(images2[2])

In [None]:
output_dir = 'outputs'
os.makedirs(output_dir, exist_ok=True)
output_name = "img_name"
output_path = os.path.join(output_dir, output_dir)

#save_image(images, output_dir)
#PIL.Image.open(output_path)

In [None]:
import matplotlib.pyplot as plt


In [None]:
images3 = []
for gradcam, gradcam_pp in cam_dict.values():
    mask, _ = gradcam(normed_torch_img3)
    mask=mask.cpu()
    heatmap, result = visualize_cam(mask, torch_img3)

    mask_pp, _ = gradcam_pp(normed_torch_img3)
    mask_pp=mask_pp.cpu()
    heatmap_pp, result_pp = visualize_cam(mask_pp, torch_img3)
    
    images3.append(torch.stack([torch_img2.squeeze().cpu(), heatmap, heatmap_pp, result, result_pp], 0))
    
images3 = make_grid(torch.cat(images3, 0), nrow=5)

In [None]:
plt.figure(figsize=(50,50))


plt.imshow(images3[2])

In [None]:
plt.imshow(images[2])