In [1]:
%pip install ignite

Note: you may need to restart the kernel to use updated packages.


In [2]:
%pip install pytorch-ignite


Note: you may need to restart the kernel to use updated packages.


In [3]:
import pandas as pd

import torch
from torch import nn
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.models import resnet18
from torchvision.transforms import Compose, Normalize, ToTensor

from ignite.engine import Engine, EventEnum, Events, create_supervised_evaluator
from ignite.metrics import Accuracy, Loss
from ignite.handlers import Timer
from ignite.contrib.handlers import BasicTimeProfiler, HandlersTimeProfiler


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()

        self.model = resnet18(num_classes=10)
        self.model.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)

    def forward(self, x):
        return self.model(x)


model = Net().to(device)

data_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])

train_loader = DataLoader(
    MNIST(download=True, root=".", transform=data_transform, train=True),
    batch_size=128,
    shuffle=True,
)

optimizer = torch.optim.RMSprop(model.parameters(), lr=0.005)
criterion = nn.CrossEntropyLoss()



In [5]:
class BackpropEvents(EventEnum):
    BACKWARD_STARTED = 'backward_started'
    BACKWARD_COMPLETED = 'backward_completed'
    OPTIM_STEP_COMPLETED = 'optim_step_completed'


In [6]:
def train_step(engine, batch):
    model.train()
    optimizer.zero_grad()
    x, y = batch[0].to(device), batch[1].to(device)
    y_pred = model(x)
    loss = criterion(y_pred, y)
    
    engine.fire_event(BackpropEvents.BACKWARD_STARTED)
    loss.backward()
    engine.fire_event(BackpropEvents.BACKWARD_COMPLETED)

    optimizer.step()
    engine.fire_event(BackpropEvents.OPTIM_STEP_COMPLETED)

    return loss.item()


trainer = Engine(train_step)


In [7]:
trainer.register_events(*BackpropEvents)


In [8]:
@trainer.on(BackpropEvents.BACKWARD_COMPLETED)
def function_before_backprop(engine):
    print(f"Iter[{engine.state.iteration}] Function fired after backward pass")


In [9]:
trainer.run(train_loader, max_epochs=3)


Iter[1] Function fired after backward pass
Iter[2] Function fired after backward pass
Iter[3] Function fired after backward pass
Iter[4] Function fired after backward pass
Iter[5] Function fired after backward pass
Iter[6] Function fired after backward pass
Iter[7] Function fired after backward pass
Iter[8] Function fired after backward pass
Iter[9] Function fired after backward pass
Iter[10] Function fired after backward pass
Iter[11] Function fired after backward pass
Iter[12] Function fired after backward pass
Iter[13] Function fired after backward pass
Iter[14] Function fired after backward pass
Iter[15] Function fired after backward pass
Iter[16] Function fired after backward pass
Iter[17] Function fired after backward pass
Iter[18] Function fired after backward pass
Iter[19] Function fired after backward pass
Iter[20] Function fired after backward pass
Iter[21] Function fired after backward pass
Iter[22] Function fired after backward pass
Iter[23] Function fired after backward pa

State:
	iteration: 1407
	epoch: 3
	epoch_length: 469
	max_epochs: 3
	output: 0.0011817800113931298
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>