In [1]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
from AutoEncoderCNN import AE_CNN
from GridSearch import GridSearch

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
torch.set_default_device(device)

In [3]:

torch.cuda.empty_cache()

In [4]:
PATH = '/groups/francescavitali/eb2/NewsubSubImages4/H&E' # /groups/francescavitali/eb2/subImages_slide299/H&E
BATCH_SIZE = 4 
SPLIT = [55767, 6971, 6971]

tensor_transform = transforms.ToTensor()

dataset = datasets.ImageFolder(PATH, 
                               transform = tensor_transform) #loads the images

train_set, val_set, test_set = torch.utils.data.random_split(dataset,
                                                           SPLIT,# 80%, 10%, 10%
                                                           generator=torch.Generator(device=device))

loader = torch.utils.data.DataLoader(dataset = train_set,
                                            batch_size = BATCH_SIZE,
                                            shuffle = True,
                                            generator=torch.Generator(device=device))

In [5]:
len(dataset.targets)

69709

In [6]:
model = AE_CNN().to(device)

model.load_state_dict(torch.load('./models/model_gs.pth')) # loading best model state
#model.load_state_dict(torch.load('./models/Copy Models/model_gs_3-28-2024.pth'))

# setting the encoder and decoder for visualization
encoder = model.encoder

In [7]:
count = 0
with torch.no_grad():
    for(img, goal) in loader: # goal will be a tensor of len == batch_size
        if count == 1:
            break
        img = img.to(device)
        print(f'Img Shape: {img.shape}')
        encoded_img = encoder(img)
        print(f'Encoded Shape: {encoded_img.shape}')
        flattened = encoded_img.flatten(start_dim = 1)
        print(f'Flattened: {flattened.shape}')
        print(f'Goal: {goal.unsqueeze(1)}')
        print()
        count += 1

Img Shape: torch.Size([4, 3, 299, 299])
Encoded Shape: torch.Size([4, 28314])
Flattened: torch.Size([4, 28314])
Goal: tensor([[0],
        [1],
        [1],
        [0]], device='cuda:0')



  return func(*args, **kwargs)


In [8]:
encoder

Sequential(
  (0): Conv2d(3, 64, kernel_size=(8, 8), stride=(2, 2), padding=(1, 1))
  (1): ReLU()
  (2): Conv2d(64, 32, kernel_size=(8, 8), stride=(2, 2), padding=(1, 1))
  (3): ReLU()
  (4): Conv2d(32, 26, kernel_size=(8, 8), stride=(2, 2), padding=(1, 1))
  (5): ReLU()
  (6): Flatten(start_dim=1, end_dim=-1)
  (7): Linear(in_features=28314, out_features=28314, bias=True)
)

In [9]:
class NeuralNet(torch.nn.Module):
    def __init__(self):
        super().__init__()
        
        self._feed_forward = torch.nn.Sequential(
            torch.nn.Linear(28314, 16384),
            torch.nn.ReLU(),
            torch.nn.Linear(16384, 4096),
            torch.nn.ReLU(),
            torch.nn.Linear(4096, 1024),
            torch.nn.ReLU(),
            torch.nn.Linear(1024, 256),
            torch.nn.ReLU(),
            torch.nn.Linear(256, 64),
            torch.nn.ReLU(),
            torch.nn.Linear(64, 1),
        )
        self._sigmoid = torch.nn.Sigmoid()

        
    
    def forward(self,  x):
        output = self._feed_forward(flattened)
        return self._sigmoid(output)
    

In [10]:
nn = NeuralNet().to(device)

In [11]:
lr = 0.0000001 
weight_decay = 1e-5
EPOCHS = 30

verbose = 1

In [12]:
val_loader = torch.utils.data.DataLoader(dataset = val_set,
                                            batch_size = 1,
                                            shuffle = True,
                                            generator=torch.Generator(device=device))
test_loader = torch.utils.data.DataLoader(dataset = test_set,
                                            batch_size = 1,
                                            shuffle = True,
                                            generator=torch.Generator(device=device))

In [13]:
c = 0
for (image_val, label_val) in val_loader:
    if(label_val == 1):
        c += 1
print(f'Val loader P(1): {c/len(val_loader)}')

c = 0
for (image_test, label_test) in test_loader:
    if(label_test == 1):
        c += 1
print(f'Test loader P(1): {c/len(test_loader)}')

Val loader P(1): 0.38975756706354897
Test loader P(1): 0.3854540238129393


In [None]:
optimizer = torch.optim.Adam(nn.parameters(), lr = lr, weight_decay = weight_decay)
loss_function = torch.nn.MSELoss()

loss_arr = []
acc_arr = []
min_loss = None
min_acc = 0
outputs = []
early_stop = False
early_stop_depth = 20
encoder.eval()

