In [1]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
from AutoEncoderCNN import AE_CNN
from util.random_patient import random_split
from util.ImageFolderWithPaths import ImageFolderWithPaths
from util.loader_info import get_totals

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
torch.set_default_device(device)
torch.cuda.empty_cache()

In [3]:
MAIN_PATH = '/home/u6/njcrutchfield/torch/NN/pill_data/pillQC-main/' # images are 225 x 225 x 3

PATH1 = '/groups/francescavitali/eb2/NewsubSubImages4/H&E/'
PATH2 = '/groups/francescavitali/eb2/NewsubSubImages4/H&E/S'

tensor_transform = transforms.ToTensor()

dataset = ImageFolderWithPaths(MAIN_PATH, transform = tensor_transform)

SPLIT = [0.8, 0.1, 0.1]

train_set, val_set, test_set = random_split(PATH1, PATH2, dataset, split_percent = SPLIT, rand_seed = 8)


Finished getting labels
0
5000
10000
15000
20000
25000
30000
35000
40000
45000
50000
55000
60000
65000
Returning


In [4]:
train_loader = torch.utils.data.DataLoader(dataset = train_set,
                                            batch_size = 4,
                                            shuffle = True,
                                            generator=torch.Generator(device=device))

In [5]:
test_loader = torch.utils.data.DataLoader(dataset = test_set,
                                            batch_size = 1,
                                            shuffle = True,
                                            generator=torch.Generator(device=device))

In [7]:
print(get_totals(test_loader))

{0: 1643, 1: 4988}


In [8]:
model = AE_CNN().to(device)

model.load_state_dict(torch.load('./models/model_gs.pth')) # loading best model state

# setting the encoder
encoder = model.encoder

In [9]:
encoder

Sequential(
  (0): Conv2d(3, 64, kernel_size=(8, 8), stride=(2, 2), padding=(1, 1))
  (1): ReLU()
  (2): Conv2d(64, 32, kernel_size=(8, 8), stride=(2, 2), padding=(1, 1))
  (3): ReLU()
  (4): Conv2d(32, 26, kernel_size=(8, 8), stride=(2, 2), padding=(1, 1))
  (5): ReLU()
  (6): Flatten(start_dim=1, end_dim=-1)
  (7): Linear(in_features=28314, out_features=28314, bias=True)
)

In [10]:
class NeuralNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self._feed_forward = torch.nn.Sequential(
            torch.nn.Linear(28314, 16384),
            torch.nn.ReLU(),
            torch.nn.Linear(16384, 4096),
            torch.nn.ReLU(),
            torch.nn.Linear(4096, 1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 1),
        )
        self._sigmoid = torch.nn.Sigmoid()

        
    
    def forward(self,  x):
        output = self._feed_forward(x) 
        return self._sigmoid(output)

In [11]:
nn = NeuralNet().to(device)
nn.load_state_dict(torch.load('./ClassifierModels/big_model.pth'))

<All keys matched successfully>

In [12]:
ans = []
total_samples = 0
total_correct = 0
nn.to(device)
nn.eval()
for (image, label, fname) in test_loader:
    nn.eval()
    image = image.to(device)
    label = label.to(device)
    
    with torch.no_grad():
        # feeding through nn
        encoded = encoder(image)
        not_rounded = nn._feed_forward(encoded)
        outputs = nn._sigmoid(not_rounded)
        
        # results based on sigmoid
        if outputs < 0.5: 
            outputs = 0
        else:
            outputs = 1
            
        # for calculating percentage and visualizing
        total_correct += (outputs == label).item()
        ans.append((label.item(), outputs, fname))

        
print(f'Accuracy: {total_correct/len(test_loader)*100:.2f}%')

  return func(*args, **kwargs)


Accuracy: 28.35%


In [13]:
d = {} # d[patient #] = [goal, # of 1's guessed, total images] 
for i in range(len(ans)):
    patient = ans[i][2][0].split('/')[7]
    if patient in d:
        d[patient][1] += ans[i][1]
        d[patient][2] += 1
    else:
        d[patient] = [ans[i][0], ans[i][1], 1]
        

In [14]:
max_thresh = 55
for thresh in range(1,max_thresh):
    correct = 0
    for patient, val in d.items():
        guess = 0
        if val[1] > thresh:
            guess = 1
        if guess == val[0]:
            correct += 1
    print(f'Accuracy with {thresh=}: {correct/len(d.keys())*100:.2f}%')

Accuracy with thresh=1: 62.50%
Accuracy with thresh=2: 62.50%
Accuracy with thresh=3: 62.50%
Accuracy with thresh=4: 62.50%
Accuracy with thresh=5: 62.50%
Accuracy with thresh=6: 50.00%
Accuracy with thresh=7: 50.00%
Accuracy with thresh=8: 50.00%
Accuracy with thresh=9: 50.00%
Accuracy with thresh=10: 50.00%
Accuracy with thresh=11: 50.00%
Accuracy with thresh=12: 50.00%
Accuracy with thresh=13: 50.00%
Accuracy with thresh=14: 50.00%
Accuracy with thresh=15: 62.50%
Accuracy with thresh=16: 62.50%
Accuracy with thresh=17: 62.50%
Accuracy with thresh=18: 62.50%
Accuracy with thresh=19: 62.50%
Accuracy with thresh=20: 62.50%
Accuracy with thresh=21: 62.50%
Accuracy with thresh=22: 62.50%
Accuracy with thresh=23: 62.50%
Accuracy with thresh=24: 62.50%
Accuracy with thresh=25: 62.50%
Accuracy with thresh=26: 62.50%
Accuracy with thresh=27: 62.50%
Accuracy with thresh=28: 62.50%
Accuracy with thresh=29: 62.50%
Accuracy with thresh=30: 62.50%
Accuracy with thresh=31: 75.00%
Accuracy with thr