In [None]:
import torch
import torch.nn as nn
import torchvision
import numpy as np
import torchvision.transforms as transforms
from torchvision.datasets import ImageFolder
import torchvision.transforms.functional as F
from torch.utils.data import DataLoader, Dataset
from torchvision.models import resnet152, ResNet152_Weights
from sklearn.model_selection import train_test_split
import torch.optim as optim
from tqdm import tqdm
import matplotlib.pyplot as plt
import datetime
from trainer import train
import json
import cv2

In [None]:
#store current time
date_time = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M")

In [None]:
#import resnet
resnet_model = resnet152(weights = ResNet152_Weights.DEFAULT)


In [None]:
class RectAngularPadTransform(torch.nn.Module):
    def __init__(self):
        super().__init__()
    
    def forward(self, img):
        padding = (
            max(0, (img.size[1] - img.size[0]) // 2),
            max(0, (img.size[0] - img.size[1]) // 2)
        )
        #show image
        new_img = new_img = F.pad(img, padding)
        return new_img
    
    def __repr__(self):
        return self.__class__.__name__

In [None]:

#transform images for resnet
transform = transforms.Compose([
    RectAngularPadTransform(),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])
#load data
dataset_path = 'datasets/2023-12-08-16-54'
dataset = ImageFolder(root=dataset_path, transform=transform)
targets = dataset.targets
#split data into train, test, val
#70-20-10
train_val_idx, test_idx= train_test_split(np.arange(len(targets)),test_size=0.2,shuffle=True,stratify=targets, random_state=42)
print(type(train_val_idx))
train_val_idx_list = train_val_idx.tolist()
train_val_stratifier = np.take(targets,train_val_idx_list)
#targets[train_val_idx_list]
train_idx, validation_idx = train_test_split(train_val_idx,test_size=0.125,shuffle=True,stratify=train_val_stratifier, random_state=42)
#adjust classifier to match number of classes +1 for uncertain
resnet_model.fc = nn.Linear(2048, len(dataset.classes))
#load data into dataloader
batch_size = 64

train_sampler = torch.utils.data.SubsetRandomSampler(train_idx)
validation_sampler = torch.utils.data.SubsetRandomSampler(validation_idx)
test_sampler = torch.utils.data.SubsetRandomSampler(test_idx)

test_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=test_sampler)
train_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=train_sampler)
validation_loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, sampler=validation_sampler)

In [None]:
print(resnet_model)

In [None]:
print("Number of train samples: ", len(train_idx), (len(train_loader.sampler)))

In [None]:
def plot_data_loader_class_distribution(loader: DataLoader, title : str):
    #plot class distribution in dataset
    class_counts = {class_idx: 0 for class_idx in range(len(dataset.classes))}

    # Count the number of samples in each class in train dataloader
    for _, label in train_loader:
        for class_idx in label:
            print(class_idx.item())
            class_counts[class_idx.item()] += 1
    # Plot the distribution
    classes = [dataset.classes[idx] for idx in class_counts.keys()]
    counts = [class_counts[idx] for idx in class_counts.keys()]

    plot = plt.bar(classes, counts)
    plt.xlabel('Classes')
    plt.ylabel('Count')
    plt.title(title)
    plt.xticks(rotation=45, ha="right")  # Rotate x-axis labels for better readability
    return plot

In [None]:
if False:
    plot_data_loader_class_distribution(train_loader, 'Train Data Class Distribution')
    plt.savefig(f'dataset_plots/{date_time}/train_data.png')
    plt.show()
    plot_data_loader_class_distribution(validation_loader, 'Validation Data Class Distribution')
    plt.savefig(f'dataset_plots/{date_time}/validation_data.png')
    plt.show()
    plot_data_loader_class_distribution(test_loader, 'Test Data Class Distribution')
    plt.savefig(f'dataset_plots/{date_time}/test_data.png')
    plt.show()

In [None]:
import matplotlib.pyplot as plt
from PIL import Image
#plot one example with and without the transform
#pathn to example image
path = 'datasets/2023-12-07-12-57/age-related macular degeneration/43_left.jpg'
img = cv2.imread(path)
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
print(img_rgb.shape)
#show image
plt.imshow(img_rgb)
plt.show()
#transform image to tensor
pil_img = Image.open(path).convert('RGB')
pil_img = transform(pil_img)
#show image
plt.imshow(pil_img.permute(1,2,0))
plt.show()

In [None]:
print(len(dataset.classes))

In [None]:
model = resnet_model
n_classes = len(dataset.classes) #+ 1
batch_size = 64
epochs = 50
lr = 0.001
dataset_path = dataset_path
best_weights_save_path = f'models/{date_time}_{resnet_model.__class__.__name__}.pth'
train_loader = train_loader
validation_loader = validation_loader

#train model
model_history_dict = train(model=model, n_classes=n_classes, batch_size=batch_size, epochs=epochs, lr=lr, dataset_path=dataset_path, best_weights_save_path=best_weights_save_path, train_loader=train_loader, validation_loader=validation_loader)
print(model_history_dict)

In [None]:
#save model history dict to json file
with open(f'{best_weights_save_path}.json', 'w') as fp:
    json.dump(model_history_dict, fp)

In [None]:
best_val_accuracy = max(model_history_dict['validation']['accuracy'])
print(f'Best validation accuracy: {best_val_accuracy}', 'n_classes', n_classes)