# Notebook for the final output of the test dataset

Here we changed something in order to print the results for the test dataset in the correct way. The code is basically the same from main.py, refer to that file for the real implementation for our project.

In [80]:
from __future__ import annotations
import os
import torch
import warnings
from utils.data_loader import ImageDataset
from model.cnn import CNN
from model.cnn2 import CNN2
from model.cnn3 import CNN3
import torch.nn as nn
from torchvision import transforms
import numpy as np
from utils.performance_measure import precision_recall_f1
warnings.filterwarnings('ignore')
from torchsummary import summary

torch.set_printoptions(threshold=10_000)
torch.set_num_threads(22)

TRAIN_SIZE = 0.8
BATCH_SIZE_TRAIN = 2000
BATCH_SIZE_TEST = 2000
TRANSFORM = False

"""
SETUP
"""
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

classes = [] # Get all the classes for one-hot encoding
for filename in os.listdir('../data/annotations'):
    filename, _ = os.path.splitext(filename)
    classes.append(filename)
classes = list(set(classes))

"""
COLLATE FUNCTION
Input:
    - batch: a simple batch from data loader
Output: 
    - zip of samples in order to form a batch which can be used by the standard implementation of the training procedure
"""
def collate_fn(batch):
    return tuple(zip(*batch))

"""
DATA AUGMENTATION
    - ColorJitter: random brightness and contrast change
    - RandomAdjustSharpness: random change sharpness (not too hardly and with low probability)
    - RandomInvert: random invert colors
    - RandomRotation: just a small rotation
Note: depends on TRANSFORM
"""
train_transform = transforms.Compose([
                                        transforms.ColorJitter(brightness=.5, contrast=.3),
                                        transforms.RandomAdjustSharpness(sharpness_factor=1.1, p=.1),
                                        transforms.RandomInvert(p=.1),
                                        transforms.RandomRotation(degrees=2)
                                        ]) if TRANSFORM else None

"""
DATA LOADING
    - Load all data with custom ImageDataset class
    - Create test-train-dev splits and create the DataLoader objects
"""
data = ImageDataset(label_dir='../data/annotations', img_dir='../data/images', classes=classes, transform=train_transform)

train_size = int(TRAIN_SIZE*len(data))
test_size = int(len(data)-train_size)
train_set, test_set = torch.utils.data.random_split(data, [train_size, test_size])

train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=BATCH_SIZE_TRAIN, shuffle=False, collate_fn=collate_fn)
test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=BATCH_SIZE_TEST, shuffle=False, collate_fn=collate_fn)


In [81]:
LR = 0.001
N_EPOCHS = 10
PATIENCE = 4
IS_VERBOSE = False
ACTIVATION_TRESHOLD = 0.25
WEIGHT_DECAY = 0.2

"""
HYPERPARAMETERS AND CONSTANTS
    - TRAIN_SIZE: size of the training set
    - BATCH_SIZE_TRAIN: size of the batches for training phase
    - BATCH_SIZE_TEST: size of the batches for testing phase
    - LR: learning rate
    - N_EPOCHS: number of epochs to execute
    - PATIENCE: the number of previous validation losses smaller than the actual one needed to early stop the training
    - IS_VERBOSE: to avoid too much output
    - ACTIVATION_TRESHOLD: the threshold for the activation function to consider a class as present
    - WEIGHT_DECAY: the weight decay for the regularization in Adam optimizer
    - USE_VALIDATION: to use the validation set or not. If false only the test set is used, if true the validation set is used
    - TRANSFORM: to use the data augmentation or not
"""

"""
MODEL INITIALIZATION
    - optimizer: Adam with weight decay as regularization technique
    - loss function: binary cross entropy loss
"""
model = CNN(dropout=False).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
loss_function = nn.BCELoss()

"""
TRAIN
    Notes:
    - Uses early stopping if the validation loss does not improve after a certain number of epochs; this depends on PATIENCE.
    - Uses the validation set if USE_VALIDATION is true, otherwise the test set is used
    - The classes are inferred based on the activation threshold
    - Precision, recall and f1 are computed for each epoch and the average is returned
"""
pre_valid_losses = []
for epoch in range(N_EPOCHS):
    train_loss = 0
    valid_losses = []
    precision = 0
    recall = 0
    f1 = 0

    for batch_num, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        data, target = torch.stack(data, dim=0), torch.stack(target, dim=0)
        outputs = model(data.float())
        loss = loss_function(outputs, target)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        predictions = outputs.data
        predictions = torch.argwhere(predictions > ACTIVATION_TRESHOLD)

        target = torch.argwhere(target)
        _precision, _recall, _f1 = precision_recall_f1(predictions, target)
        precision += _precision
        recall += _recall
        f1 += _f1

        if IS_VERBOSE:
            print('Training: Epoch %d - Batch %d/%d: Loss: %.4f' % 
            (epoch, batch_num, len(train_loader), train_loss / (batch_num + 1)))

    print('EPOCH', epoch, 'PRECISION:', (precision / (train_size/BATCH_SIZE_TRAIN)))
    print('EPOCH', epoch, 'RECALL:', (recall / (train_size/BATCH_SIZE_TRAIN)))
    print('EPOCH', epoch, 'F1-SCORE:', (f1 / (train_size/BATCH_SIZE_TRAIN)))

    """
    EARLY STOPPING
    When validation loss is higher than the previous PATIENCE ones it stops.
    If there have been PATIENCE or more previous losses smaller than the actual, stop
    """
    model.eval() 
    for data, target in test_loader:
        data, target = torch.stack(data, dim=0), torch.stack(target, dim=0)
        output = model(data.float())
        loss = loss_function(output, target.float())
        valid_losses.append(loss.item())

    valid_loss = np.average(valid_losses)
    pre_valid_losses.append(valid_loss)

    print('Epoch', epoch, 'Validation loss', valid_loss)

    j = 0
    # Now start checking if it has to stop
    if len(pre_valid_losses) >= PATIENCE:
        for l in pre_valid_losses:
            if l < valid_loss:
                j+=1
    if(j>=PATIENCE):
        break

