In [2]:
import numpy as np
from tqdm.notebook import tqdm
import time

import torch
import torch.nn as nn
from torchvision.transforms import transforms

from src.dataset import get_train_dataloader, get_test_dataloader
from src.models import FCNet

In [3]:
train_transforms = transforms.Compose([
    transforms.ToTensor()
])

test_transforms = transforms.Compose([
    transforms.ToTensor(),
])

In [4]:
train_loader = get_train_dataloader(128, transforms=train_transforms)

  self.labels = torch.from_numpy(labels).to(torch.long)


In [5]:
test_loader = get_test_dataloader(128, transforms=test_transforms)

In [6]:
model = FCNet()

In [7]:
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)

In [8]:
compute_loss = torch.nn.CrossEntropyLoss(reduction='mean')

In [9]:
# Check to use cuda
use_cuda: bool = torch.cuda.is_available()
if use_cuda:
    model = model.cuda()

In [10]:
epochs = 30

start_n_iter = 0
start_epoch = 0
# main loop
n_iter = start_n_iter
for epoch in range(epochs):
    model.train()
    
    # use prefetch_generator and tqdm for iterating through data
    pbar = tqdm(enumerate(train_loader),
                total=len(train_loader))
    start_time = time.time()

    # for loop going through dataset
    for i, data in pbar:
        # data preparation
        image, label = data
        if use_cuda:
            image = image.cuda()
            label = label.cuda()

        # keep track of preparation time
        prepare_time = start_time - time.time()

        # forward and backward pass
        out = model(image)
        loss = compute_loss(out, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # compute computation time and *compute_efficiency*
        process_time = start_time-time.time()-prepare_time
        compute_efficiency = process_time/(process_time+prepare_time)
        pbar.set_description(
            f'Compute efficiency: {compute_efficiency:.2f}, ' 
            f'loss: {loss.item():.2f},  epoch: {epoch}/{epochs}')
        start_time = time.time()

    # test data every N=1 epochs
    if epoch % 1 == 0:
        model.eval()

        correct = 0
        total = 0
        

        pbar = tqdm(enumerate(test_loader),
                total=len(test_loader)) 
        with torch.no_grad():
            for i, data in pbar:
                # data preparation
                image, label = data
                if use_cuda:
                    image = image.cuda()
                    label = label.cuda()
                
                out = model(image)
                _, predicted = torch.max(out.data, 1)
                total += label.size(0)
                correct += (predicted == label).sum().item()

        print(f'Accuracy on test set: {100*correct/total:.2f}')


  0%|          | 0/469 [00:00<?, ?it/s]

  x = self.softmax(x)


  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 66.33


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 72.36


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 74.47


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 75.31


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 82.61


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 85.04


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 85.30


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 85.85


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 86.00


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 86.24


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 86.34


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 86.50


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 86.57


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 86.65


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 86.80


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 86.87


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 86.92


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.05


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.02


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.05


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.10


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.06


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.16


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.26


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.17


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.31


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.22


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.34


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.34


  0%|          | 0/469 [00:00<?, ?it/s]

  0%|          | 0/79 [00:00<?, ?it/s]

Accuracy on test set: 87.28
