In [36]:
import os
import torch
from torch import nn, optim
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader, random_split
from torch.utils.data import Subset
import numpy as np
from tqdm import tqdm
import pandas as pd
import random
import json

device = torch.device("cuda")

# Function to calculate accuracy
def binary_accuracy(preds, y):
    rounded_preds = torch.round(torch.sigmoid(preds))
    correct = (rounded_preds == y).float()
    acc = correct.sum() / len(correct)
    return acc

def best_cv_val(nested_list):
    array = np.array(nested_list)
    means = np.mean(array, axis=0)
    best_epoch = np.argmax(means)
    return best_epoch + 1, means[best_epoch]

class RandomGaussianBlur:
    def __init__(self, kernel_size=3, probability=0.5):
        self.kernel_size = kernel_size
        self.probability = probability
        self.gaussian_blur = transforms.GaussianBlur(self.kernel_size)

    def __call__(self, img):
        if random.random() < self.probability:
            return self.gaussian_blur(img)
        return img

def split_dataset(holdout_clusters, full_dataset):
    # Determine the validation indices by checking if the file matches any holdout cluster
    val_indices = [i for i, (path, _) in enumerate(full_dataset.imgs)
                   if any(f"_{cluster}.png" in path for cluster in holdout_clusters)]

    # The training set includes all indices that are not in the validation set
    train_indices = [i for i in range(len(full_dataset)) if i not in val_indices]

    return train_indices, val_indices

os.chdir('/home/kdoherty/spurge/data_release')

train_dir = './data/crop_39/train'

df = pd.read_csv('./results/best_lr.csv')
best_row = df.loc[df['accuracy'].idxmax()]
learning_rate = best_row['lr']

with open('./results/best_augs.json', 'r') as file:
    augs = json.load(file)

gaussian_blur = augs['gaussian_blur']
flip_horizontal = augs['flip_horizontal']
flip_vertical = augs['flip_vertical']
brightness = augs['brightness']
contrast = augs['contrast']
saturation = augs['saturation']
hue = augs['hue']
rotation = augs['rotation']

stats = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

transform_list = [transforms.ToTensor(), stats]

if gaussian_blur:
    transform_list.insert(0, RandomGaussianBlur())
if flip_horizontal:
    transform_list.insert(0, transforms.RandomHorizontalFlip())
if flip_vertical:
    transform_list.insert(0, transforms.RandomVerticalFlip())

transform_list.insert(0, transforms.ColorJitter(hue=hue, contrast=contrast, brightness=brightness, saturation=saturation))
transform_list.insert(0, transforms.RandomRotation(rotation))

data_transforms = {
    'train': transforms.Compose(transform_list),
    'val': transforms.Compose([
        transforms.ToTensor(),
        stats
    ])
}

seed = 0
batch_size = 32
n_epochs = 500
holdout_sets = [[0], [1], [2], [4], [5], [6,7], [8]]

full_dataset = datasets.ImageFolder(train_dir, transform=data_transforms['train'])

array_idx = int(os.environ['SLURM_ARRAY_TASK_ID'])
#array_idx = 0
holdout_set = holdout_sets[array_idx]

print(f'Validating cluster {holdout_set}')

# Now split the dataset
train_indices, val_indices = split_dataset(holdout_set, full_dataset)

# Create subsets
train_subset = Subset(full_dataset, train_indices)
val_subset = Subset(full_dataset, val_indices)

# Update transformations for the validation set
val_subset.dataset.transform = data_transforms['val']

# Create the data loaders
train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_subset, batch_size=batch_size, shuffle=False)

# Load pre-trained resnet50 model + higher level layers
model = models.resnet50(pretrained=True)

# Change the last layer to have 1 output
num_ftrs = model.fc.in_features

# Modify fc layers for binary classification
model.fc = nn.Sequential(
  nn.Linear(num_ftrs, 1),
  )

# Loss function and optimizer
criterion = nn.BCEWithLogitsLoss().to(device)

model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=learning_rate)

epochs = range(n_epochs)
epoch_accs = []
epoch_losses = []

with tqdm(total=n_epochs*len(train_loader), unit="batch", desc="Training Progress") as pbar:
    for epoch in epochs:
        model.train()
        running_loss = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            output = model(images).squeeze(1)  # Remove only the second dimension if it's size 1
            loss = criterion(output, labels.float())  # No need to squeeze again
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            pbar.update(1)
        
        train_loss = running_loss / len(train_loader)
        
        # Validate the model
        model.eval()
        running_loss = 0
        running_acc = 0
        
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            output = model(images).squeeze(1)  # Remove only the second dimension if it's size 1
            loss = criterion(output, labels.float())  # No need to squeeze again
            acc = binary_accuracy(output, labels.float())  # No need to squeeze here either
            running_loss += loss.item()
            running_acc += acc.item()

        val_loss = running_loss/len(val_loader)
        val_acc = running_acc/len(val_loader)
        epoch_accs.append(val_acc)
        epoch_losses.append(val_loss)
        
        pbar.set_postfix({'Epoch': epoch, 
                          'Validation Loss': f'{val_loss:.3f}', 
                          'Validation Accuracy': f'{val_acc:.3f}'})

os.makedirs('./results/epoch_tune', exist_ok=True)

# Convert the lists to a DataFrame
results_df = pd.DataFrame({
    'array_idx': [array_idx] * n_epochs,
    'epoch': list(epochs),
    'epoch_loss': epoch_losses,
    'epoch_acc': epoch_accs
})

# Save to CSV
csv_path = f'./results/epoch_tune/{array_idx}.csv'
results_df.to_csv(csv_path, index=False)

Validating cluster [0]


Training Progress: 100%|██████████████████████████████████████████████████████████████████████████████████████████| 66/66 [00:05<00:00, 12.33batch/s, Epoch=2, Validation Loss=1.146, Validation Accuracy=0.615]
