In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data.dataloader import DataLoader
from torch.utils.data import Dataset
from torchvision.datasets import ImageFolder
from torchvision.transforms import ToTensor
import torchvision
import os
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix,accuracy_score 
from sklearn.utils import shuffle
import pydicom
from pydicom.data import get_testdata_files
from pydicom import dcmread
import time
import csv
from datetime import date
from pytorch_grad_cam import GradCAM
%matplotlib inline

#report day
today=date.today()

In [None]:
# Function to take care of teh translation and windowing. 
# source: https://www.kaggle.com/code/redwankarimsony/ct-scans-dicom-files-windowing-explained/notebook
from PIL import Image

def window_image(img, window_center,window_width, intercept, slope, rescale=True):
    img = (img*slope +intercept) #for translation adjustments given in the dicom file. 
    img_min = window_center - window_width/2 
    img_max = window_center + window_width/2 
    img[img<img_min] = img_min 
    img[img>img_max] = img_max 
    if rescale: 
        img = (img - img_min) / (img_max - img_min)*255.0 
    return img
    
def get_first_of_dicom_field_as_int(x):
    #get x[0] as in int is x is a 'pydicom.multival.MultiValue', otherwise get int(x)
    if type(x) == pydicom.multival.MultiValue: return int(x[0])
    else: return int(x)
    
def get_windowing(data):
    dicom_fields = [data[('0028','1050')].value, #window center
                    data[('0028','1051')].value, #window width
                    data[('0028','1052')].value, #intercept
                    data[('0028','1053')].value] #slope
    return [get_first_of_dicom_field_as_int(x) for x in dicom_fields]

In [None]:
#Load Dataset
if torch.cuda.is_available():  
    torch.cuda.empty_cache()
    dev = "cuda:0" 
else:  
    dev = "cpu" 
    
start = time.time()

path='LungCT'
filelist=os.listdir(path)
count=0


train_ct_scan=[]
train_label=[]

test_ct_scan=[]
test_label=[]

counta=0
countb=0
#counte=0
countg=0

train_len_a=0
test_len_a=0

train_len_b=0
test_len_b=0

#train_len_e=0
#test_len_e=0

train_len_g=0
test_len_g=0

for file in filelist:
    if 'A' in file:
        if counta < 71:
            path2=path+'/'+file+'/'
            path2list=os.listdir(path2)
            for folder2 in path2list:
                path3=path2+folder2+'/'
                path3list=os.listdir(path3)
                for folder3 in path3list:
                    path4=path3+folder3+'/'
                    path4list=os.listdir(path4)
                    for imgs in path4list:
                        imgdata=path4+imgs
                        pict=pydicom.read_file(imgdata)
                        windows=get_windowing(pict)
                        img=window_image(pict.pixel_array[75:437,75:437],windows[0],windows[1],windows[2],windows[3])
                        train_ct_scan.append(img.reshape(1,362,362))
                        train_label.append(0)
                        train_len_a+=1
        elif counta < 80:
            path2=path+'/'+file+'/'
            path2list=os.listdir(path2)
            for folder2 in path2list:
                path3=path2+folder2+'/'
                path3list=os.listdir(path3)
                for folder3 in path3list:
                    path4=path3+folder3+'/'
                    path4list=os.listdir(path4)
                    for imgs in path4list:
                        imgdata=path4+imgs
                        pict=pydicom.read_file(imgdata)
                        windows=get_windowing(pict)
                        img=window_image(pict.pixel_array[75:437,75:437],windows[0],windows[1],windows[2],windows[3])
                        test_ct_scan.append(img.reshape(1,362,362))
                        test_label.append(0)
                        test_len_a+=1
        else:
            continue
        counta+=1
    elif 'B' in file:
        if countb < 33:
            path2=path+'/'+file+'/'
            path2list=os.listdir(path2)
            for folder2 in path2list:
                path3=path2+folder2+'/'
                path3list=os.listdir(path3)
                for folder3 in path3list:
                    path4=path3+folder3+'/'
                    path4list=os.listdir(path4)
                    for imgs in path4list:
                        imgdata=path4+imgs
                        pict=pydicom.read_file(imgdata)
                        windows=get_windowing(pict)
                        img=window_image(pict.pixel_array[75:437,75:437],windows[0],windows[1],windows[2],windows[3])
                        train_ct_scan.append(img.reshape(1,362,362))
                        train_label.append(1)
                        train_len_b+=1
        elif countb < 36:
            path2=path+'/'+file+'/'
            path2list=os.listdir(path2)
            for folder2 in path2list:
                path3=path2+folder2+'/'
                path3list=os.listdir(path3)
                for folder3 in path3list:
                    path4=path3+folder3+'/'
                    path4list=os.listdir(path4)
                    for imgs in path4list:
                        imgdata=path4+imgs
                        pict=pydicom.read_file(imgdata)
                        windows=get_windowing(pict)
                        img=window_image(pict.pixel_array[75:437,75:437],windows[0],windows[1],windows[2],windows[3])
                        test_ct_scan.append(img.reshape(1,362,362))
                        test_label.append(1)
                        test_len_b+=1
        else:
            continue
        countb+=1
    elif 'G' in file:
        if countg < 36:
            path2=path+'/'+file+'/'
            path2list=os.listdir(path2)
            for folder2 in path2list:
                path3=path2+folder2+'/'
                path3list=os.listdir(path3)
                for folder3 in path3list:
                    path4=path3+folder3+'/'
                    path4list=os.listdir(path4)
                    for imgs in path4list:
                        imgdata=path4+imgs
                        pict=pydicom.read_file(imgdata)
                        windows=get_windowing(pict)
                        img=window_image(pict.pixel_array[75:437,75:437],windows[0],windows[1],windows[2],windows[3])
                        train_ct_scan.append(img.reshape(1,362,362))
                        train_label.append(2)
                        train_len_g+=1
        elif countg < 40:
            path2=path+'/'+file+'/'
            path2list=os.listdir(path2)
            for folder2 in path2list:
                path3=path2+folder2+'/'
                path3list=os.listdir(path3)
                for folder3 in path3list:
                    path4=path3+folder3+'/'
                    path4list=os.listdir(path4)
                    for imgs in path4list:
                        imgdata=path4+imgs
                        pict=pydicom.read_file(imgdata)
                        windows=get_windowing(pict)
                        img=window_image(pict.pixel_array[75:437,75:437],windows[0],windows[1],windows[2],windows[3])
                        test_ct_scan.append(img.reshape(1,362,362))
                        test_label.append(2)
                        test_len_g+=1
        else:
            continue
        countg+=1
    

