# check points

In [None]:
import torch
import os

def load_checkpoint(r_path, model, optimizer=None, scheduler=None):

    if not os.path.isfile(r_path):
        print(f"No checkpoint found at '{r_path}'")
        return model, optimizer, scheduler, 0

    print(f"Loading checkpoint '{r_path}'")
    checkpoint = torch.load(r_path, map_location='cpu')

    model.load_state_dict(checkpoint['model_state_dict'])

    if optimizer is not None and 'optimizer_state_dict' in checkpoint:
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    epoch = checkpoint.get('epoch', 0)

    if scheduler is not None:
        for _ in range(epoch):
            scheduler.step()

    return model, optimizer, scheduler, epoch

def save_checkpoint(model, epoch, tag, base_directory, optimizer=None, current_val_score=None, best_scores=None, checkpoint_freq=1):
    if not os.path.exists(base_directory):
        os.makedirs(base_directory)

    save_dict = {
        'model_state_dict': model.state_dict(),
        'epoch': epoch
    }

    if optimizer is not None:
        save_dict['optimizer_state_dict'] = optimizer.state_dict()

    if current_val_score is not None:
        # Initialize best_scores if it's None
        if best_scores is None:
            best_scores = [(float('inf'), ''), (float('inf'), '')]  # (score, filename)

        # Insert new score and sort
        best_scores.append((current_val_score, f'{tag}_checkpoint_epoch_{epoch + 1}_val{current_val_score:.4f}.pth'))
        best_scores.sort(key=lambda x: x[0])

        # Keep only the top 2 scores
        best_scores = best_scores[:2]

        # Save the model if it's one of the top 2
        for score, filename in best_scores:
            if score == current_val_score:
                save_path = os.path.join(base_directory, filename)
                torch.save(save_dict, save_path)
                break

    # Regular checkpoint updates
    if (epoch + 1) % checkpoint_freq == 0:
        checkpoint_path = os.path.join(base_directory, f'{tag}_checkpoint_epoch_{epoch + 1}.pth')
        torch.save(save_dict, checkpoint_path)

    return best_scores

# gradient accumulation step

In [None]:
accumulation_steps = 4  # This means the effective batch size is batch_size * accumulation_steps

# Start the training loop
for epoch in range(num_epochs):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(train_loader):
        # Forward pass
        outputs = model(inputs)
        loss = loss_function(outputs, targets)

        # Scales loss as per the number of accumulation steps
        loss = loss / accumulation_steps

In [None]:
# -- torch vision transforms ( needed to be included )
import pickle
import os
import sys
import argparse
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import json
import pandas as pd
import copy
import torchvision.ops
from images.scheduler.scheduler import SchedulerManager
from images.datasets.singlefolder_dataset import SingleFolderDataset
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# ================ INIT ================= #
learning_rate = 0.001
scheduler_config = {'T_0': 100, 'T_mult': 1, 'eta_min': 0.0001}
scheduler = 'CosineAnnealingWarmRestarts'

def initialization(learning_rate, model_class, scheduler, sceduler_config, criterion):
  model = model_class().to(device) # -- model = JNet().to(device)
  optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
  scheduler_manager = SchedulerManager()
  scheduler_manager.configs[scheduler] = scheduler_config # -- config update
  scheduler = scheduler_manager.initialize_scheduler(optimizer, 'CosineAnnealingWarmRestarts')
  criterion = nn.CrossEntropyLoss()  # -- criterion
  return model, optimizer, scheduler, criterion

# ================ LOADING STEP ================= #
load_dir = ''
load_states = True # -- bool

def loading_states(load_dir, model, optimizer, scheduler):
  model, optimizer, scheduler, start_epoch = load_checkpoint(load_dir, model, optimizer, scheduler)
  return model, optimizer, scheduler, start_epoch

# ================ DATA PREP ================= #
train_df =
val_df =
collater = custom_collate_fn # or False
batch_sizes = 24
shuffle = False
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

def dataset_prep(train_df, val_df, collater, batch_sizes, shuffle, transform ):
    # -- target col
    def generate_target_columns(start, end):
      return [str(i) for i in range(start, end + 1)]
    target_columns = generate_target_columns(1, 16)
    # --

    # -- train
    train_dataset = DataFrameDataset_custom(train_df, 'img_path', target_columns, transform=transform)

    if collater:
        train_dataloader = DataLoader(train_dataset, batch_size=batch_sizes, shuffle=shuffle, collate_fn=collater)
    else:
        train_dataloader = DataLoader(train_dataset, batch_size=batch_sizes, shuffle=shuffle)

    # -- validation
    val_dataset = DataFrameDataset_custom(val_df, 'img_path', target_columns, transform=transform)

    if collater:
        val_dataloader = DataLoader(val_dataset, batch_size=batch_sizes, shuffle=shuffle, collate_fn=collater)
    else:
        val_dataloader = DataLoader(val_dataset, batch_size=batch_sizes, shuffle=shuffle)


    return train_dataloader, val_dataloader


# ================ TRAINING STEP ================= #
num_epochs =

def training_step(model, scheduler, train_dataloader, optimizer, criterion = None):
    model.train()
    total_loss = 0

    for batch_idx, (batch_data, batch_target) in enumerate(tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{num_epochs}")):
        batch_data, batch_target = batch_data.to(device), batch_target.to(device)
        optimizer.zero_grad()
        outputs = model(batch_data)

        batch_target = batch_target.long()  # -- convert to long
        loss = criterion(outputs, batch_target)

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    average_training_loss = total_loss / len(train_dataloader)
    scheduler.step() # -- scheduler step
    print(f"Epoch {epoch + 1}/{num_epochs}, Training Loss: {average_training_loss}")

  return model, average_training_loss, num_epochs, optimizer

# ================ TRAINING STEP ================= #
def validation_step(model, val_dataloader, criterion = None):
  model.eval()
  total_val_loss = 0
  val_loss = 0
  with torch.no_grad():
      for batch_data, batch_target in val_dataloader:
          batch_data, batch_target = batch_data.to(device), batch_target.to(device)
          outputs = model(batch_data)

          batch_target = batch_target.long()  # Convert target labels to torch.long
          loss = criterion(outputs, batch_target)
          val_loss += loss.item()

  total_val_loss = val_loss / len(val_dataloader)
  print(f"Average Validation Loss: {total_val_loss}")

  return total_val_loss

# ================ CHECK POINTS ================= #
base_directory = 'directory'
tag = 'name of the model'

def check_points(model, epoch, tag, base_directory, optimizer, val_score, best_scores = None):
  best_scores = save_checkpoint(model, epoch, tag, base_directory, optimizer, val_score, best_scores, checkpoint_freq=1)
  return best_scores

In [None]:
# -- init
model, optimizer, scheduler, start_epoch_t = initialization(learning_rate, model_class, scheduler, sceduler_config, criterion)

# -- load
if load_states:
  model_t, optimizer, scheduler_t, start_epoch = loading_states(load_dir, model, optimizer, scheduler)

# -- data
train_dataloader, val_dataloader = dataset_prep(train_df, val_df, collater, batch_sizes, shuffle, transform )

for epoch in range(start_epoch, num_epochs):
  # -- training
  model, training_loss, num_epochs, optimizer =  training_step(model, scheduler, train_dataloader, optimizer, criterion = None)

  # -- validation
  val_loss = validation_step(model, val_dataloader, criterion = None)

  # -- saving check points
  best_scores = check_points(model, epoch, tag, base_directory, optimizer, val_loss, best_scores = None)
