# Imports

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import numpy as np
import os
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt

import ssl
ssl._create_default_https_context = ssl._create_unverified_context

import warnings
warnings.filterwarnings("ignore")

### SELECT DEVICE ###
# GPU device configuration
if torch.cuda.is_available():
  DEVICE = torch.device('cuda')
  print('Using CUDA')
elif torch.backends.mps.is_available():
  DEVICE = torch.device('mps')
  print('Using MPS')
else:
  DEVICE = torch.device('cpu')
  print('Using CPU')

# Dataloader

In [94]:
### DEFINE TRANSFORMATIONS ###
normalizer = transforms.Normalize(mean = [0.485, 0.456, 0.406],
                                  std = [0.229, 0.224, 0.225])
train_transforms = transforms.Compose([
                    transforms.Resize((224,224)),
                    transforms.ToTensor(),
                    normalizer
                    ])

class GuessTheCorrelationDataset(Dataset):
  def __init__(self, root, transform=None, indexes=None):
    self.root = root
    self.transform = transform
    self.img_dir = os.path.join(root, 'train_imgs')
    
    # Load correlation values
    csv_path = os.path.join(root, 'train_responses.csv')
    df = pd.read_csv(csv_path)
    
    # Filter by indexes if provided
    if indexes is not None:
      df = df.loc[df.index.isin(indexes)]
    
    self.img_files = df['id'].astype(str) + '.png'
    self.correlations = df['corr'].values
  
  def __len__(self):
    return len(self.img_files)
  
  def __getitem__(self, idx):
    img_path = os.path.join(self.img_dir, self.img_files.iloc[idx])
    image = Image.open(img_path).convert('RGB')
    
    if self.transform:
      image = self.transform(image)
    
    label = torch.tensor(self.correlations[idx], dtype=torch.float32)
    
    return image, label

In [95]:
if os.name == 'posix':
  root = os.path.expanduser("~/Documents/PyTorch_Data/correlation/guess-the-correlation")
else:
  root = "d:\\PyTorch_Data\\guess-the-correlation"
  
train_imgs_path = os.path.join(root, 'train_imgs')
# num_images = len([name for name in os.listdir(train_imgs_path) if os.path.isfile(os.path.join(train_imgs_path, name))])
# print(f"Number of images: {num_images}")

# Model training function

In [96]:
def train(model, device, epochs, optimizer, loss_fn, batch_size, trainloader, valloader):
    log_training = {"epoch": [],
                    "training_loss": [],
                    "validation_loss": []}

    for epoch in range(1, epochs + 1):
        print(f"Starting Epoch {epoch}")
        training_losses = []
        validation_losses = []

        for image, label in tqdm(trainloader, ncols = 60):
            image, label = image.to(device), label.to(device)
            out = model.forward(image)
        
            ### CALCULATE LOSS ##
            optimizer.zero_grad()
            loss = loss_fn(out, label)
            training_losses.append(np.sqrt(loss.item()))

            loss.backward()
            optimizer.step()

        for image, label in tqdm(valloader, ncols = 60):
            image, label = image.to(device), label.to(device)
            with torch.no_grad():
                out = model.forward(image)

                ### CALCULATE LOSS ##
                loss = loss_fn(out, label)
                validation_losses.append(np.sqrt(loss.item()))

        training_loss_mean = np.mean(training_losses)
        valid_loss_mean = np.mean(validation_losses)

        log_training["epoch"].append(epoch)
        log_training["training_loss"].append(training_loss_mean)
        log_training["validation_loss"].append(valid_loss_mean)

        print("Training Loss:", training_loss_mean) 
        print("Validation Loss:", valid_loss_mean)
        print("=====================================\n")
        
    return log_training, model

## Load PreTrained Weights but Only Train the Final Classifier Layer

In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)
model.fc
model.fc = nn.Linear(2048, 1)

# Check the name of all the parameters
for name, param in model.named_parameters():
  # print(name)
  if "fc" not in name:
    param.requires_grad_(False) # Inplace turn of gradient updates

In [None]:
print(model)

In [None]:
total_parameters = 0
for name, params in model.named_parameters():
  num_params = int(torch.prod(torch.tensor(params.shape)))
  print(name, ":", params.shape, "Num Parameters:", num_params)
  total_parameters += num_params

print("------------------------")
print("Total Parameters in Model", total_parameters)

In [None]:
model = torch.hub.load('pytorch/vision:v0.10.0', 'resnet101', pretrained=True)
model.fc = nn.Linear(2048, 1)

# Check the name of all the parameters
for name, param in model.named_parameters():
  # print(name)
  if "fc" not in name:
    param.requires_grad_(False) # Inplace turn of gradient updates

model = model.to(DEVICE)

### MODEL TRAINING INPUTS ###
epochs = 50
optimizer = optim.SGD(params=model.parameters(), lr=0.0001)
loss_fn = nn.MSELoss()
batch_size = 64

# Dataset creation
train_ds = GuessTheCorrelationDataset(root,
                                      transform=train_transforms,
                                      indexes=range(30000))
valid_ds = GuessTheCorrelationDataset(root,
                                      transform=train_transforms,
                                      indexes=range(30000, 40000))
test_ds = GuessTheCorrelationDataset(root,
                                     transform=train_transforms,
                                     indexes=range(40000, 50000))

### BUILD DATALOADERS ###
trainloader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0)
valloader = DataLoader(valid_ds, batch_size=batch_size, shuffle=False, num_workers=0)
testloader = DataLoader(test_ds, batch_size=batch_size, shuffle=False, num_workers=0)

In [None]:
data_iterator = iter(trainloader)
data = next(data_iterator)
features, labels = data

plt.imshow(features[0].permute(1, 2, 0), cmap='gray')
plt.axis('off')
plt.show()

In [None]:
train_logs, model = train(model=model,
                          device=DEVICE,
                          epochs=epochs,
                          optimizer=optimizer,
                          loss_fn=loss_fn,
                          batch_size=batch_size,
                          trainloader=trainloader,
                          valloader=valloader)

In [104]:
train_log_df = pd.DataFrame({'Epoch': train_logs['epoch'],
                             'Training_Loss': train_logs['training_loss'],
                             'Validation_Loss': train_logs['validation_loss']})

3000 training data @ 50 epochs:

- Training Loss: 0.4556598829628105
- Validation Loss: 0.4534302428237914

30000 training data @ 50 epochs:

- Training Loss: 0.4431861734305133
- Validation Loss: 0.4413213843223881

In [None]:
plt.figure(figsize=(10, 6))

# Example: Plot loss over epochs
plt.plot(train_log_df['Epoch'], train_log_df['Training_Loss'],
         label='Training Loss', marker='o')
plt.plot(train_log_df['Epoch'], train_log_df['Validation_Loss'],
         label='Validation Loss', marker='o')
plt.title('Loss Over Epochs')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()