In [1]:
import torch
import torchvision
import torchvision.transforms as transforms

In [2]:
transform = transforms.Compose(
    [transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))])

In [3]:
training_set = torchvision.datasets.FashionMNIST('./data', train=True, transform=transform, download=True)
validation_set = torchvision.datasets.FashionMNIST('./data', train=False, transform=transform, download=True)


In [4]:
training_loader = torch.utils.data.DataLoader(training_set, batch_size=4, shuffle=True)
validation_loader = torch.utils.data.DataLoader(validation_set, batch_size=4, shuffle=False)

In [5]:
classes = ('T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
        'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle Boot')

In [6]:
# Report split sizes
print('Training set has {} instances'.format(len(training_set)))
print('Validation set has {} instances'.format(len(validation_set)))

Training set has 60000 instances
Validation set has 10000 instances


In [7]:
from models import CNN_Simple
model = CNN_Simple()

In [8]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

In [9]:
loss_fn = torch.nn.CrossEntropyLoss()

In [10]:
def train_one_epoch(epoch_index):
    running_loss = 0.
    last_loss = 0.

    # Here, we use enumerate(training_loader) instead of
    # iter(training_loader) so that we can track the batch
    # index and do some intra-epoch reporting
    for i, data in enumerate(training_loader):
        # Every data instance is an input + label pair
        inputs, labels = data

        # Zero your gradients for every batch!
        optimizer.zero_grad()

        # Make predictions for this batch
        outputs = model(inputs)

        # Compute the loss and its gradients
        loss = loss_fn(outputs, labels)
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()
        if i % 1000 == 999:
            last_loss = running_loss / 1000 # loss per batch
            print('  batch {} loss: {}'.format(i + 1, last_loss))
            running_loss = 0.

    return last_loss

In [11]:
# Initializing in a separate cell so we can easily add more epochs to the same run
epoch_number = 0

EPOCHS = 5

best_vloss = 1_000_000.

for epoch in range(EPOCHS):
    print('EPOCH {}:'.format(epoch_number + 1))

    # Make sure gradient tracking is on, and do a pass over the data
    model.train(True)
    avg_loss = train_one_epoch(epoch_number)


    running_vloss = 0.0
    # Set the model to evaluation mode, disabling dropout and using population
    # statistics for batch normalization.
    model.eval()

    # Disable gradient computation and reduce memory consumption.
    with torch.no_grad():
        for i, vdata in enumerate(validation_loader):
            vinputs, vlabels = vdata
            voutputs = model(vinputs)
            vloss = loss_fn(voutputs, vlabels)
            running_vloss += vloss

    avg_vloss = running_vloss / (i + 1)
    print('LOSS train {} valid {}'.format(avg_loss, avg_vloss))

    epoch_number += 1

EPOCH 1:
  batch 1000 loss: 1.3700492917858065
  batch 2000 loss: 0.753765296482481
  batch 3000 loss: 0.6575536987413653
  batch 4000 loss: 0.617356703243684
  batch 5000 loss: 0.5597889542661142
  batch 6000 loss: 0.5574784479846712
  batch 7000 loss: 0.5291602074299008
  batch 8000 loss: 0.5166076289904303
  batch 9000 loss: 0.46495643792499325
  batch 10000 loss: 0.4735561103897635
  batch 11000 loss: 0.4517362857464468
  batch 12000 loss: 0.4417606803929666
  batch 13000 loss: 0.4557264124929206
  batch 14000 loss: 0.4376754614979727
  batch 15000 loss: 0.4274326150195848


KeyboardInterrupt: 

In [10]:
from pytorch_optim_training_manager import train_manager

In [11]:
manager = train_manager(model, loss_fn, optimizer, training_loader, validation_loader)

In [12]:
losses = manager.train(5, verbose=True)

epoch0: train_loss: 0.5853
epoch1: train_loss: 0.3673
epoch2: train_loss: 0.3249
epoch3: train_loss: 0.3027
epoch4: train_loss: 0.2856


In [13]:
losses

[0.5853274175085768,
 0.36727704638633,
 0.32490648102932745,
 0.3027269221981532,
 0.28560926965978056]

In [16]:
validation = manager.eval_model()

In [17]:
validation

{'loss': 0.3151531573221726, 'accuracy': tensor(0.8860)}