In [16]:
import numpy as np
from tqdm import tqdm
import time

import torch
import torch.nn as nn
from torchvision.transforms import transforms

from src.dataset import get_train_dataloader, get_test_dataloader

In [7]:
train_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

In [9]:
train_loader = get_train_dataloader(128, transforms=train_transforms)

In [8]:
test_loader = get_test_dataloader(128, transforms=test_transforms)

In [10]:
class FCNet(nn.Module):

    def __init__(self):
        super(FCNet, self).__init__()

        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(in_features=28 * 28, out_features=32)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(in_features=32, out_features=10)
        self.softmax = nn.Softmax(dim=1)

    def forward(self, x: torch.Tensor):
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        x = self.relu(x)
        x = self.softmax(x)
        return x


In [11]:
model = FCNet()

In [12]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

In [13]:
compute_loss = torch.nn.CrossEntropyLoss(reduction='mean')

In [14]:
# Check to use cuda
use_cuda: bool = torch.cuda.is_available()
if use_cuda:
    model = model.cuda()

In [17]:
epochs = 20
start_n_iter = 0
start_epoch = 0

# main loop
n_iter = start_n_iter
for epoch in range(epochs):
    model.train()
    
    # use prefetch_generator and tqdm for iterating through data
    pbar = tqdm(enumerate(train_loader),
                total=len(train_loader))
    start_time = time.time()

    # for loop going through dataset
    for i, data in pbar:
        # data preparation
        image, label = data
        if use_cuda:
            image = image.cuda()
            label = label.cuda()

        # keep track of preparation time
        prepare_time = start_time - time.time()

        # forward and backward pass
        out = model(image)
        loss = compute_loss(out, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # compute computation time and *compute_efficiency*
        process_time = start_time-time.time()-prepare_time
        compute_efficiency = process_time/(process_time+prepare_time)
        pbar.set_description(
            f'Compute efficiency: {compute_efficiency:.2f}, ' 
            f'loss: {loss.item():.2f},  epoch: {epoch}/{epochs}')
        start_time = time.time()

    # test data every N=1 epochs
    if epoch % 1 == 0:
        model.eval()

        correct = 0
        total = 0
        

        pbar = tqdm(enumerate(test_loader),
                total=len(test_loader)) 
        with torch.no_grad():
            for i, data in pbar:
                # data preparation
                image, label = data
                if use_cuda:
                    image = image.cuda()
                    label = label.cuda()
                
                out = model(image)
                _, predicted = torch.max(out.data, 1)
                total += label.size(0)
                correct += (predicted == label).sum().item()

        print(f'Accuracy on test set: {100*correct/total:.2f}')


  0%|          | 0/469 [00:00<?, ?it/s]


RuntimeError: output with shape [1, 28, 28] doesn't match the broadcast shape [3, 28, 28]