In [1]:
import os
import re
import copy
import itertools
import warnings

import numpy as np

import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch2trt import torch2trt
from torch2trt import TRTModule

warnings.filterwarnings("ignore")

random_state = 42

In [2]:
n_epochs = 3
batch_size_train = 64
batch_size_test = 1000
learning_rate = 0.01
momentum = 0.5
log_interval = 10

random_seed = 1

In [3]:
train_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=True, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_train, shuffle=True)

test_loader = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=batch_size_test, shuffle=True)

In [4]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 20, 5, 1)
        self.conv2 = nn.Conv2d(20, 50, 5, 1)
        self.fc1 = nn.Linear(4*4*50, 500)
        self.fc2 = nn.Linear(500, 10)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.max_pool2d(x, 2, 2)
        x = F.relu(self.conv2(x))
        x = F.max_pool2d(x, 2, 2)
        x = x.view(-1, 4*4*50)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x
    
inputs, labels = next(iter(train_loader))
model = Net()
model(inputs)[:1]

tensor([[ 0.1809,  0.1472,  0.2091,  0.0225, -0.0389,  0.0180,  0.0023,  0.0622,
         -0.1889,  0.0579]], grad_fn=<SliceBackward>)

In [5]:
patience = 10
n_epochs = 1

model = Net()
model.cuda()
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9)

train_losses = []
valid_losses = []

avg_train_losses = []
avg_valid_losses = [] 
train_accuracies = []
valid_accuracies = []

for epoch in range(n_epochs):  # loop over the dataset multiple times
    train_correct = 0
    train_total = 0
    model.train()
    for i, data in enumerate(train_loader, 0):

        inputs, labels = data
        inputs = inputs.to("cuda")
        labels = labels.to("cuda")

        optimizer.zero_grad()

        outputs = model(inputs.float())

        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        train_losses.append(loss.item())

        _, predicted = torch.max(outputs.data, 1)
        train_total += labels.size(0)
        train_correct += (predicted == labels).sum().item()

    valid_correct = 0
    valid_total = 0
    model.eval()
    for data in test_loader:
        inputs, labels = data
        inputs = inputs.to("cuda")
        labels = labels.to("cuda")

        outputs = model(inputs.float())
        loss = criterion(outputs, labels)
        valid_losses.append(loss.item())

        _, predicted = torch.max(outputs.data, 1)
        valid_total += labels.size(0)
        valid_correct += (predicted == labels).sum().item()

    # calculate average loss over an epoch
    train_loss = np.average(train_losses)
    valid_loss = np.average(valid_losses)
    train_accuracy = train_correct / train_total
    valid_accuracy = valid_correct / valid_total

    avg_train_losses.append(train_loss)
    avg_valid_losses.append(valid_loss)
    train_accuracies.append(train_accuracy)
    valid_accuracies.append(valid_accuracy)
        
    log_message = f'[{str(epoch)}/{str(n_epochs)}] train_loss: {train_loss:.5f} valid_loss: {valid_loss:.5f} train_accuracy: {train_accuracy:.5f} valid_accuracy: {valid_accuracy:.5f}'
    print(log_message)

    # clear lists to track next epoch
    train_losses = []
    valid_losses = []

print('Finished Training')

[0/1] train_loss: 0.18067 valid_loss: 0.04064 train_accuracy: 0.94338 valid_accuracy: 0.98580
Finished Training


In [6]:
valid_correct = 0
valid_total = 0
model.eval()
for data in test_loader:
    inputs, labels = data
    inputs = inputs.to("cuda")
    labels = labels.to("cuda")

    outputs = model(inputs.float())
    loss = criterion(outputs, labels)
    valid_losses.append(loss.item())

    _, predicted = torch.max(outputs.data, 1)
    valid_total += labels.size(0)
    valid_correct += (predicted == labels).sum().item()

valid_correct/valid_total

0.9858

In [7]:
model_trt = copy.deepcopy(model)
valid_correct = 0
valid_total = 0
model_trt.eval()
for data in test_loader:
    inputs, labels = data
    inputs = inputs.to("cuda")
    labels = labels.to("cuda")

    outputs = model_trt(inputs.float())
    loss = criterion(outputs, labels)
    valid_losses.append(loss.item())

    _, predicted = torch.max(outputs.data, 1)
    valid_total += labels.size(0)
    valid_correct += (predicted == labels).sum().item()

valid_correct/valid_total

0.9858

In [8]:
test_loader_trt = torch.utils.data.DataLoader(
  torchvision.datasets.MNIST('/files/', train=False, download=True,
                             transform=torchvision.transforms.Compose([
                               torchvision.transforms.ToTensor(),
                               torchvision.transforms.Normalize(
                                 (0.1307,), (0.3081,))
                             ])),
  batch_size=1, shuffle=True)

model_trt = model_trt.eval().cuda().half()
inputs, labels = next(iter(test_loader_trt))
inputs = inputs.cuda().half()

model_trt = torch2trt(model_trt, [inputs], fp16_mode=True)

In [9]:
valid_correct = 0
valid_total = 0
model_trt.eval()
for data in test_loader_trt:
    inputs, labels = data
    inputs = inputs.to("cuda")
    labels = labels.to("cuda")

    outputs = model_trt(inputs.cuda().half())
    loss = criterion(outputs, labels)
    valid_losses.append(loss.item())

    _, predicted = torch.max(outputs.data, 1)
    valid_total += labels.size(0)
    valid_correct += (predicted == labels).sum().item()

valid_correct/valid_total

0.9858