Imports and Device Selection

In [2]:
from ts_model import Net
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import torchvision

if torch.cuda.is_available():
    device = torch.device('cuda:0')
    print('Running on GPU')
    print(torch.cuda.get_device_name(0))
else:
    device = torch.device('cpu')
    print('Running on CPU')

Running on GPU
GeForce GTX 1650 Ti with Max-Q Design


Create Model

In [None]:
model = Net()
model.to(device)

Data Processing

In [None]:
train_dataset = None
test_dataset = None

batch_size = 16
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

Optimizer & Loss Function

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters())

Train Model

In [None]:
def train_one_epoch(model, train_loader, optimizer, criterion, device):
    model.train()

    for batch_inputs, batch_labels in train_loader:
        batch_inputs, batch_labels = batch_inputs.to(device), batch_labels.to(device)

        optimizer.zero_grad()
        batch_outputs = model(batch_inputs)
        loss = criterion(batch_outputs, batch_labels)
        loss.backward()
        optimizer.step()

    print("End of epoch loss:", loss.item(), 3)

Test Model

In [None]:
def test(model, test_loader, device):
    model.eval()

    correct = 0
    for batch_inputs, batch_labels in test_loader:
        batch_inputs, batch_labels = batch_inputs.to(device), batch_labels.to(device)

        predictions = model(batch_inputs).argmax(axis=1)
        correct += (predictions == batch_labels).sum().item()

    print('End of epoch accuracy:', 100*correct/len(test_dataset), '%')

Train-Test Loop

In [None]:
NUM_EPOCHS = 10

for epoch in range(NUM_EPOCHS):
    print("Epoch: ", epoch + 1)
    train_one_epoch(model, train_loader, optimizer, criterion, device)
    test(model, test_loader, device)

Save Model

In [None]:
m = torch.jit.script(model)  # Convert to torchscripted form
m.save("model.pt")