## Pytorch Ignite Example

In [1]:
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, Engine
from ignite.metrics import Accuracy, Loss
from ignite.contrib.handlers import ProgressBar
import torch
import torchvision
import torchvision.transforms as transforms
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

batch_size = 8

train_ds = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_dl = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)

val_ds = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform)
val_dl = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [2]:
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = torch.flatten(x, 1) # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x


model = Net()

In [3]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
epochs = 10
model.to(device)

Net(
  (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120, bias=True)
  (fc2): Linear(in_features=120, out_features=84, bias=True)
  (fc3): Linear(in_features=84, out_features=10, bias=True)
)

In [4]:
device

'cuda'

In [5]:
def train_step(engine, batch):
    x, y = batch
    x = x.to(device)
    y = y.to(device)

    model.train()
    x = x.requires_grad_()
    result = model(x)
    loss = criterion(result, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss

In [6]:
def val_step(engine, batch):
    x, y = batch

    model.eval()
    with torch.no_grad():
        x = x.to(device)
        y = y.to(device)

        result = model(x)

        return result, y

In [7]:
trainer = Engine(train_step)
validator = Engine(val_step)

In [8]:
@trainer.on(Events.STARTED)
def start_message():
    print("Training now!")

In [9]:
@trainer.on(Events.COMPLETED)
def done_message():
    print("Training done!")

In [10]:
@trainer.on(Events.EPOCH_COMPLETED)
def run_test():
    validator.run(val_dl)

In [11]:
val_metrics = {
    "accuracy": Accuracy(),
    "loss": Loss(criterion)
}

In [12]:
for name, metric in val_metrics.items():
    metric.attach(validator, name)

In [13]:
train_evaluator = create_supervised_evaluator(model, metrics=val_metrics, device=device)

In [14]:
@trainer.on(Events.EPOCH_COMPLETED)
def run_validation():
    train_evaluator.run(train_dl)

In [15]:
@train_evaluator.on(Events.COMPLETED)
def show_train_results():
    metrics = train_evaluator.state.metrics
    acc = val_metrics["accuracy"].compute()
    loss = val_metrics["loss"].compute()
    print(f"Training results for Epoch: {trainer.state.epoch} - accuracy: {acc:.3f} - Loss: {loss:.3f}")

In [16]:
@validator.on(Events.COMPLETED)
def show_valid_results():
    metrics = validator.state.metrics
    acc = val_metrics["accuracy"].compute()
    loss = val_metrics["loss"].compute()
    print(f"Validation results for Epoch {validator.state.epoch} - accuracy: {acc:.3f} - Loss: {loss:.3f}")

In [17]:
#ProgressBar().attach(trainer, output_transform=lambda x: {'batch loss': x})
ProgressBar().attach(trainer)

In [18]:
trainer.run(train_dl, max_epochs=epochs)

Training now!


[1/6250]   0%|           [00:00<?]

Validation results for Epoch 1 - accuracy: 0.432 - Loss: 1.526
Training results for Epoch: 1 - accuracy: 0.432 - Loss: 1.527


[1/6250]   0%|           [00:00<?]

Validation results for Epoch 1 - accuracy: 0.520 - Loss: 1.330
Training results for Epoch: 2 - accuracy: 0.530 - Loss: 1.307


[1/6250]   0%|           [00:00<?]

Validation results for Epoch 1 - accuracy: 0.554 - Loss: 1.248
Training results for Epoch: 3 - accuracy: 0.572 - Loss: 1.196


[1/6250]   0%|           [00:00<?]

Validation results for Epoch 1 - accuracy: 0.590 - Loss: 1.173
Training results for Epoch: 4 - accuracy: 0.620 - Loss: 1.082


[1/6250]   0%|           [00:00<?]

Validation results for Epoch 1 - accuracy: 0.618 - Loss: 1.103
Training results for Epoch: 5 - accuracy: 0.656 - Loss: 0.986


[1/6250]   0%|           [00:00<?]

Validation results for Epoch 1 - accuracy: 0.639 - Loss: 1.044
Training results for Epoch: 6 - accuracy: 0.685 - Loss: 0.898


[1/6250]   0%|           [00:00<?]

Validation results for Epoch 1 - accuracy: 0.640 - Loss: 1.045
Training results for Epoch: 7 - accuracy: 0.697 - Loss: 0.865


[1/6250]   0%|           [00:00<?]

Validation results for Epoch 1 - accuracy: 0.647 - Loss: 1.034
Training results for Epoch: 8 - accuracy: 0.712 - Loss: 0.822


[1/6250]   0%|           [00:00<?]

Validation results for Epoch 1 - accuracy: 0.629 - Loss: 1.094
Training results for Epoch: 9 - accuracy: 0.701 - Loss: 0.843


[1/6250]   0%|           [00:00<?]

Validation results for Epoch 1 - accuracy: 0.650 - Loss: 1.033
Training results for Epoch: 10 - accuracy: 0.735 - Loss: 0.756
Training done!


State:
	iteration: 62500
	epoch: 10
	epoch_length: 6250
	max_epochs: 10
	output: <class 'torch.Tensor'>
	batch: <class 'list'>
	metrics: <class 'dict'>
	dataloader: <class 'torch.utils.data.dataloader.DataLoader'>
	seed: <class 'NoneType'>
	times: <class 'dict'>