In [None]:
import torch
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, utils, models
from torchsummary import summary
import pandas as pd
import numpy as np
import os
import sys
import matplotlib.pyplot as plt
from PIL import Image

: 

In [None]:
# Get cpu or gpu device for training.
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")

: 

In [None]:
model = models.resnet18(pretrained=True).to(device)

# Freezing the base model layers to prevent retraining
for param in model.parameters():
    param.requires_grad = False

: 

Notes on the Resnet Model:

** convert images to b&w
- Requires input images dimensions (256,256) ** resize our images
- My additional fully connected layer needs dimensions (2048,10) - 10 for the 10 classes for the 10 style types (--)
- Image preprocessing requires:
  1. (224,224) center crop
  2. image is normalized with mean = 255*[0.485, 0.456, 0.406] and
  std = 255*[0.229, 0.224, 0.225]
  3. transpose it from HWC to CHW layout
- Post-processing involves calculating the softmax probability scores for each class

In [None]:
classes = 4
model.fc = torch.nn.Linear(512, classes).to(device)
print(model)

: 

In [None]:
loss_fn = torch.nn.CrossEntropyLoss() # multi-class classification model loss
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9) 

: 

In [None]:
class MyDataset(Dataset):
    def __init__(self,
                 csv_file,      # images could be provided with in a series of directories
                 root_dir,     # images could be provided as a list as well
                 transform = None):  # provide transformation to apply to each image
      """
      Organize the images and the associated labels into two lists.  Potentially create additional
      lists if more complicated information is need.  Important note: images are NOT
      read and stored in this initializer.  They are read in __getitem__ as needed.
      """
      self.csv_file = csv_file # path of csv file
      self.root_dir = root_dir # directory the photos are in
      self.images = pd.read_csv(self.csv_file)
      # Record the transform that may need to be applied.
      self.transform = transform

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        '''
        Return a tuple with the data, ground truth label, and any other data
        associated with a single image.
        '''
        img_name = self.images.iloc[idx, 0] # name of image in 1st column
        img_path = os.path.join(self.root_dir, img_name)
        im = Image.open(img_path)

        if self.transform is not None:
            im = self.transform(im)

        """
        label encodes season
        season = {
            0: 'spring'
            1: 'summer'
            2: 'fall'
            3: 'winter'  
        }
        """
        label = self.images.iloc[idx, 1]

        return im, label

: 

In [None]:
os.getcwd()

: 

In [None]:
image_transforms = transforms.Compose([transforms.ToTensor(), transforms.Resize((224, 224)), 
                                        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))])
# transforms.Grayscale(num_output_channels=1)

dataset = MyDataset(csv_file='./filtered_style_stats.csv',
                    root_dir='./yolov5/yolov5/crop-images',
                    transform=image_transforms)

train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(dataset, [0.7, 0.15, 0.15], generator=torch.Generator())

: 

In [None]:
# Referenced from 
# https://towardsdatascience.com/pytorch-basics-sampling-samplers-2a0f29f0bf2a

def get_class_distribution(dataset_obj):
    count_dict = {          \
        'spring': 0,        \
        'summer': 0,        \
        'fall':   0,        \
        'winter': 0         \
    } # type: ignore
    idx_to_class = {        \
        0: 'spring',        \
        1: 'summer',        \
        2: 'fall',          \
        3: 'winter'         \
    }
    
    for idx in range(len(dataset_obj)):
        element = dataset_obj[idx]
        y_lbl = idx_to_class[element[1]]
        count_dict[y_lbl] += 1
    return count_dict

: 

In [None]:
#### CHECKING ITS DECENT ####
total_data = len(train_dataset) + len(val_dataset) + len(test_dataset)
print(f"Length of total data: {total_data}")
print(f"Length of train_dataset: {len(train_dataset)}; {(len(train_dataset)/total_data)*100:.2f}%")
print(f"Length of val_dataset: {len(val_dataset)}; {(len(val_dataset)/total_data)*100:.2f}%")
print(f"Length of test_dataset: {len(test_dataset)}; {(len(test_dataset)/total_data)*100:.2f}%\n")
print(f"Class Distribution of train_dataset: {get_class_distribution(train_dataset)}")
print(f"Class Distribution of val_dataset: {get_class_distribution(val_dataset)}")
print(f"Class Distribution of test_dataset: {get_class_distribution(test_dataset)}")

: 

In [None]:
batch_size = 32

train_dataloader = DataLoader(train_dataset, batch_size=batch_size)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size)

: 

In [None]:
def train(dataloader, model, loss_fn, optimizer):
    size = len(train_dataset)
    model.train()
    for batch, (X, y) in enumerate(dataloader):
        X, y = X.to(device), y.to(device)

        # Compute prediction error
        pred = model(X)
        loss = loss_fn(pred, y)

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if batch % 200 == 0: # every 5000 images run
            loss, current = loss.item(), batch * len(X)
            print(f"loss: {loss:>7f}  [{current:>5d}/{size:>5d}]")

: 

In [None]:
def test(dataloader, model, loss_fn, incorrect_examples, correct_examples):
    size = len(val_dataset)
    num_batches = len(dataloader)
    model.eval()
    val_loss, correct = 0, 0
    with torch.no_grad(): 
        for X, y in dataloader:
            X, y = X.to(device), y.to(device)
            pred = model(X)
            val_loss += loss_fn(pred, y).item()

            val = pred.argmax(1).to(device) 
            correct += ((val == y).type(torch.float).sum().item()) 

            if torch.all(torch.eq(val, y)) and len(correct_examples) < 6:
                correct_examples.append(X.cpu())
            if (not torch.all(torch.eq(val, y))) and len(incorrect_examples) < 6:
                incorrect_examples.append(X.cpu())

    val_loss /= num_batches
    correct /= size
    print(f"Val Error ---\n Accuracy: {(100*correct):>0.1f}%, Avg loss: {val_loss:>8f} \n")
    return correct_examples, incorrect_examples

: 

In [None]:
epochs = 10

for t in range(epochs):
    print(f"Epoch {t+1}\n-------------------------------")
    train(train_dataloader, model, loss_fn, optimizer)
    incorrect_examples = []
    correct_examples = []
    correct_examples, incorrect_examples = test(val_dataloader, model, loss_fn, incorrect_examples, correct_examples)
print("Done!\n")

: 