print(len(train_ct_scan))
print(len(test_ct_scan))

print()
print(train_len_a)
print(test_len_a)
print()
print(train_len_b)
print(test_len_b)
print()
#print(train_len_e)
#print(test_len_e)
print()
print(train_len_g)
print(test_len_g)
print()


end = time.time()
print(end - start)
train_ct_scan=np.array(train_ct_scan)
train_ct_scan=torch.from_numpy(train_ct_scan)
train_label=np.array(train_label)
train_label=torch.tensor(train_label,dtype=torch.long)
test_ct_scan=np.array(test_ct_scan)
test_ct_scan=torch.from_numpy(test_ct_scan)
test_label=np.array(test_label)
test_label=torch.tensor(test_label,dtype=torch.long)


In [None]:
#Model_1
class imgs(Dataset):
    def __init__(self,data,target,transform=None):
        self.data=data
        self.target=target
        self.transform=transform
    def __len__(self):
        return len(self.target)
    def __getitem__(self, index):
        img = self.data[index]
        lbl = self.target[index]
        return img, lbl
    
class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.conv1=nn.Conv2d(1,3,7,3,1) #170
        self.bn1=nn.BatchNorm2d(3)
        self.relu1=nn.ReLU()
        self.pool=nn.MaxPool2d(2) #85
        self.fc1 = nn.Linear(60*60*3,3)
    def forward(self,x):
        x=self.conv1(x)
        x=self.bn1(x)
        x=self.relu1(x)
        x=self.pool(x)
        x=x.view(-1,60*60*3)
        x=self.fc1(x)
        return(x)

device=dev
torch.manual_seed(4712)

trainset=imgs(train_ct_scan,train_label)
testset=imgs(test_ct_scan,test_label)
trainload=DataLoader(trainset,4,shuffle=True)
testload=DataLoader(testset,4,shuffle=True)

modelname="Model_1"
saveepoch=0
model=Net().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=0.001)
loss_function = nn.CrossEntropyLoss()
final_train_accuracy=0
final_test_accuracy=0
tg=[] #test labels
pd=[] #predict labels

max_accuracy=0

