In [2]:
from torchvision import models
from torchvision.models import ResNet18_Weights
import torch.nn as nn
from torch.nn.modules.loss import BCEWithLogitsLoss
import torch.optim as optim
import torch
import torchvision
from torchvision import transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
import random
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, roc_auc_score, accuracy_score
import seaborn as sns
import matplotlib.pyplot as plt
from sklearn.preprocessing import label_binarize
import numpy as np
import pandas as pd
from make_dataset import DataGen
import csv 
import re

In [37]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Training on device {device}.")


model = models.resnet18(pretrained=False)  
num_features = model.fc.in_features
model.fc = nn.Linear(num_features, 1)  


model_checkpoint_path = 'finetune_model_epoch49.pth'
model.load_state_dict(torch.load(model_checkpoint_path, map_location=device))


model = model.to(device)
model.eval()

Training on device cuda.




ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

In [18]:
test_df = list(pd.read_csv('testset_values_50epoch.csv'))

In [33]:
patient_numbers = [filename.split('_')[0] for filename in test_df]
patient_number_set = set(patient_numbers)

In [41]:

img_labels = pd.read_csv(r'C:\Users\collaborations\Downloads\per_scan_data.csv')



def find_label(img_name, img_labels):
   
    name_column = 'scan_name'

    found_row = img_labels[img_labels[name_column] == img_name]
    
    return found_row['status'].item()

labels = []
for i in range(len(test_df)):
    img_name = test_df[i]
 
    label = find_label(img_name, img_labels)
    if label == True:
        labels.append(1)
    else:
        labels.append(0)



In [42]:
len(labels)


4300

In [45]:
all_slices_path = r'\\fsmresfiles.fsm.northwestern.edu\fsmresfiles\Ophthalmology\Mirza_Images\AMD\dAMD_GA\all_slices_3'

for patient in patient_number_set:
    indices = [index for index, value in enumerate(patient_numbers) if value == patient]
    patient_X = [test_df[index] for index in indices]
    patient_y = [labels[index] for index in indices]

    test_dataset = DataGen(patient_X, patient_y, all_slices_path, image_height=1536, image_width=500)

    test_dataloader = DataLoader(test_dataset, batch_size=32)

    TP_count = 0
    TN_count = 0
    FP_count = 0
    FN_count = 0

    with torch.no_grad():
        for i, (inputs, ground) in enumerate(test_dataloader):
            inputs, ground = inputs.float().to(device), ground.to(device)
            inputs = inputs.permute(0, 3, 1, 2).to(device)
            outputs = model(inputs).squeeze()

            preds = outputs.cpu().detach().numpy()
            pred_binary = np.where(preds >= 0.5, 1, 0)

            ground = ground.float()
       
            tn, fp, fn, tp = confusion_matrix(ground.cpu().numpy(), pred_binary, labels=[0,1]).ravel()
            TP_count += tp
            TN_count += tn
            FP_count += fp
            FN_count += fn
    print('for patient number:', patient)
    print("******* TP: ", TP_count, " TN: ", TN_count, " FP: ", FP_count, " FN: ", FN_count, " *******")

for patient number: 321
******* TP:  54  TN:  82  FP:  11  FN:  56  *******
for patient number: 047
******* TP:  27  TN:  152  FP:  15  FN:  8  *******
for patient number: 064
******* TP:  542  TN:  1599  FP:  9  FN:  160  *******
for patient number: 190
******* TP:  23  TN:  143  FP:  7  FN:  21  *******
for patient number: 326
******* TP:  9  TN:  275  FP:  8  FN:  17  *******
for patient number: 341
******* TP:  48  TN:  50  FP:  0  FN:  14  *******
for patient number: 578
******* TP:  207  TN:  131  FP:  14  FN:  57  *******
for patient number: 345
******* TP:  21  TN:  481  FP:  11  FN:  48  *******
