#### Imports

In [None]:
try: 
    import cv2
    import torch
    import torchvision
    import sklearn.svm
except:
    %pip install opencv-python-headless==4.9.0.80
    %pip install torch
    %pip install torchvision
    %pip install torchsummary 

import torch
from torch.utils.data import Dataset
from torch import cuda
from torchvision import transforms, datasets, models
import torch.optim as optim
import torch.nn as nn
from torch.optim import lr_scheduler

from pathlib import Path
from timeit import default_timer as timer
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import Counter

from skimage.feature import hog
from sklearn.svm import SVC, LinearSVC
from sklearn.model_selection import GridSearchCV, KFold, train_test_split, cross_val_score
from sklearn.metrics import accuracy_score, confusion_matrix, classification_report

import random
import numpy as np
import time
import copy
import pickle 
import re
import shutil

import warnings
warnings.filterwarnings('ignore', category=FutureWarning)

from torchsummary import summary
from PIL import Image

np.set_printoptions(threshold=np.inf)
print('import successful')

#### Data Paths

In [None]:
# Dataset Parameters
BASE_PATH = Path('/home/jovyan/work/data/out')

EMOREACT = 'EmoReact'
FER = 'FER-2013'
KDEF = 'KDEF-AKDEF'
NIMH = 'NIMH-CHEFS'
INTERNAL = Path('/home/jovyan/work/output/extracted_faces')
CROSS = Path('/home/jovyan/work/cross-label/extracted_faces')

DATASET = 'Internal Cross Label'

DATA_PATH = Path('/home/jovyan/work/data/out/') / DATASET
DATA_PATH = CROSS

# Dataset-specific paths
LABELS = [f.name for f in DATA_PATH.iterdir() if f.is_dir()]
IMAGE_PATHS = list(DATA_PATH.rglob('*.jpg'))
print(len(IMAGE_PATHS))

# Constants for splitting dataset
TRAIN = 'train'
TEST = 'test'
VAL = 'val'

# Model parameters
MODEL_PATH = Path('/home/jovyan/work/models')
BATCH_SIZE = 16
SUBSET_RATIO = 0.1
SUBSET = False

# Constants for feature extraction
FEATURES = 'feature-extraction'
TRANSFER = 'transfer-learning'
FINETUNE = 'fine-tuning'

# Cuda parameters
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.cuda.empty_cache()
torch.cuda.get_device_name(0)

# feature extaraction
orientations = 7
pixels_per_cell = 8
cells_per_block = 4

hog_params = { 
    'orientations': orientations,
    'pixels_per_cell': pixels_per_cell,
    'cells_per_block': cells_per_block
}

In [None]:
DATA_PATH = Path('/home/jovyan/work/output/extracted_faces-improved/subset')
#LABELS = [f.name for f in (PATH / TRAIN).iterdir() if f.is_dir()]

DATASET = 'internal-improved'
#LABELS

#### Dataset class

In [None]:
class Dataset(Dataset):
    def __init__(self, data_path, transforms=None, phase='train'):
        self.data_path = Path(data_path) / phase
        self.transform = transforms[phase]
        self.phase = phase

        self.classes = self._get_classes()
        self.image_paths = self._get_image_paths()
        self.num_classes = len(self.classes)

        self.class_to_int = {class_name: idx for idx, class_name in enumerate(self.classes)}
        self.int_to_class = {idx: class_name for class_name, idx in self.class_to_int.items()}

    def _get_classes(self):
        return [f.name for f in self.data_path.iterdir() if f.is_dir()]

    def _get_image_paths(self):
        paths = list(self.data_path.rglob('*.jpg'))
        random.shuffle(paths)
        return paths

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = Image.open(self.image_paths[idx])
        img_path = self.image_paths[idx]
        if self.transform:
            img = self.transform(img)
        label = Path(img_path).parent.name
        label = self.class_to_int[label]  # Convert label to integer

        return img, label

    def show_distribution(self):
        labels_count = Counter([self.__getitem__(i)[1] for i in range(len(self.image_paths))])
        sorted_counts = sorted(labels_count.items())
        labels, counts = zip(*sorted_counts)

        plt.figure(figsize=(10, 3))
        bars = plt.bar(labels, counts, color='skyblue')
        plt.xlabel('Class')
        plt.ylabel('Count')
        plt.title('Counts per Class')
        plt.xticks(rotation=45, ha='right')
        plt.grid(axis='y', linestyle='--', alpha=0.7)
        for bar, count in zip(bars, counts):
            plt.text(bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.5, count,
                    ha='center', va='bottom', color='black', fontsize=8)
        plt.tight_layout()
        plt.show()

    def idx_to_class(self, idx_list):
        return [self.int_to_class[idx] for idx in idx_list]

    def class_to_idx(self, class_list):
        return [self.class_to_int[class_name] for class_name in class_list]

