In [None]:
import cv2
import math
import matplotlib.pyplot as plt

import numpy as np
import torch
from torch import nn, optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.hub import load_state_dict_from_url

from torch.utils.data import Dataset
import os

from PIL import Image

from tqdm import tqdm

from sklearn.metrics import confusion_matrix
import pandas as pd
import seaborn as sns
from scipy.ndimage import gaussian_filter1d


In [None]:
torch.__version__

In [None]:
torch.backends.mps.is_available()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device='mps'
print('Using device:', device)

In [None]:
root="./chest_xray/chest_xray_ternary/train"
data=[]
for catagory in os.listdir(root):
        if catagory=='.DS_Store':
            continue
        catagory_path=os.path.join(root,catagory)
        for image in os.listdir(catagory_path):
            image_path=os.path.join(catagory_path,image)
            data.append((Image.open(image_path).width,Image.open(image_path).height))
x = np.array([i[0] for i in data])
y = np.array([i[1] for i in data])
plt.figure(figsize=(6,6))
plt.scatter(x, y, alpha=0.3,marker='.',)
plt.show()


In [None]:
class X_ray(Dataset):
    def __init__(self,root,transform):
        self.root=root
        self.transform = transform
        self.data=[]
        self.namelabel={"BACTERIAL":0,"NORMAL":1,"VIRAL":2}

        for catagory in os.listdir(root):
            if catagory=='.DS_Store':
                continue
            catagory_path=os.path.join(root,catagory)
            for image in os.listdir(catagory_path):
                image_path=os.path.join(catagory_path,image)
                self.data.append((image_path,catagory))
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_loc = self.data[idx][0]
     
        # Use PIL for image loading
        image = Image.open(img_loc).convert("RGB")
        # Apply the transformations
        tensor_image = self.transform(image)

        target=torch.tensor(int(self.namelabel[self.data[idx][1]]))
        
        return tensor_image.to(device),target.to(device)

In [None]:
trnsfrm = transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ToTensor()
])

train_vaild_set=X_ray(
    root='./chest_xray/chest_xray_ternary/train',
    transform=trnsfrm
)

train_num = int(len(train_vaild_set)*0.7)+1
valid_num = int(len(train_vaild_set)*0.3)


train_set,valid_set=torch.utils.data.random_split(
    train_vaild_set,
    lengths=[train_num,valid_num],
    generator=torch.Generator().manual_seed(0)
)


In [None]:
train_loader=torch.utils.data.DataLoader(
    train_set,
    batch_size=64,
    shuffle=True,
    num_workers=0
)
valid_loader=torch.utils.data.DataLoader(
    valid_set,
    batch_size=64,
    shuffle=True,
    num_workers=0
)

In [None]:
model=models.vgg16(pretrained=False).to(device)
model.classifier._modules['6'] = nn.Linear(4096,3)

In [None]:
def valid(valid_loader,net):
    net=net.to(device)
    loss_fn = nn.CrossEntropyLoss()
    correct = 0
    total = 0
    total_valid_loss = 0
    n = 0    # counter for number of minibatches
    with torch.no_grad():           #valid
        for data in valid_loader:
            img,target = data
            outputs = net(img)
            loss = loss_fn(outputs,target)
            total_valid_loss += loss.item()
            n+=1

        _, predicted = torch.max(outputs.data, 1)
        total += target.size(0)    # add in the number of labels in this batch
        correct += (predicted == target).sum().item()  # add in the number of correct labels

        # collect together statistics for this epoch

        lvld = total_valid_loss/n
        avld = correct/total
    return lvld,avld