"""
TEST
"""
test_loss = 0
precision = 0
recall = 0
f1 = 0
model.eval()
with torch.no_grad():
    for batch_num, (data, target) in enumerate(test_loader):
        data, target = torch.stack(data, dim=0), torch.stack(target, dim=0)
        outputs = model(data.float())
        loss = loss_function(outputs, target.float())
        test_loss += loss.item()
        predictions = outputs.data
        predictions = torch.argwhere(predictions > ACTIVATION_TRESHOLD)
        target = torch.argwhere(target)

        _precision, _recall, _f1 = precision_recall_f1(predictions, target)
        precision += _precision
        recall += _recall
        f1 += _f1

        if IS_VERBOSE:
            print('Evaluating: Batch %d/%d: Loss: %.4f' % 
            (batch_num, len(test_loader), test_loss / (batch_num + 1)))
    print('TEST PRECISION:', (precision / (test_size/BATCH_SIZE_TEST)))
    print('TEST RECALL:', (recall / (test_size/BATCH_SIZE_TEST)))
    print('TEST F1-SCORE',  (f1 / (test_size/BATCH_SIZE_TEST)))

EPOCH 0 PRECISION: 0.20445800719561363
EPOCH 0 RECALL: 0.541637768288763
EPOCH 0 F1-SCORE: 0.2841073312350703
Epoch 0 Validation loss 0.40095336735248566
EPOCH 1 PRECISION: 0.3107196519103894
EPOCH 1 RECALL: 0.5120679953327992
EPOCH 1 F1-SCORE: 0.38533047381122076
Epoch 1 Validation loss 0.36685608327388763
EPOCH 2 PRECISION: 0.3290206401997256
EPOCH 2 RECALL: 0.5723475119919789
EPOCH 2 F1-SCORE: 0.41510309653188887
Epoch 2 Validation loss 0.3622354120016098
EPOCH 3 PRECISION: 0.3292167621629845
EPOCH 3 RECALL: 0.567923097757969
EPOCH 3 F1-SCORE: 0.41560185132908795
Epoch 3 Validation loss 0.36408497393131256
EPOCH 4 PRECISION: 0.32877297591903154
EPOCH 4 RECALL: 0.5725582285058626
EPOCH 4 F1-SCORE: 0.4157034295662583
Epoch 4 Validation loss 0.36419475078582764
EPOCH 5 PRECISION: 0.328258782530228
EPOCH 5 RECALL: 0.554960365300006
EPOCH 5 F1-SCORE: 0.4114663473207939
Epoch 5 Validation loss 0.3661970794200897
EPOCH 6 PRECISION: 0.32876518097014684
EPOCH 6 RECALL: 0.5657663813228753
EPO

In [82]:
from torchvision import transforms, datasets
from matplotlib import pyplot as plt
from utils.data_loader_test import ImageDatasetTest

new_data = ImageDatasetTest(label_dir='../data/test_fake_annotations', img_dir='../data/test_images', classes=classes, transform=train_transform)
# shuffle=False, the images are in the righ order so we can get the names after. I manually checked this
new_test_loader = torch.utils.data.DataLoader(dataset=new_data, shuffle=False, collate_fn=collate_fn) 
image_names = new_data.image_names # get image names
enc_labels = new_data.enc_labels # get encoded labels
e = [i for i in range(0,14)]
enc_labels = dict(zip(e, enc_labels))

# Get the predictions in the usual way
preds = []
model.eval()
with torch.no_grad():
    for batch_num, (data, _) in enumerate(new_test_loader,0):
        data = torch.stack(data, dim=0)
        outputs = model(data.float())
        predictions = outputs.data
        preds.append(torch.argwhere(predictions > ACTIVATION_TRESHOLD).tolist())

# convert encoded to decoded labels
for pred in preds:
    for sing in pred:
        sing[1] = enc_labels[sing[1]]

# transform the prediction into one-hot vectors
preds_onehot = {}
i = 0
for p in preds:
    str = ''
    for label in enc_labels.values():
        str += '1 ' if any(label in x for x in p) else '0 '
    preds_onehot[image_names[i]] = str
    i+=1

# write to file
with open('final_results_041_2.txt', 'w') as f:
    str = 'image_name\t'
    for c in enc_labels.values():
        str += c + '\t'
    f.write(str + '\n')
    for key, value in preds_onehot.items():
        f.write(key.strip() + '\t' + value.replace(' ', '\t').strip() + '\n')