wholestart=time.time()
for epoch in range(40):
    start=time.time()
    tg=[]
    pd=[]
    train_accuracy=0.0
    test_accuracy=0.0
    model.train()
    for data,target in trainload:
        data,target=data.to(device),target.to(device)
        optimizer.zero_grad()
        output=model(data.float())
        loss=loss_function(output,target)
        loss.backward()
        optimizer.step()
        _,prediction=torch.max(output.data,1)
        train_accuracy+=int(torch.sum(prediction==target.data))
    train_accuracy=train_accuracy/len(train_label) 
    model.eval()
    
    for data,target in testload:
        data,target=data.to(device),target.to(device)
        output=model(data.float())
        _,prediction=torch.max(output.data,1)
        test_accuracy+=int(torch.sum(prediction==target.data))
        for k in prediction:
            pd.append(k.item())
        for l in target.data:
            tg.append(l.item())
    test_accuracy=test_accuracy/len(test_label)
    end=time.time()
    runtime=end-start
    if test_accuracy>max_accuracy:
        max_accuracy=test_accuracy
        max_matrix=confusion_matrix(tg,pd)
        saveepoch=epoch
        savetime=time.time()-wholestart
        torch.save(model,'net.pt')
    print('epoch '+str(epoch+1)+' train acurracy: '+str(train_accuracy)+' test accuracy: '+str(test_accuracy)+' run-time: '+str(runtime))
    final_train_accuracy=train_accuracy
    final_test_accuracy=test_accuracy

print()
print("Accuracy: ",max_accuracy)
print("Confusion matrix: ")
print(max_matrix)
note='selected'
to_csv=[modelname,str(today),saveepoch+1,len(train_ct_scan),len(test_ct_scan),max_accuracy,max_matrix[0][0],max_matrix[0][1],max_matrix[0][2],'N/A',max_matrix[1][0],max_matrix[1][1],max_matrix[1][2],'N/A',max_matrix[2][0],max_matrix[2][1],max_matrix[2][2],'N/A','N/A','N/A','N/A','N/A',savetime,note]
#print(to_csv)

In [None]:
#ResNet50
import torchvision.models as models
class ResNet(nn.Module):
    def __init__(self, in_channels=1):
        super(ResNet, self).__init__()
        self.model = models.resnet50(pretrained=True)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        num_ftrs = self.model.fc.in_features
        self.model.fc = nn.Linear(num_ftrs, 3)

    def forward(self, x):
        return self.model(x)

In [None]:
#ResNet
trainset=imgs(train_ct_scan,train_label)
testset=imgs(test_ct_scan,test_label)
trainload=DataLoader(trainset,4,shuffle=True)
testload=DataLoader(testset,4,shuffle=True)

if torch.cuda.is_available():
    torch.cuda.empty_cache()
modelname="ResNet"
saveepoch=0

model=ResNet().to(device)
optimizer = optim.Adadelta(model.parameters(), lr=0.001)
loss_function = nn.CrossEntropyLoss()
final_train_accuracy=0
final_test_accuracy=0
tg=[] #test labels
pd=[] #predict labels

max_accuracy=0
wholestart=time.time()
try:
    for epoch in range(13):
        start=time.time()
        tg=[]
        pd=[]
        train_accuracy=0.0
        test_accuracy=0.0
        model.train()
        for data,target in trainload:
            data,target=data.to(device),target.to(device)
            optimizer.zero_grad()
            output=model(data.float())
            loss=loss_function(output,target)
            loss.backward()
            optimizer.step()
            _,prediction=torch.max(output.data,1)
            train_accuracy+=int(torch.sum(prediction==target.data))
        train_accuracy=train_accuracy/len(train_label) 
        model.eval()
        for data,target in testload:
            data,target=data.to(device),target.to(device)
            output=model(data.float())
            _,prediction=torch.max(output.data,1)
            test_accuracy+=int(torch.sum(prediction==target.data))
            for k in prediction:
                pd.append(k.item())
            for l in target.data:
                tg.append(l.item())
        test_accuracy=test_accuracy/len(test_label)
        end=time.time()
        runtime=end-start
        if test_accuracy>max_accuracy:
            saveepoch=epoch
            max_accuracy=test_accuracy
            max_matrix=confusion_matrix(tg,pd)
            savetime=time.time()-wholestart
            torch.save(model,'net.pt')
        print('epoch '+str(epoch+1)+' train acurracy: '+str(train_accuracy)+' test accuracy: '+str(test_accuracy)+' run-time: '+str(runtime))
        final_train_accuracy=train_accuracy
        final_test_accuracy=test_accuracy
except KeyboardInterrupt:
    print('Interrupted')
except RuntimeError:
    print('GPU not enough')
    

print()
print("Accuracy: ",max_accuracy)
print("Confusion matrix: ")
print(max_matrix)
to_csv=[modelname,str(today),saveepoch+1,len(train_ct_scan),len(test_ct_scan),max_accuracy,max_matrix[0][0],max_matrix[0][1],max_matrix[0][2],'N/A',max_matrix[1][0],max_matrix[1][1],max_matrix[1][2],'N/A',max_matrix[2][0],max_matrix[2][1],max_matrix[2][2],'N/A','N/A','N/A','N/A','N/A',savetime,note]

#print(to_csv)

In [None]:
#save to report

with open('report.csv', 'a',newline='') as report:
    writer = csv.writer(report)
    writer.writerow(to_csv)