In [None]:
def Training(train_loader,valid_loader,net,nepochs):
    loss_fn = nn.CrossEntropyLoss()
    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)
    statsrec = np.zeros((4,nepochs))    
    net=net.to(device)
    for epoch in range(nepochs): 
        print("------------epoch:{d}------------".format(epoch+1)) 
        correct=0            # number of examples predicted correctly (for accuracy) 
        total = 0            # number of examples
        running_loss = 0.0   # accumulated loss (for mean loss)
        n = 0                # number of minibatches
        for data in tqdm(train_loader):
            inputs, labels = data

            # Zero the parameter gradients
            optimizer.zero_grad()

            # Forward, backward, and update parameters
            outputs = net(inputs)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # accumulate data for accuracy
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)    # add in the number of labels in this minibatch
            correct += (predicted == labels).sum().item()  # add in the number of correct labels

            n += 1
            # if n%100 == 0:
            #     print('Numbers of training:{}, Loss:{:.3f}'.format(n,loss))

        ltrn = running_loss/n
        atrn = correct/total 

        lvld, avld=valid(valid_loader,net)

        statsrec[:,epoch] = (ltrn, atrn, lvld, avld)

        print('accurancy of train:{:.1%}, accurancy of validation:{:.1%}'.format(atrn,avld))

    return statsrec

    

In [None]:
nepochs=5
Training_statsrec=Training(train_loader,valid_loader,model,nepochs)

In [None]:
nepochs=50
Training_statsrec=Training(train_loader,valid_loader,model,nepochs)

In [None]:
ltrn=gaussian_filter1d(Training_statsrec[0],sigma=2)
lvld=gaussian_filter1d(Training_statsrec[2],sigma=2)

x=[i for i in range(50)]
plt.figure(figsize=(14,6))
plt.subplot(1,2,1)
plt.title('Loss')
plt.xlabel('epochs')
plt.plot(x, ltrn,label='ltrn') 
plt.plot(x, lvld,label='lvld') 
plt.legend()

atrn=gaussian_filter1d(Training_statsrec[1],sigma=2)
avld=gaussian_filter1d(Training_statsrec[3],sigma=2)
plt.subplot(1,2,2)
plt.title('Accuracy')
plt.xlabel('epochs')
plt.plot(x, atrn,label='atrn')
plt.plot(x, avld,label='avld') 
plt.legend()
plt.show()

In [None]:
def get_all_pred(model,loader,categories_names):
    all_preds=torch.tensor([])
    all_targets=torch.tensor([])
    i=0
    with torch.no_grad():
        for batch in tqdm(loader):
            images,label=batch
            preds=model(images)
            
            all_preds=torch.cat((all_preds,preds),dim=0)
            all_targets=torch.cat((all_targets,label),dim=0)
            
        fig, ax = plt.subplots(figsize=(6,6))
        cm=confusion_matrix(all_targets.tolist(),all_preds.argmax(dim=1).tolist())
        conf_matrix=pd.DataFrame(data=cm,columns=categories_names
                                          ,index=categories_names)
        sns.heatmap(conf_matrix, annot=True, fmt="d",cmap='gray')
        
        plt.show()

In [None]:
categories_names=["BACTERIAL","NORMAL","VIRAL"]

In [None]:
get_all_pred(model,train_loader,categories_names)

In [None]:
get_all_pred(model,valid_loader,categories_names)

In [None]:
class testSet(Dataset):
    def __init__(self,root,transform):
        self.root=root
        self.transform = transform
        self.data=[]
        self.namelabel={"BACTERIAL":0,"NORMAL":1,"VIRAL":2}


        for catagory in os.listdir(root):
            if catagory=='.DS_Store':
                continue
            catagory_path=os.path.join(root,catagory)
            for image in os.listdir(catagory_path):
                image_path=os.path.join(catagory_path,image)
                self.data.append((image_path,catagory))
        
    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_loc = self.data[idx][0]
     
        # Use PIL for image loading
        image = Image.open(img_loc).convert("RGB")
        # Apply the transformations
        tensor_image = self.transform(image)
        label = torch.tensor(self.namelabel[self.data[idx][1]])
        return tensor_image, label

trnsfrm = transforms.Compose([
    transforms.Resize([224,224]),
    transforms.ToTensor(),
])

test_set=testSet(
    root='./chest_xray/chest_xray_ternary/test',
    transform=trnsfrm
)
test_loader=torch.utils.data.DataLoader(
    test_set,
    batch_size=64,
    num_workers=0
)

In [None]:
torch.save(model.state_dict(),'net.pth')
test_model=models.vgg16(pretrained=False)
test_model.classifier._modules['6'] = nn.Linear(4096,3)
test_model.load_state_dict(torch.load('net.pth'))

In [None]:
get_all_pred(test_model,test_loader,categories_names)

In [None]:
print(model)