In [1]:
import torch
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
from AutoEncoderCNN import AE_CNN
from GridSearch import GridSearch

In [2]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.set_default_device(device)

In [3]:
PATH = '/groups/francescavitali/eb2/subImages_slide299/H&E' # has 506 images
BATCH_SIZE = 4

tensor_transform = transforms.ToTensor()

dataset = datasets.ImageFolder(PATH, 
                               transform = tensor_transform) #loads the images

train_set, val_set, test_set = torch.utils.data.random_split(dataset,
                                                           [404,51,51],# 80%, 10%, 10%
                                                           generator=torch.Generator(device=device))

loader = torch.utils.data.DataLoader(dataset = train_set,
                                            batch_size = BATCH_SIZE,
                                            shuffle = True,
                                            generator=torch.Generator(device=device))

In [4]:
len(dataset.targets)

506

In [5]:
model = AE_CNN(64,128).to(device)

model.load_state_dict(torch.load('./models/model_gs.pth')) # loading best model state
model.load_state_dict(torch.load('./models/Copy Models/model_gs_3-28-2024.pth'))

# setting the encoder and decoder for visualization
encoder = model.encoder
decoder = model.decoder

In [6]:
count = 0
with torch.no_grad():
    for(img, goal) in loader: # goal will be a tensor of len == batch_size
        if count == 3:
            break
        img = img.to(device)
        print(f'Img Shape: {img.shape}')
        encoded_img = encoder(img)
        print(f'Encoded Shape: {encoded_img.shape}')
        flattened = encoded_img.flatten(start_dim = 1)
        print(f'Flattened: {flattened.shape}')
        print(f'Goal: {goal.unsqueeze(1)}')
        print()
        count += 1

Img Shape: torch.Size([4, 3, 299, 299])
Encoded Shape: torch.Size([4, 128, 21, 21])
Flattened: torch.Size([4, 56448])
Goal: tensor([[0],
        [1],
        [1],
        [0]], device='cuda:0')

Img Shape: torch.Size([4, 3, 299, 299])
Encoded Shape: torch.Size([4, 128, 21, 21])
Flattened: torch.Size([4, 56448])
Goal: tensor([[1],
        [1],
        [1],
        [0]], device='cuda:0')

Img Shape: torch.Size([4, 3, 299, 299])
Encoded Shape: torch.Size([4, 128, 21, 21])
Flattened: torch.Size([4, 56448])
Goal: tensor([[1],
        [1],
        [0],
        [0]], device='cuda:0')



In [7]:
class NeuralNet(torch.nn.Module):
    def __init__(self, encoder):
        super().__init__()
        
        self.encoder = encoder
        self._feed_forward = torch.nn.Sequential(
            torch.nn.Linear(56448, 2048),
            torch.nn.ReLU(),
            torch.nn.Linear(2048, 512),
            torch.nn.ReLU(),
            torch.nn.Linear(512, 128),
            torch.nn.ReLU(),
            torch.nn.Linear(128, 32),
            torch.nn.ReLU(),
            torch.nn.Linear(32, 1),
            torch.nn.Sigmoid()
        )
        
    def encoded_without_training(self, x):
        with torch.no_grad():
            encoded = self.encoder(x)
            flattened = encoded.flatten(start_dim = 1)
            return flattened
    
    def forward(self,  x):
        #encoded = self.encoder(x) # this will update the encoder weights
        #flattened = encoded.flatten(start_dim = 1)
        flattened = self.encoded_without_training(x)
        output = self._feed_forward(flattened)
        return output
    

In [8]:
nn = NeuralNet(encoder).to(device)

In [9]:
lr = 0.00001
weight_decay = 1e-5
EPOCHS = 1000

verbose = 1

In [10]:
optimizer = torch.optim.Adam(nn.parameters(), lr = lr, weight_decay = weight_decay)
loss_function = torch.nn.BCELoss()

loss_arr = []
min_loss = None
outputs = []
early_stop = False
early_stop_depth = 20

for epoch in range(EPOCHS):

    if early_stop:
        if verbose != 0:
            print(f'\n\n------EARLY STOP {min_loss}------\n\n')
        break

    count = 0

    nn.train()
    for (image, label) in loader:
        image = image.to(device)
        #image = image.flatten(start_dim=1) # ignore the batch_size

        output = nn(image)
        loss = loss_function(output, goal.unsqueeze(1).float())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()


        # UI
        if verbose == 2:
            sys.stdout.write('\r')
            sys.stdout.write("Epoch: {} [{:{}}] {:.1f}% | Loss: {}".format(epoch+1, "="*count, 
                                                                       len(loader)-1, 
                                                                       (100/(len(loader)-1)*count), 
                                                                       loss.item()))
            sys.stdout.flush()

        count += 1

    loss_arr.append(loss.item())
    if not min_loss:
        min_loss = loss_arr[0]
    if early_stop_depth >= 1 and early_stop_depth < len(loss_arr[loss_arr.index(min_loss):]):
        early_stop = True
        for loss_item in loss_arr[loss_arr.index(min_loss):]:
            if loss_item < min_loss:
                min_loss = loss_item
                early_stop = False


    if verbose != 0:
        print(f'\nEpoch: {epoch + 1} | Loss: {loss.item():.4f}', end='\n'*2)


Epoch: 1 | Loss: 0.6693


Epoch: 2 | Loss: 0.6533


Epoch: 3 | Loss: 0.7175


Epoch: 4 | Loss: 0.6963


Epoch: 5 | Loss: 0.6203


Epoch: 6 | Loss: 0.6887


Epoch: 7 | Loss: 0.6716


Epoch: 8 | Loss: 0.7106


Epoch: 9 | Loss: 0.7856


Epoch: 10 | Loss: 0.7076


Epoch: 11 | Loss: 0.7316


Epoch: 12 | Loss: 0.7318


Epoch: 13 | Loss: 0.7334


Epoch: 14 | Loss: 0.7113


Epoch: 15 | Loss: 0.6679


Epoch: 16 | Loss: 0.7722


Epoch: 17 | Loss: 0.6411


Epoch: 18 | Loss: 0.7070


Epoch: 19 | Loss: 0.6867


Epoch: 20 | Loss: 0.6903


Epoch: 21 | Loss: 0.6777


Epoch: 22 | Loss: 0.7346


Epoch: 23 | Loss: 0.6774


Epoch: 24 | Loss: 0.6641


Epoch: 25 | Loss: 0.7162



------EARLY STOP 0.6203462481498718------




In [11]:
ans = []
nn.eval()
for x in range(len(val_set)):
    with torch.no_grad():
        inp = val_set.__getitem__(x)[0]
        exp = val_set.__getitem__(x)[1]
        pred = nn.cpu()(img).flatten() # why output as a tensor of shape 2 x 4?
        ans.append((exp, float(min(pred))))