In [2]:
!pip install torch torchvision ipywidgets widgetsnbextension

Collecting torch
  Using cached torch-1.10.2-cp39-cp39-manylinux1_x86_64.whl (881.9 MB)
Collecting torchvision
  Using cached torchvision-0.11.3-cp39-cp39-manylinux1_x86_64.whl (23.2 MB)
Collecting ipywidgets
  Using cached ipywidgets-7.6.5-py2.py3-none-any.whl (121 kB)
Collecting widgetsnbextension
  Using cached widgetsnbextension-3.5.2-py2.py3-none-any.whl (1.6 MB)
Collecting jupyterlab-widgets>=1.0.0
  Using cached jupyterlab_widgets-1.0.2-py3-none-any.whl (243 kB)
Installing collected packages: widgetsnbextension, torch, jupyterlab-widgets, torchvision, ipywidgets
Successfully installed ipywidgets-7.6.5 jupyterlab-widgets-1.0.2 torch-1.10.2 torchvision-0.11.3 widgetsnbextension-3.5.2


In [3]:
import torch
from torch import nn
from torch.optim import SGD
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision.transforms import Compose, ToTensor, Normalize, RandomHorizontalFlip
from torchvision.models import resnet18

def get_mnist_data_loaders(path, train_transform, train_batch_size, val_transform, val_batch_size):

    train_loader = DataLoader(
        MNIST(download=True, root=path, transform=train_transform, train=True),
        batch_size=train_batch_size,
        shuffle=True,
    )

    val_loader = DataLoader(
        MNIST(download=False, root=path, transform=val_transform, train=False), batch_size=val_batch_size, shuffle=False
    )
    return train_loader, val_loader

In [4]:
def train_mnist_classifier():
    seed = 12
    debug = False
    train_batch_size = 128
    val_batch_size = 512
    
    train_transform = Compose([RandomHorizontalFlip(), ToTensor(), Normalize((0.1307,), (0.3081,))])
    val_transform = Compose([ToTensor(), Normalize((0.1307,), (0.3081,))])
    
    path = '/tmp/mnist'
    # Load the MNIST dataset
    train_loader, test_loader = get_mnist_data_loaders(
        path, train_transform, train_batch_size, val_transform, val_batch_size
    )

    model = resnet18(num_classes=10)
    model.conv1 = nn.Conv2d(1, 64, 3)

    learning_rate = 0.01

    optimizer = SGD(model.parameters(), lr=learning_rate)
    criterion = nn.CrossEntropyLoss()

    # Train the network
    for epoch in range(1):
        for batch_idx, (data, target) in enumerate(train_loader):
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 100 == 0:
                print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                    epoch, batch_idx * len(data), len(train_loader.dataset),
                    100. * batch_idx / len(train_loader), loss.item()))

    # Test the network
    with torch.no_grad():
        correct = 0
        total = 0
        for data in test_loader:
            images, labels = data
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print('Accuracy of the network on the 10000 test images: {} %'.format(100 * correct / total))

    # Save the model
    torch.save(model, 'mnist_classifier.pth')

if __name__ == '__main__':
    train_mnist_classifier()

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting /tmp/mnist/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting /tmp/mnist/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting /tmp/mnist/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting /tmp/mnist/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist/MNIST/raw

Accuracy of the network on the 10000 test images: 94.6 %


In [6]:
model = resnet18(num_classes=10)
model.conv1 = nn.Conv2d(1, 64, 3)
model.load_state_dict(torch.load('/home/jovyan/python-elemeno-ai-sdk/tests/conversion/torch/mnist_classifier.pth'))

<All keys matched successfully>

In [None]:
torch.save(model, 'home/jovyan/python-elemeno-ai-sdk/tests/conversion/torch/mnist_classifier.pt')