### 📝 Imports

In [15]:
import torch
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import os

from PIL import Image

from neural_net import NeuralNet


### 🔧 Config

In [16]:
number_of_epochs = 30
saved_model_path = "trained_networks/cifar_10.pth"

### 🌐 Create Transforms

In [17]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(((0.5), (0.5), (0.5)), ((0.5), (0.5), (0.5))),
])

### 🚦 Load Training Data

In [18]:
train_data = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=32, shuffle=True, num_workers=2)

Files already downloaded and verified


### 🥾 Initialize the Neural Net

In [19]:
net = NeuralNet()

if os.path.isfile(saved_model_path):
    net.load_state_dict(torch.load(saved_model_path))

net.cuda()
loss_function = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001)

  net.load_state_dict(torch.load(saved_model_path))


### 🏃‍♂️‍➡️ Train

In [20]:
for epoch in range(number_of_epochs):
    print(f'Training epoch {epoch}')
    
    running_loss = 0.0
    for i, data in enumerate(train_loader):
        inputs, labels = data
        inputs, labels = inputs.to('cuda'), labels.to('cuda')
        optimizer.zero_grad()
        
        outputs = net(inputs)
        
        loss = loss_function(outputs, labels)
        loss.backward()
        
        optimizer.step()
        
        running_loss += loss.item()
        
    print (f'Loss: {running_loss / len(train_loader):.4f}')

Training epoch 0
Loss: 1.3451
Training epoch 1
Loss: 1.3345
Training epoch 2
Loss: 1.3233
Training epoch 3
Loss: 1.3135
Training epoch 4
Loss: 1.3024
Training epoch 5
Loss: 1.2919
Training epoch 6
Loss: 1.2824
Training epoch 7
Loss: 1.2722
Training epoch 8
Loss: 1.2616
Training epoch 9
Loss: 1.2502
Training epoch 10
Loss: 1.2415
Training epoch 11
Loss: 1.2311
Training epoch 12
Loss: 1.2208
Training epoch 13
Loss: 1.2114
Training epoch 14
Loss: 1.2018
Training epoch 15
Loss: 1.1915
Training epoch 16
Loss: 1.1820
Training epoch 17
Loss: 1.1725
Training epoch 18
Loss: 1.1639
Training epoch 19
Loss: 1.1557
Training epoch 20
Loss: 1.1451
Training epoch 21
Loss: 1.1376
Training epoch 22
Loss: 1.1290
Training epoch 23
Loss: 1.1199
Training epoch 24
Loss: 1.1114
Training epoch 25
Loss: 1.1026
Training epoch 26
Loss: 1.0955
Training epoch 27
Loss: 1.0858
Training epoch 28
Loss: 1.0788
Training epoch 29
Loss: 1.0700


### 💾 Save Progress

In [21]:
torch.save(net.state_dict(), saved_model_path)