## Torch+Ignite CIFAR10

In [39]:
import torch
print("pytorch ",torch.version.__version__)
import torchvision
print("torchvision ",torchvision.version.__version__)
import ignite
print("ignite ",ignite.__version__)

pytorch  2.4.0
torchvision  0.19.0
ignite  0.5.1


Libraries used:
1. PyTorch - [Main](https://pytorch.org/) / [conda channel](https://anaconda.org/pytorch/repo/files)
2. Pytorch Ignite - https://pytorch-ignite.ai/
3. Torchvision - https://pytorch.org/vision/stable/index.html

### Load dataset

In [41]:
# Define transformations for the training and test sets
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

# Custom transformation to scale normalized values back to unit [0, 1]
class ToUnitRange:
    def __call__(self, tensor):
        return (tensor + 1) / 2

# Define transformations for the training and test sets
transform = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Normalize((0.0,), (1.0,))  # Normalize to [0, 1]
])

# Example usage with CIFAR-10 dataset
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
testset  = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)

trainloader= torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True, num_workers=1)
testloader = torch.utils.data.DataLoader(testset, batch_size=64, shuffle=False, num_workers=1)

Files already downloaded and verified
Files already downloaded and verified


In [17]:
# Define CNN model and optimizer

import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

In [65]:
# Define Ignite training loop

from ignite.engine import create_supervised_trainer, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.engine import Events
from ignite.handlers import ModelCheckpoint
from ignite.contrib.handlers import global_step_from_engine

trainer = create_supervised_trainer(model, optimizer, criterion)
# Create the evaluator
metrics = {
    'accuracy': Accuracy(),
    'loss': Loss(criterion)
}
evaluator = create_supervised_evaluator(model, metrics=metrics)

# Create a checkpoint handler
checkpoint_handler = ModelCheckpoint(
    dirname='./checkpoints',
    filename_prefix='cifar10',
    n_saved=3,
    create_dir=True,
    require_empty=False,
    atomic=True,
    include_self=True,
    global_step_transform=global_step_from_engine(trainer)
    )
to_save = {'model': model, 'optimizer': optimizer, 'trainer': trainer}

@trainer.on(Events.EPOCH_COMPLETED)
def run_validation():
    evaluator.run(testloader)
    metrics = evaluator.state.metrics
    print(f"Epoch: {trainer.state.epoch},  Validation accuracy: {metrics['accuracy']}",
          f"Loss: {metrics['loss']:.3f}")

@trainer.on(Events.EPOCH_COMPLETED(every=3))
def run_training_loss():
    evaluator.run(trainloader)
    metrics = evaluator.state.metrics
    print(f"Epoch: {trainer.state.epoch},  Training accuracy: {metrics['accuracy']}",
          f"Loss: {metrics['loss']:.3f}")

@trainer.on(Events.EPOCH_COMPLETED)
def save_checkpoint(engine):
   checkpoint_handler(engine, to_save)

In [56]:
# load saved checkpoint
checkpoint = torch.load('./checkpoints/cifar10_checkpoint_2.pt')
checkpoint_handler.load_objects(to_save, checkpoint)

  checkpoint = torch.load('./checkpoints/cifar10_checkpoint_2.pt')


In [66]:
# Run more epochs
EPOCHS = 7

# Determine the max number of epochs to run
if trainer.state is not None:
    current_epoch = trainer.state.epoch
    max_epochs = current_epoch + EPOCHS
else:
    max_epochs = EPOCHS

trainer.run(trainloader, max_epochs=max_epochs)

Epoch: 1,  Validation accuracy: 0.6136 Loss: 1.570
Epoch: 2,  Validation accuracy: 0.6114 Loss: 1.598
Epoch: 3,  Validation accuracy: 0.6123 Loss: 1.617
Epoch: 4,  Validation accuracy: 0.6083 Loss: 1.608
Epoch: 4,  Training accuracy: 0.85746 Loss: 0.398
Epoch: 5,  Validation accuracy: 0.6102 Loss: 1.621
Epoch: 6,  Validation accuracy: 0.6092 Loss: 1.615
Epoch: 7,  Validation accuracy: 0.598 Loss: 1.669


State:
	iteration: 5474
	epoch: 7
	epoch_length: 782
	max_epochs: 7
	output: 0.5639845728874207
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>