In [1]:
import numpy as np
from tqdm import tqdm
import time

import torch
import torch.nn as nn
from torchvision.transforms import transforms

from src.dataset import get_train_dataloader, get_test_dataloader

In [7]:
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [9]:
train_loader = get_train_dataloader(128, transforms=train_transforms)

In [8]:
test_loader = get_test_dataloader(128, transforms=test_transforms)

In [10]:
class FCNet(nn.Module):

    def __init__(self):
        super(FCNet, self).__init__()

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features=28 * 28, out_features=32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(in_features=32, out_features=10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x: torch.Tensor):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.softmax(x)
        return x


In [11]:
model = FCNet()

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

In [13]:
compute_loss = torch.nn.CrossEntropyLoss(reduction='mean')

In [14]:
# Check to use cuda
use_cuda: bool = torch.cuda.is_available()
if use_cuda:
    model = model.cuda()

In [None]:
epochs = 20
start_n_iter = 0
start_epoch = 0

for epoch in range(epochs):
    model.train()
    