for epoch in range(EPOCHS):

    total_correct = 0
    total_samples = 0
    
    if early_stop:
        if verbose != 0:
            print(f'\n\n------EARLY STOP {min_loss}------\n\n')
        break

    count = 0

    nn.train()
    nn.to(device)
    for (image, label) in loader:
        image = image.to(device)
        label = label.to(device)
        
        x = encoder(image) # pretrained compressed image
        
        output = nn(x) # new model output
    
        loss = loss_function(output, goal.unsqueeze(1).float())
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        # UI
        if verbose == 2:
            sys.stdout.write('\r')
            sys.stdout.write("Epoch: {} [{:{}}] {:.1f}% | Loss: {}".format(epoch+1, "="*count, 
                                                                       len(loader)-1, 
                                                                       (100/(len(loader)-1)*count), 
                                                                       loss.item()))
            sys.stdout.flush()

        count += 1
        
    loss_arr.append(loss.item())
    
    # Get the accuracy of the current state
    nn.eval()
    total_correct= 0
    for (image_val, label_val) in val_loader:
        image_val = image_val.to(device)
        label_val = label_val.to(device)

        with torch.no_grad():

            x2 = encoder(image_val)
            o = nn._sigmoid(nn._feed_forward(x2))
            val_outputs = torch.round(o)
            total_samples += 1
            
            total_correct += (val_outputs == label_val).item()
        
    

    accuracy = total_correct/len(val_loader)*100
    if verbose != 0:
        print(f'\nEpoch: {epoch + 1} | Loss: {loss.item():.4f} | Val Accuracy: {accuracy:.2f}%', end='\n'*2)


Epoch: 1 | Loss: 0.0000 | Val Accuracy: 60.45%


Epoch: 2 | Loss: 0.0000 | Val Accuracy: 60.55%


Epoch: 3 | Loss: 0.0000 | Val Accuracy: 60.59%


Epoch: 4 | Loss: 0.0000 | Val Accuracy: 60.61%


Epoch: 5 | Loss: 0.0000 | Val Accuracy: 60.59%


Epoch: 6 | Loss: 0.0000 | Val Accuracy: 60.65%


Epoch: 7 | Loss: 0.0000 | Val Accuracy: 60.65%


Epoch: 8 | Loss: 0.0000 | Val Accuracy: 60.67%


Epoch: 9 | Loss: 0.0000 | Val Accuracy: 60.64%


Epoch: 10 | Loss: 0.0000 | Val Accuracy: 60.64%


Epoch: 11 | Loss: 0.0000 | Val Accuracy: 60.62%


Epoch: 12 | Loss: 0.0000 | Val Accuracy: 60.64%


Epoch: 13 | Loss: 0.0000 | Val Accuracy: 60.61%


Epoch: 14 | Loss: 0.0000 | Val Accuracy: 60.64%


Epoch: 15 | Loss: 0.0000 | Val Accuracy: 60.65%



In [17]:
torch.save(nn.state_dict(), f'./ClassifierModels/class_model_gs.pth')

In [19]:
print("saved")

saved


In [11]:
nn = NeuralNet().to(device)
nn.load_state_dict(torch.load('./ClassifierModels/class_model_gs.pth'))

<All keys matched successfully>

In [22]:
#device = torch.device('cpu')
ans = []
total_samples = 0
total_correct = 0
nn.to(device)
nn.eval()
for (image, label) in test_loader:
    nn.eval()
    image = image.to(device)
    label = label.to(device)
    
    with torch.no_grad():
        encoded = encoder(image)
        not_rounded = nn._feed_forward(encoded)
        
        outputs = nn._sigmoid(not_rounded)
        if outputs < 0.5: # 0.779
            outputs = 0
        else:
            outputs = 1

        total_samples += 1
        total_correct += (outputs == label).item()
        ans.append((label.item(), outputs))

        
print(f'Accuracy: {total_correct/len(test_loader)*100:.2f}%')


Accuracy: 61.24%


In [23]:
total0, total1, correct0, correct1 = 0, 0, 0, 0
for i in range(len(ans)):
    if ans[i][0] == 0:
        total0 += 1
        if ans[i][1] == 0:
            correct0 += 1
    else:
        total1 += 1
        if ans[i][1] == 1:
            correct1 += 1

print(total0 + total1 == len(val_loader))

True


In [27]:
print(f'0 stats: {total0=} | {correct0=} | {100*(correct0/total0):.2f}%')
print(f'1 stats: {total1=} | {correct1=} | {100*(correct1/total1):.2f}%')
print(f'Total stats: {100*((correct0 + correct1)/(total0 + total1)):.2f}%')
print(f'Model guessed 0: {correct0 + total1- correct1}, getting {correct0} correct')
print(f'Model guessed 1: {correct1 + total0- correct0}, getting {correct1} correct')

0 stats: total0=4284 | correct0=4110 | 95.94%
1 stats: total1=2687 | correct1=159 | 5.92%
Total stats: 61.24%
Model guessed 0: 6638, getting 4110 correct
Model guessed 1: 333, getting 159 correct


In [None]:
# look into the 159 correct , cos distance for checking similarity between all 0's and 1's maybe 90%?