In [1]:
import os
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
from torchvision.transforms import CenterCrop
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Subset
from PIL import Image
import numpy as np
from tqdm import tqdm
import pandas as pd
import random
from glob import glob
import optuna
import json
import plotly

device = torch.device("cuda")

# Function to calculate accuracy
def binary_accuracy(preds, y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()
    acc = correct.sum() / len(correct)
    return acc

class RandomGaussianBlur:
    def __init__(self, kernel_size=3, probability=0.5):
        self.kernel_size = kernel_size
        self.probability = probability
        self.gaussian_blur = transforms.GaussianBlur(self.kernel_size)

    def __call__(self, img):
        if random.random() < self.probability:
            return self.gaussian_blur(img)
        return img

os.chdir('/home/kdoherty/spurge/data_release')

train_dir = './data/crop_39/train'

def objective(trial):
    # Hyperparameters to be optimized
    imagenet = trial.suggest_categorical("imagenet", [True, False])
    gaussian_blur = trial.suggest_categorical("gaussian_blur", [True, False])
    flip_horizontal = trial.suggest_categorical("flip_horizontal", [True, False])
    flip_vertical = trial.suggest_categorical("flip_vertical", [True, False])
    brightness = trial.suggest_float("brightness", 0.0, 1.0, step=0.1)
    contrast = trial.suggest_float("contrast", 0.0, 1.0, step=0.1)
    saturation = trial.suggest_float("saturation", 0.0, 1.0, step=0.1)
    hue = trial.suggest_float("hue", 0.0, 0.5, step=0.1)
    rotation = trial.suggest_int("rotation", 0, 90, step=5)

    if imagenet:
        stats = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    else:
        stats = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])

    transform_list = [transforms.ToTensor(), stats]

    if gaussian_blur:
        transform_list.insert(0, RandomGaussianBlur())
    if flip_horizontal:
        transform_list.insert(0, transforms.RandomHorizontalFlip())
    if flip_vertical:
        transform_list.insert(0, transforms.RandomVerticalFlip())

    transform_list.insert(0, transforms.ColorJitter(hue=hue, contrast=contrast, brightness=brightness, saturation=saturation))
    transform_list.insert(0, transforms.RandomRotation(rotation))

    data_transforms = {
        'train': transforms.Compose(transform_list),
        'val': transforms.Compose([
            transforms.ToTensor(),
            stats
        ])
    }

    batch_size = 32
    learning_rate = 0.00005
    n_epochs = 20

    full_dataset = datasets.ImageFolder(train_dir, transform=data_transforms['train'])

    seeds = range(8)

    val_accs = []

    for seed in seeds:
        torch.manual_seed(seed)
        np.random.seed(seed)
        random.seed(seed)
    
        # Shuffle and create subsets for training and validation
        dataset_size = len(full_dataset)
        indices = list(range(dataset_size))
        random.shuffle(indices)
        train_indices = indices[:128]
        val_indices = indices[128:256]  # Ensuring no overlap with the training set
    
        train_subset = Subset(full_dataset, train_indices)
        val_subset = Subset(full_dataset, val_indices)
    
        # Update the transformations for the validation set
        val_subset.dataset.transform = data_transforms['val']
    
        train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
        val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)
    
        # Load pre-trained resnet50 model + higher level layers
        model = models.resnet50(pretrained=True)
    
        # Change the last layer to have 1 output
        num_ftrs = model.fc.in_features
    
        # Modify fc layers for binary classification
        model.fc = nn.Sequential(
          nn.Linear(num_ftrs, 1),
          )
    
        # Loss function and optimizer
        criterion = nn.BCEWithLogitsLoss().to(device)
    
        model = model.to(device)
        optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    
        with tqdm(total=n_epochs*len(train_loader), unit="batch", desc="Training Progress") as pbar:
            for epoch in range(n_epochs):
                model.train()
                running_loss = 0
    
                for images, labels in train_loader:
                    images, labels = images.to(device), labels.to(device)
                    optimizer.zero_grad()
                    output = model(images)
                    loss = criterion(output.squeeze(), labels.float())
                    loss.backward()
                    optimizer.step()
                    running_loss += loss.item()
                    pbar.update(1)
    
                train_loss = running_loss/len(train_loader)
    
                # Validate the model
                model.eval()
                running_loss = 0
                running_acc = 0
    
                for images, labels in val_loader:
                    images, labels = images.to(device), labels.to(device)
                    output = model(images)
                    loss = criterion(output.squeeze(), labels.float())
                    acc = binary_accuracy(output.squeeze(), labels.float())
                    running_loss += loss.item()
                    running_acc += acc.item()
    
                val_loss = running_loss/len(val_loader)
                val_acc = running_acc/len(val_loader)
                
                val_accs.append(val_acc)
                pbar.set_postfix({'Epoch': epoch+1, 'Validation Loss': f'{val_loss:.3f}', 'Validation Accuracy': f'{val_acc:.3f}'})

    return np.mean(val_accs)

def save_figures(study,trial):
    # Constant file names for overwriting
    opt_hist_path = './results/optimization_history_plot.png'
    fig1 = optuna.visualization.plot_optimization_history(study)
    fig1.write_image(opt_hist_path)

    opt_slice_path = './results/optimization_slice_plot.png'
    fig2 = optuna.visualization.plot_slice(study)
    fig2.write_image(opt_slice_path)
    
    if len(study.trials) > 1:
        opt_importance_path = './results/optimization_importance_plot.png'
        fig3 = optuna.visualization.plot_param_importances(study)
        fig3.write_image(opt_importance_path)
    
trials = 3

study = optuna.create_study(direction='maximize', sampler=optuna.samplers.TPESampler(seed=0))
study.optimize(objective, n_trials=trials, callbacks=[save_figures])

best_params_path = './results/best_hyperparams.json'
best_params = study.best_params
with open(best_params_path, 'w') as f:
    json.dump(best_params, f)

[I 2023-11-02 12:45:31,735] A new study created in memory with name: no-name-da6c557f-613b-4027-9202-a1bdcbc6827e
Training Progress: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:02<00:00,  5.89batch/s, Epoch=3, Validation Loss=0.654, Validation Accuracy=0.578]
Training Progress: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  7.98batch/s, Epoch=3, Validation Loss=0.645, Validation Accuracy=0.656]
Training Progress: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  8.31batch/s, Epoch=3, Validation Loss=0.709, Validation Accuracy=0.453]
Training Progress: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 12/12 [00:01<00:00,  8.70batch/s, Epoch=3, Validation Loss=0.636, Validation Accuracy=0.656]
Training Progress: 100%|██████████████████████████