<a href="https://colab.research.google.com/github/kkotsche1/Image_Classification_Model_Training_Timm/blob/main/Image_Classification_Training_Timm_Library.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# !pip install barbar pytorch-ignite
# !pip install torchvision
# !pip install barbar
# !pip install pytorch-ignite
# !pip install timm
# !pip install torchsummary
# !pip install torchsampler

import torch
from torch import nn 
from torch.nn import functional as F
from torch.optim import Adam
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms, datasets
import timm
import numpy as np 
from tqdm.notebook import tqdm
import glob 
import os
import matplotlib.pyplot as plt
import cv2

import torch
import os
import shutil
import random
import torchvision
import pandas as pd
import torch.nn as nn
import numpy as np
from torchvision import transforms, datasets
from shutil import copyfile, move
from torch.utils.tensorboard import SummaryWriter
from barbar import Bar
from torchsummary import summary
from ignite.metrics import Accuracy
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, classification_report, roc_auc_score
from torchsampler import ImbalancedDatasetSampler

In [None]:
#This requires the training data and validation data be stored in the respective folders at the root of this project ie where this notebook is located

#Individual classes schould be stored as folder containing each image belonging to that class

#   ./train
#      -class1
#         -image_class1_1.jpg
#         -image_class1_2.jpg
#         -image_class1_3.jpg
#      -class2
#         -image_class2_1.jpg
#         -image_class2_2.jpg
# etc. 


train_dir = "./train"
valid_dir = "./validation"

traindir = "./train"
valdir = "./validation/"

train_image_dirs = glob.glob(train_dir + "/*/*.*")
valid_image_dirs = glob.glob(valid_dir + "/*/*.*")

unique_labels = os.listdir(train_dir)
print(unique_labels)

In [None]:
import albumentations as A
import cv2

train_transforms = transforms.Compose([
            transforms.Resize((320,320)),
            transforms.RandomAffine(degrees = 0, translate = (0.1, 0.1)),
            transforms.RandomHorizontalFlip(),
            transforms.RandomVerticalFlip(),
            transforms.RandomRotation(360),
            transforms.RandomPerspective(distortion_scale=0.1, p=0.3, interpolation= transforms.InterpolationMode.BILINEAR, fill=0),
            transforms.ColorJitter(brightness=0.2, contrast=0.2, hue=0.2),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
     ])

val_transforms = transforms.Compose([
    transforms.Resize((320,320)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])


train_dataset = datasets.ImageFolder(
     traindir, transform=train_transforms)

val_dataset = datasets.ImageFolder(
    valdir, transform=val_transforms)

train_loader = torch.utils.data.DataLoader(
    train_dataset, batch_size=32, sampler = ImbalancedDatasetSampler(train_dataset),
    pin_memory=False, drop_last=False)

val_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=32,sampler = ImbalancedDatasetSampler(val_dataset),
    pin_memory=False, drop_last=False)

In [None]:
EPOCHS = 100
LR 0 0.0001
DROP_RATE = 0.5

In [None]:
DEVICE="cuda"

criterion = nn.CrossEntropyLoss().to(DEVICE)

model = timm.create_model('edgenext_small', pretrained=True, num_classes = 8).to(DEVICE)

optimizer = torch.optim.Adam(model.parameters(), lr=LR)

In [None]:
class EarlyStopping:
    """Early stops the training if validation loss doesn't improve after a given patience."""
    def __init__(self, patience=7, verbose=False, delta=0, path='checkpoint.pt', trace_func=print):
        """
        Args:
            patience (int): How long to wait after last time validation loss improved.
                            Default: 7
            verbose (bool): If True, prints a message for each validation loss improvement. 
                            Default: False
            delta (float): Minimum change in the monitored quantity to qualify as an improvement.
                            Default: 0
            path (str): Path for the checkpoint to be saved to.
                            Default: 'checkpoint.pt'
            trace_func (function): trace print function.
                            Default: print            
        """
        self.patience = patience
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = 0
        self.delta = delta
        self.path = path
        self.trace_func = trace_func
    def __call__(self, val_loss, model):

        score = val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.delta:
            self.counter += 1
            
            self.trace_func(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        '''Saves model when monitored metric decrease.'''
        if self.verbose:
            self.trace_func(f'Monitored metric has improved ({self.val_loss_min} --> {val_loss}).  Saving model ...')
        torch.save(model.state_dict(), f'path/to/save') 
        self.val_loss_min = val_loss

In [None]:
early_stopping = EarlyStopping(patience=200, verbose=True)

device = "cuda"

for epoch in EPOCHS:
    train_loss = 0.00
    val_loss = 0.00
    
    train_accuracy = Accuracy()
    val_accuracy = Accuracy()
    print(f'Epoch {epoch+1}')

    # Training loop
    for idx, (inputs, labels) in enumerate(Bar(train_loader)):
        model.train()
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad() 
        outputs = model(inputs) 
        loss = criterion(outputs, labels)
        loss.backward() 
        optimizer.step() 
        train_loss += loss.item()
        train_accuracy.update((nn.functional.softmax(outputs, dim=1), labels))
    print(f"Train Accuracy: {train_accuracy.compute()}")
    train_loss /= len(train_loader)
    train_loss_formated = "{:.4f}".format(train_loss)

    # Validation loop
    with torch.no_grad():
        for inputs, labels in val_loader:
            model.eval()           
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            val_accuracy.update((nn.functional.softmax(outputs, dim=1), labels))
    print(f"Val Accuracy: {val_accuracy.compute()}")
    val_loss /= len(val_loader)
    val_loss_formated = "{:.4f}".format(val_loss)
    print(f'Training Loss: {train_loss_formated}')
    print(f"Validation Loss: {val_loss_formated}")

    # Early Stopping
    early_stopping(val_accuracy.compute(), model)       
    if early_stopping.early_stop:
        print("Early stopping")
        break