In [1]:
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim

from utils.nets import *
from utils.model_tools import *
from utils.feature_extractor import *
from utils.dataset_tools import *
from utils.cosine_similarity import *

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5), (0.5))])

LEARNING_RATE = 0.001
EXP_DECAY = 0.0001

batch_size = 64

FMNIST_train_gen = torchvision.datasets.FashionMNIST(root='./data', train=True,
                                        download=True, transform=transform)
FMNIST_trainloader_gen = torch.utils.data.DataLoader(FMNIST_train_gen, batch_size=batch_size,
                                          shuffle=True, num_workers=2)

FMNIST_test_gen = torchvision.datasets.FashionMNIST(root='./data', train=False,
                                       download=True, transform=transform)
FMNIST_testloader_gen = torch.utils.data.DataLoader(FMNIST_test_gen, batch_size=batch_size,
                                         shuffle=False, num_workers=2)


no_boot_bag_train_idx = np.where((np.array(FMNIST_train_gen.targets) != 8) & 
                        (np.array(FMNIST_train_gen.targets) != 9))[0]
no_boot_bag_train_subset = torch.utils.data.Subset(FMNIST_train_gen, no_boot_bag_train_idx)
no_boot_bag_train_dl = torch.utils.data.DataLoader(no_boot_bag_train_subset, batch_size=batch_size, shuffle=True, num_workers=2)

no_boot_bag_test_idx = np.where((np.array(FMNIST_test_gen.targets) != 8) & 
                        (np.array(FMNIST_test_gen.targets) != 9))[0]
no_boot_bag_test_subset = torch.utils.data.Subset(FMNIST_test_gen, no_boot_bag_test_idx)
no_boot_bag_test_dl = torch.utils.data.DataLoader(no_boot_bag_test_subset, batch_size=batch_size, shuffle=True, num_workers=2)

In [3]:
class LinearFashionMNIST_alt(nn.Module):
  def __init__(self, input_size, num_classes: int):
    super(LinearFashionMNIST_alt, self).__init__()

    self.flatten = nn.Flatten()
    self.input_layer = nn.Linear(input_size, 128)
    self.output_layer = nn.Linear(128, num_classes)

  def forward(self, x):
    x = self.flatten(x)
    return self.output_layer(self.input_layer(x))

In [4]:
criterion = nn.CrossEntropyLoss()
linear_model = LinearFashionMNIST_alt(28*28, 8)
FMNIST_optim = optim.Adam(linear_model.parameters(), lr=LEARNING_RATE)

num_epochs = 15

decay_rate = (EXP_DECAY/LEARNING_RATE)**(1/num_epochs)

lr_scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer=FMNIST_optim, gamma=decay_rate) 
# TODO: we need to use the scheduler for cnn too if we use that

In [5]:
from utils.exceptions import ArchitectureError

import torchmetrics
from torchmetrics.classification import MulticlassRecall

from sklearn.metrics import classification_report

In [6]:
def test(dataloader, model, loss_fn, device, swap=False, swap_labels=[], classes = 9) -> float:
    '''
        Model test loop. Performs a single epoch of model updates.

        * USAGE *
        Within a training loop of range(num_epochs) to perform epoch validation, or after training to perform testing.

        * PARAMETERS *
        dataloader: A torch.utils.data.DataLoader object
        model: A torch model which subclasses torch.nn.Module
        loss_fn: A torch loss function, such as torch.nn.CrossEntropyLoss
        optimizer: A torch.optim optimizer
        device: 'cuda' or 'cpu'

        * RETURNS *
        float: The average test loss
    '''

    size = len(dataloader.dataset)
    num_batches = len(dataloader)
    test_loss, correct = 0, 0
    y_pred_list, targets = [], []

    model.eval()
    with torch.no_grad():
        for X, y in dataloader:
            if swap:
                for i in range(len(y)):
                    if y[i] == swap_labels[0]:
                        y[i] = swap_labels[1]
            X, y = X.to(device), y.to(device)
            pred = model(X)
            #preds.append(pred)
            targets.append(y.numpy())
            test_loss += loss_fn(pred, y).item()
            correct += (pred.argmax(1) == y).type(torch.float).sum().item()
            
            _, y_pred_tags = torch.max(pred, dim=1)
            y_pred_list.append(y_pred_tags.cpu().numpy())
            
    y_pred_list = [a.squeeze().tolist() for a in y_pred_list]

    test_loss /= num_batches
    correct /= size
    
    #print(preds)
    #print(targets)
    
    recall = MulticlassRecall(classes)
    # torch.IntTensor(targets)
    recall_val = recall(torch.FloatTensor(np.asarray(y_pred_list)), torch.IntTensor(np.asarray(targets)))
    # should I be calling it on preds[0]?

    print(
        f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f}, Recall val: {recall_val:>8f} \n")

    return test_loss, np.asarray(y_pred_list), np.asarray(targets)

In [7]:
train_losses = []
test_losses = []

for epoch in range(num_epochs):
    train_loss = train(no_boot_bag_train_dl, linear_model, criterion, FMNIST_optim, 'cpu')
    test_loss, y_pred_list, y_test = test(no_boot_bag_test_dl, linear_model, criterion, 'cpu')
    
    print("Epoch", epoch, "train loss:", train_loss, "test loss:", test_loss)
    
    lr_scheduler.step()
    
    train_losses.append(train_loss)
    test_losses.append(test_loss)
    
    #print(classification_report(y_test, y_pred_list))

loss: 2.038709  [    0/48000]
Test Error: 
 Accuracy: 81.5%, Avg loss: 0.501730, Recall val: 0.724333 

Epoch 0 train loss: 0.5428749349514643 test loss: 0.5017296605110169
loss: 0.386660  [    0/48000]
Test Error: 
 Accuracy: 82.3%, Avg loss: 0.487703, Recall val: 0.731667 

Epoch 1 train loss: 0.47309311193227765 test loss: 0.48770318055152895
loss: 0.397758  [    0/48000]
Test Error: 
 Accuracy: 81.2%, Avg loss: 0.497151, Recall val: 0.722222 

Epoch 2 train loss: 0.4534300406376521 test loss: 0.49715096735954284
loss: 0.653621  [    0/48000]
Test Error: 
 Accuracy: 82.1%, Avg loss: 0.481427, Recall val: 0.730000 

Epoch 3 train loss: 0.44205059425036114 test loss: 0.4814268772602081
loss: 0.281445  [    0/48000]
Test Error: 
 Accuracy: 82.5%, Avg loss: 0.484571, Recall val: 0.732889 

Epoch 4 train loss: 0.43446280564864476 test loss: 0.4845711535215378
loss: 0.343484  [    0/48000]
Test Error: 
 Accuracy: 82.0%, Avg loss: 0.483754, Recall val: 0.729000 

Epoch 5 train loss: 0.4286