#### Create Transforms & Datasets

In [None]:
data_transforms = {
    TRAIN: transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
    ]),
    VAL: transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ]),
    TEST: transforms.Compose([
        transforms.Grayscale(num_output_channels=3),
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
    ])
}

In [None]:
datasets = { x: Dataset(DATA_PATH, transforms=data_transforms, phase=x) for x in [TRAIN, VAL, TEST] }
dataloaders = { x: torch.utils.data.DataLoader(datasets[x], batch_size=BATCH_SIZE, shuffle=True, num_workers=4) for x in [TRAIN, VAL, TEST] }
dataset_sizes = { x : len(datasets[x]) for x in [TRAIN, VAL, TEST] }
NUM_CLASSES = datasets[TRAIN].num_classes
print(datasets[TRAIN].classes)
print(datasets[TRAIN].data_path)
print(datasets[TEST].data_path)
print(datasets[VAL].data_path)

#### Show distributions

In [None]:
datasets[VAL].show_distribution()

In [None]:
datasets[TEST].show_distribution()

In [None]:
datasets[TRAIN].show_distribution()

#### Default Training

In [None]:
from sklearn.metrics import confusion_matrix, classification_report, balanced_accuracy_score

def train_model(vgg, criterion, optimizer, scheduler, num_epochs=10, patience=5):
    since = time.time()

    best_acc = 0.0
    avg_loss = 0
    avg_acc = 0
    avg_loss_val = 0
    avg_acc_val = 0

    train_batches = len(dataloaders[TRAIN])
    val_batches = len(dataloaders[VAL])

    best_model_wts = copy.deepcopy(vgg.state_dict())
    epochs_no_improve = 0

    # Lists to store metrics for plotting
    train_losses = []
    val_losses = []
    train_accuracies = []
    val_accuracies = []

    for epoch in range(num_epochs):
        print(f"Epoch {epoch}/{num_epochs}")
        print('-' * 10)

        loss_train = 0
        loss_val = 0
        acc_train = 0
        acc_val = 0
        
        vgg.train(True)

        # Iterate through batches of training set
        for i, data in enumerate(dataloaders[TRAIN]):
            if i % 100 == 0:
                print(f"\rTraining batch {i}/{train_batches}", flush=True)
            
            # Set input and labels to the batch features and labels
            inputs = data[0]
            labels = data[1]

            # Move to cuda if possible
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            # Do the learning stuff & update loss etc.
            optimizer.zero_grad()
            outputs = vgg(inputs)
            _, preds = torch.max(outputs.data, 1)
            # Backward & optimize
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()                
            
            loss_train += loss.item()
            acc_train += torch.sum(preds == labels.data)
            
            del inputs, labels, outputs, preds
            torch.cuda.empty_cache()
        
        print()
        avg_loss = loss_train / dataset_sizes[TRAIN]
        avg_acc = acc_train / dataset_sizes[TRAIN]

        # Set to evaluation mode
        vgg.train(False)
        vgg.eval()

        all_preds = []
        all_labels = []

        # Iterate through batches of validation set
        with torch.no_grad():
            for i, data in enumerate(dataloaders[VAL]):
                if i % 100 == 0:
                    print(f"\rValidation batch {i}/{val_batches}", flush=True)
                
                # Set input and labels to the batch features and labels
                inputs = data[0]
                labels = data[1]
    
                # Move to cuda if possible
                inputs = inputs.to(DEVICE)
                labels = labels.to(DEVICE)
    
                # Do the learning stuff & update loss etc.
                outputs = vgg(inputs)
                _, preds = torch.max(outputs.data, 1)
                loss = criterion(outputs, labels)
                loss_val += loss.item()
                acc_val += torch.sum(preds == labels.data)

                all_preds.extend(preds.cpu().numpy())
                all_labels.extend(labels.cpu().numpy())
                
                del inputs, labels, outputs, preds
                torch.cuda.empty_cache()
        
        avg_loss_val = loss_val / dataset_sizes[VAL]
        avg_acc_val = balanced_accuracy_score(all_labels, all_preds)

        # Store metrics for plotting
        train_losses.append(avg_loss)
        val_losses.append(avg_loss_val)
        train_accuracies.append(avg_acc.item())
        val_accuracies.append(avg_acc_val)

        print()
        print(f"Epoch {epoch} result: ")
        print(f"Avg loss (train): {avg_loss:.4f}")
        print(f"Avg acc (train): {avg_acc:.4f}")
        print(f"Avg loss (val): {avg_loss_val:.4f}")
        print(f"Avg acc (val): {avg_acc_val:.4f}")
        print('-' * 10)
        print()

        # Update best acc save best model weights
        if avg_acc_val > best_acc:
            best_acc = avg_acc_val
            best_model_wts = copy.deepcopy(vgg.state_dict())
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        # Early stopping
        if epochs_no_improve == patience:
            print(f'Early stopping triggered after {patience} epochs without improvement')
            break

    # Compute training time
    elapsed_time = time.time() - since

    # Print all the results
    print()
    print(f"Training completed in {elapsed_time // 60:.0f}m {elapsed_time % 60:.0f}s")
    print(f"Best acc: {best_acc:.4f}")
    print()

    vgg.load_state_dict(best_model_wts)

    # Compute evaluation metrics on the best model
    vgg.train(False)
    vgg.eval()

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for i, data in enumerate(dataloaders[VAL]):
            # Set input and labels to the batch features and labels
            inputs = data[0]
            labels = data[1]

            # Move to cuda if possible
            inputs = inputs.to(DEVICE)
            labels = labels.to(DEVICE)

            # Do the learning stuff & update loss etc.
            outputs = vgg(inputs)
            _, preds = torch.max(outputs.data, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

            del inputs, labels, outputs, preds
            torch.cuda.empty_cache()

    # Compute evaluation metrics
    cm = confusion_matrix(all_labels, all_preds)
    report = classification_report(all_labels, all_preds, target_names=LABELS)
    acc = balanced_accuracy_score(all_labels, all_preds)
    
    print("Confusion Matrix:")
    print(cm)
    print("Classification Report:")
    print(report)
    print(f"Weighted Accuracy: {acc:.4f}")

    # Plotting the results
    epochs = range(1, len(train_losses) + 1)

    plt.figure(figsize=(14, 5))

    plt.subplot(1, 2, 1)
    plt.plot(epochs, train_losses, label='Training Loss')
    plt.plot(epochs, val_losses, label='Validation Loss')
    plt.title('Loss over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Loss')
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(epochs, train_accuracies, label='Training Accuracy')
    plt.plot(epochs, val_accuracies, label='Validation Accuracy')
    plt.title('Accuracy over Epochs')
    plt.xlabel('Epochs')
    plt.ylabel('Accuracy')
    plt.legend()

    plt.show()

    return vgg


#### Transfer Learning

#### Init the VGG model

In [None]:
# init the VGG model
vgg16 = models.vgg16(weights='IMAGENET1K_V1')
vgg16 = vgg16.to(DEVICE)

# Freeze feature layers
for param in vgg16.features.parameters():
    param.requires_grad = False

# Make classifier trainable
for param in vgg16.classifier.parameters():
    param.requires_grad = True

vgg16.classifier[-1] = torch.nn.Linear(in_features=vgg16.classifier[-1].in_features, out_features=NUM_CLASSES)

vgg16 = vgg16.to(DEVICE)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer_ft = optim.Adam(vgg16.parameters(), lr=0.0001)
exp_lr_scheduler = lr_scheduler.StepLR(optimizer_ft, step_size=10, gamma=0.1)

#### Train and Pickle

In [None]:
since = time.time()

trained_vgg16 = train_model(vgg16, criterion, optimizer_ft, exp_lr_scheduler, num_epochs=50)

elapsed_time = time.time() - since
print(f"Training completed in {elapsed_time // 60:.0f}m {elapsed_time % 60:.0f}s")

pkl_name = f'{DATASET}_VGG16.pkl'
pkl_path = str(MODEL_PATH / 'PICKLE' / pkl_name)

with open(pkl_path, 'wb') as f:
    pickle.dump(trained_vgg16, f)