In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
import os
import copy
from PIL import Image
from tqdm import tqdm
from scipy import stats

In [2]:
device = torch.device('cuda:0')

In [3]:
# data_transforms = {
#     'train': transforms.Compose([
# #         transforms.Resize((310, 310)),
#         transforms.Resize((224, 224)),
# #         transforms.CenterCrop(224),
# #         transforms.RandomResizedCrop(224),
# #         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#     ]),
#     'val': transforms.Compose([
# #         transforms.Resize((310, 310)),
#         transforms.Resize((224, 224)),
# #         transforms.Resize(256),
# #         transforms.CenterCrop(224),
# #         transforms.CenterCrop(224),
#         transforms.ToTensor(),
#         transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
#     ]),
# }
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}


In [4]:
class ImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None, stride = 1, max_size=None):
        self.paths = paths[::stride]
        if max_size is not None:
            self.paths = self.paths[:max_size]
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_name = self.paths[idx]
        img = Image.open(img_name).convert('RGB')
        
        img_tensor = self.transform(img)

        return img_tensor, self.labels[idx]

In [5]:
train_file = 'tvnews_sandbox/data_panels/train.txt'
val_file = 'tvnews_sandbox/data_panels/val.txt'
test_file = 'tvnews_sandbox/data_panels/test.txt'

In [6]:
def read_file(filename):
    paths = []
    labels = []
    with open(filename, 'r') as f:
        for line in f.readlines():
            video, image, label = line.split(' ')
            paths.append('tvnews_sandbox/images/{}/{:04d}.jpg'.format(video, int(image)))
            labels.append(int(label))
            
    return paths, labels

In [7]:
train_paths, Y_train = read_file(train_file)
val_paths, Y_val = read_file(val_file)
test_paths, Y_test = read_file(test_file)

In [8]:
def class_balance(filename):
    labels = []
    with open(filename, 'r') as f:
        for line in f.readlines():
            video, image, label = line.split(' ')
            labels.append(int(label))
    
    cb = len([lb for lb in labels if lb == 1]) / len(labels)
    
    return cb

print(class_balance(train_file))
print(class_balance(val_file))
print(class_balance(test_file))

0.12611558156547184
0.11103767349636484
0.12489475161380859


In [9]:
image_datasets = {
    'train': ImageDataset(train_paths, Y_train, transform = data_transforms['train'],
                          stride=1, max_size=None),
    'val_train': ImageDataset(val_paths, Y_val, transform = data_transforms['train'],
                       stride=1, max_size=None),
    'val': ImageDataset(val_paths, Y_val, transform = data_transforms['val'],
                       stride=1, max_size=None),
    'test': ImageDataset(test_paths, Y_test, transform = data_transforms['val'],
                        stride=1, max_size=None)
}

In [10]:
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                              batch_size=4,
                                              shuffle='train' in x,
                                              num_workers=4,
                                              pin_memory=True)
                                              for x in image_datasets}

In [11]:
dataset_sizes = {
    x: len(image_datasets[x])
    for x in image_datasets
}

In [12]:
dataset_sizes

{'train': 6835, 'val_train': 3026, 'val': 3026, 'test': 3563}

In [13]:
def safe_divide(a, b):
    return a / b if b > 0 else 0

In [14]:
def train_model(model, criterion, optimizer, scheduler, train_dl, val_dl, test_dl=None,
                num_epochs=25, return_best=False, verbose=True, log_file=None):
    print(train_dl, val_dl, test_dl)
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_epoch = 0
    best_test_acc = 0.0
    best_f1 = 0.0
    best_precision = 0.0
    best_recall = 0.0
    
    phases = ['train', 'val', 'test'] if test_dl is not None else ['train', 'val']

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in phases:
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
                dl = dataloaders[train_dl]
                dataset_size = dataset_sizes[train_dl]
                
            elif phase == 'val':
                model.eval()   # Set model to evaluate mode
                dl = dataloaders[val_dl]
                dataset_size = dataset_sizes[val_dl]
            else:
                model.eval()
                dl = dataloaders[test_dl]
                dataset_size = dataset_sizes[test_dl]

            running_loss = 0.0
            running_corrects = 0
            true_positives = 0.
            true_negatives = 0.
            false_positives = 0.
            false_negatives = 0.

            # Iterate over data.
            for inputs, labels in dl:
                inputs = inputs.to(device)
                labels = labels.to(device).float()

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    preds = torch.where(
                        outputs >= 0.,
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                    target = torch.where(
                        labels >= 0.5,
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device)
                    )
                    loss = criterion(outputs.view(target.shape), target)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                label_vals = torch.where(
                    labels >= 0.5,
                    torch.tensor([1.]).to(device),
                    torch.tensor([0.]).to(device)
                )
                correct = preds.view(label_vals.shape) == label_vals.data
                running_corrects += torch.sum(correct)
                
                true_positives += torch.sum(
                    torch.where(
                        (correct == 1.) * (label_vals == 1.),
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                )
                true_negatives += torch.sum(
                    torch.where(
                        (correct == 1.) * (label_vals == 0.),
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                )
                false_positives += torch.sum(
                    torch.where(
                        (correct == 0.) * (label_vals == 0.),
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                )
                false_negatives += torch.sum(
                    torch.where(
                        (correct == 0.) * (label_vals == 1.),
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                )
            
            epoch_loss = running_loss / dataset_size
            epoch_acc = running_corrects.double() / dataset_size
            epoch_pre = safe_divide(true_positives, (true_positives + false_positives))
            epoch_recall = safe_divide(true_positives, (true_positives + false_negatives))
            epoch_f1 = safe_divide(2 * epoch_pre * epoch_recall, (epoch_pre + epoch_recall))

            if verbose:
                print('{} Loss: {:.4f} Acc: {:.4f} Pre: {:.4f} Rec: {:.4f} F1: {:.4f}'.format(
                    phase, epoch_loss, epoch_acc, epoch_pre, epoch_recall, epoch_f1))
                print('TP: {} TN: {} FP: {} FN: {}'.format(
                    true_positives.data, true_negatives.data, false_positives.data, false_negatives.data))
            if log_file is not None:
                log_file.write('Phase: {0}\t'
                               'Epoch: [{1}/{2}]\t'
                               'Loss: {loss_c:.4f}\t'
                               'Acc: {acc:.4f}\t'
                               'Pre: {pre:.4f}\t'
                               'Rec: {rec:.4f}\t'
                               'F1: {f1:.4f}\t'
                               'TP: {tp} '
                               'TN: {tn} '
                               'FP: {fp} '
                               'FN: {fn}\n'.format(
                                   phase, epoch + 1, num_epochs, loss_c=epoch_loss,
                                   acc=epoch_acc, pre=epoch_pre, rec=epoch_recall,
                                   f1=epoch_f1, tp=int(true_positives.data), tn=int(true_negatives.data),
                                   fp=int(false_positives.data), fn=int(false_negatives.data)
                               ))
                log_file.flush()

            # deep copy the model
            if phase == 'val' and epoch_f1 > best_f1:
                best_acc = epoch_acc
                best_f1 = epoch_f1
                best_precision = epoch_pre
                best_recall = epoch_recall
                best_epoch = epoch
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'test' and best_epoch == epoch:
                best_test_acc = epoch_acc

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    
    if return_best:
        print('Best epoch: {}'.format(best_epoch))
        print('Best val Acc: {:4f}'.format(best_acc))
        print('Best val Pre: {:4f}'.format(best_precision))
        print('Best val Rec: {:4f}'.format(best_recall))
        print('Best val F1: {:4f}'.format(best_f1))
        print('Test Acc: {:4f}'.format(best_test_acc))

        # load best model weights
        model.load_state_dict(best_model_wts)
    return model

In [23]:
path = 'models/transfer_learning'
for seed in range(5):
    torch.manual_seed(seed)
    model_ts = models.resnet50(pretrained=True)
    num_ftrs = model_ts.fc.in_features
    model_ts.fc = nn.Linear(num_ftrs, 1)

    model_ts = model_ts.to(device)

#     criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.]).to(device))
#     criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([5.]).to(device))
    
    criterion = nn.BCEWithLogitsLoss()
    
    # Observe that all parameters are being optimized
    optimizer_ts = optim.SGD(model_ts.parameters(), lr=0.001, momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler_ts = lr_scheduler.StepLR(optimizer_ts, step_size=7, gamma=0.1)
    
    if not os.path.exists(path):
        os.makedirs(path)
    with open(os.path.join(path, 'seed_{}.log'.format(seed)), 'w') as log_file:
        model_ts = train_model(model_ts, criterion, optimizer_ts, exp_lr_scheduler_ts,
                               'val_train', 'val', test_dl='test', num_epochs=25, verbose=True,
                               log_file=log_file, return_best=False)
        torch.save(model_ts.state_dict(), os.path.join(path, 'seed_{}.pth'.format(seed)))

val_train val test
Epoch 0/24
----------
train Loss: 0.3359 Acc: 0.8840 Pre: 0.3770 Rec: 0.0685 F1: 0.1159
TP: 23.0 TN: 2652.0 FP: 38.0 FN: 313.0
val Loss: 0.2770 Acc: 0.8949 Pre: 1.0000 Rec: 0.0536 F1: 0.1017
TP: 18.0 TN: 2690.0 FP: 0.0 FN: 318.0
test Loss: 0.3914 Acc: 0.8774 Pre: 0.7857 Rec: 0.0247 F1: 0.0479
TP: 11.0 TN: 3115.0 FP: 3.0 FN: 434.0

Epoch 1/24
----------
train Loss: 0.2983 Acc: 0.8876 Pre: 0.4836 Rec: 0.1756 F1: 0.2576
TP: 59.0 TN: 2627.0 FP: 63.0 FN: 277.0
val Loss: 0.1922 Acc: 0.9213 Pre: 0.8451 Rec: 0.3571 F1: 0.5021
TP: 120.0 TN: 2668.0 FP: 22.0 FN: 216.0
test Loss: 0.3213 Acc: 0.8846 Pre: 0.6349 Rec: 0.1798 F1: 0.2802
TP: 80.0 TN: 3072.0 FP: 46.0 FN: 365.0

Epoch 2/24
----------
train Loss: 0.2733 Acc: 0.9009 Pre: 0.6139 Rec: 0.2887 F1: 0.3927
TP: 97.0 TN: 2629.0 FP: 61.0 FN: 239.0
val Loss: 0.1854 Acc: 0.9309 Pre: 0.7292 Rec: 0.6012 F1: 0.6591
TP: 202.0 TN: 2615.0 FP: 75.0 FN: 134.0
test Loss: 0.3218 Acc: 0.8748 Pre: 0.4981 Rec: 0.2876 F1: 0.3647
TP: 128.0 TN: 29

train Loss: 0.3311 Acc: 0.8876 Pre: 0.4767 Rec: 0.1220 F1: 0.1943
TP: 41.0 TN: 2645.0 FP: 45.0 FN: 295.0
val Loss: 0.2267 Acc: 0.9088 Pre: 0.7632 Rec: 0.2589 F1: 0.3867
TP: 87.0 TN: 2663.0 FP: 27.0 FN: 249.0
test Loss: 0.3319 Acc: 0.8875 Pre: 0.7619 Rec: 0.1438 F1: 0.2420
TP: 64.0 TN: 3098.0 FP: 20.0 FN: 381.0

Epoch 1/24
----------
train Loss: 0.3155 Acc: 0.8900 Pre: 0.5135 Rec: 0.1696 F1: 0.2550
TP: 57.0 TN: 2636.0 FP: 54.0 FN: 279.0
val Loss: 0.2014 Acc: 0.9280 Pre: 0.7379 Rec: 0.5446 F1: 0.6267
TP: 183.0 TN: 2625.0 FP: 65.0 FN: 153.0
test Loss: 0.3197 Acc: 0.8914 Pre: 0.6295 Rec: 0.3169 F1: 0.4215
TP: 141.0 TN: 3035.0 FP: 83.0 FN: 304.0

Epoch 2/24
----------
train Loss: 0.2845 Acc: 0.8976 Pre: 0.6000 Rec: 0.2321 F1: 0.3348
TP: 78.0 TN: 2638.0 FP: 52.0 FN: 258.0
val Loss: 0.2527 Acc: 0.9147 Pre: 0.5777 Rec: 0.8631 F1: 0.6921
TP: 290.0 TN: 2478.0 FP: 212.0 FN: 46.0
test Loss: 0.3586 Acc: 0.8588 Pre: 0.4478 Rec: 0.5596 F1: 0.4975
TP: 249.0 TN: 2811.0 FP: 307.0 FN: 196.0

Epoch 3/24
-

val Loss: 0.0819 Acc: 0.9696 Pre: 0.8280 Rec: 0.9167 F1: 0.8701
TP: 308.0 TN: 2626.0 FP: 64.0 FN: 28.0
test Loss: 0.3528 Acc: 0.8861 Pre: 0.5549 Rec: 0.4427 F1: 0.4925
TP: 197.0 TN: 2960.0 FP: 158.0 FN: 248.0

Training complete in 28m 8s
val_train val test
Epoch 0/24
----------
train Loss: 0.3436 Acc: 0.8820 Pre: 0.2000 Rec: 0.0208 F1: 0.0377
TP: 7.0 TN: 2662.0 FP: 28.0 FN: 329.0
val Loss: 0.2458 Acc: 0.9065 Pre: 0.5943 Rec: 0.4970 F1: 0.5413
TP: 167.0 TN: 2576.0 FP: 114.0 FN: 169.0
test Loss: 0.3063 Acc: 0.8880 Pre: 0.5737 Rec: 0.4022 F1: 0.4729
TP: 179.0 TN: 2985.0 FP: 133.0 FN: 266.0

Epoch 1/24
----------
train Loss: 0.2925 Acc: 0.8969 Pre: 0.6277 Rec: 0.1756 F1: 0.2744
TP: 59.0 TN: 2655.0 FP: 35.0 FN: 277.0
val Loss: 0.2435 Acc: 0.9095 Pre: 0.5671 Rec: 0.7798 F1: 0.6566
TP: 262.0 TN: 2490.0 FP: 200.0 FN: 74.0
test Loss: 0.3323 Acc: 0.8647 Pre: 0.4635 Rec: 0.5281 F1: 0.4937
TP: 235.0 TN: 2846.0 FP: 272.0 FN: 210.0

Epoch 2/24
----------
train Loss: 0.2802 Acc: 0.8942 Pre: 0.5541 Re

test Loss: 0.3716 Acc: 0.8869 Pre: 0.5593 Rec: 0.4449 F1: 0.4956
TP: 198.0 TN: 2962.0 FP: 156.0 FN: 247.0

Epoch 24/24
----------
train Loss: 0.1546 Acc: 0.9428 Pre: 0.8382 Rec: 0.6012 F1: 0.7002
TP: 202.0 TN: 2651.0 FP: 39.0 FN: 134.0
val Loss: 0.0743 Acc: 0.9719 Pre: 0.8555 Rec: 0.8988 F1: 0.8766
TP: 302.0 TN: 2639.0 FP: 51.0 FN: 34.0
test Loss: 0.3755 Acc: 0.8838 Pre: 0.5425 Rec: 0.4449 F1: 0.4889
TP: 198.0 TN: 2951.0 FP: 167.0 FN: 247.0

Training complete in 28m 50s
val_train val test
Epoch 0/24
----------
train Loss: 0.3398 Acc: 0.8833 Pre: 0.3396 Rec: 0.0536 F1: 0.0925
TP: 18.0 TN: 2655.0 FP: 35.0 FN: 318.0
val Loss: 0.2919 Acc: 0.8827 Pre: 0.4667 Rec: 0.3958 F1: 0.4283
TP: 133.0 TN: 2538.0 FP: 152.0 FN: 203.0
test Loss: 0.3461 Acc: 0.8656 Pre: 0.4267 Rec: 0.2225 F1: 0.2925
TP: 99.0 TN: 2985.0 FP: 133.0 FN: 346.0

Epoch 1/24
----------
train Loss: 0.2996 Acc: 0.8913 Pre: 0.5321 Rec: 0.1726 F1: 0.2607
TP: 58.0 TN: 2639.0 FP: 51.0 FN: 278.0
val Loss: 0.2111 Acc: 0.9207 Pre: 0.6472 

train Loss: 0.1504 Acc: 0.9471 Pre: 0.8411 Rec: 0.6458 F1: 0.7306
TP: 217.0 TN: 2649.0 FP: 41.0 FN: 119.0
val Loss: 0.0710 Acc: 0.9712 Pre: 0.8487 Rec: 0.9018 F1: 0.8745
TP: 303.0 TN: 2636.0 FP: 54.0 FN: 33.0
test Loss: 0.3657 Acc: 0.8855 Pre: 0.5562 Rec: 0.4112 F1: 0.4729
TP: 183.0 TN: 2972.0 FP: 146.0 FN: 262.0

Epoch 24/24
----------
train Loss: 0.1499 Acc: 0.9524 Pre: 0.8810 Rec: 0.6607 F1: 0.7551
TP: 222.0 TN: 2660.0 FP: 30.0 FN: 114.0
val Loss: 0.0652 Acc: 0.9739 Pre: 0.8725 Rec: 0.8958 F1: 0.8840
TP: 301.0 TN: 2646.0 FP: 44.0 FN: 35.0
test Loss: 0.3763 Acc: 0.8925 Pre: 0.6054 Rec: 0.4000 F1: 0.4817
TP: 178.0 TN: 3002.0 FP: 116.0 FN: 267.0

Training complete in 28m 5s
val_train val test
Epoch 0/24
----------
train Loss: 0.3297 Acc: 0.8866 Pre: 0.4286 Rec: 0.0625 F1: 0.1091
TP: 21.0 TN: 2662.0 FP: 28.0 FN: 315.0
val Loss: 0.2573 Acc: 0.9052 Pre: 0.5698 Rec: 0.5952 F1: 0.5822
TP: 200.0 TN: 2539.0 FP: 151.0 FN: 136.0
test Loss: 0.3207 Acc: 0.8799 Pre: 0.5275 Rec: 0.3663 F1: 0.4324
T

val Loss: 0.0897 Acc: 0.9670 Pre: 0.8089 Rec: 0.9196 F1: 0.8607
TP: 309.0 TN: 2617.0 FP: 73.0 FN: 27.0
test Loss: 0.3641 Acc: 0.8762 Pre: 0.5054 Rec: 0.4225 F1: 0.4602
TP: 188.0 TN: 2934.0 FP: 184.0 FN: 257.0

Epoch 23/24
----------
train Loss: 0.1388 Acc: 0.9511 Pre: 0.8950 Rec: 0.6339 F1: 0.7422
TP: 213.0 TN: 2665.0 FP: 25.0 FN: 123.0
val Loss: 0.0714 Acc: 0.9752 Pre: 0.8872 Rec: 0.8899 F1: 0.8886
TP: 299.0 TN: 2652.0 FP: 38.0 FN: 37.0
test Loss: 0.3662 Acc: 0.8849 Pre: 0.5641 Rec: 0.3461 F1: 0.4290
TP: 154.0 TN: 2999.0 FP: 119.0 FN: 291.0

Epoch 24/24
----------
train Loss: 0.1366 Acc: 0.9527 Pre: 0.8845 Rec: 0.6607 F1: 0.7564
TP: 222.0 TN: 2661.0 FP: 29.0 FN: 114.0
val Loss: 0.0662 Acc: 0.9749 Pre: 0.8916 Rec: 0.8810 F1: 0.8862
TP: 296.0 TN: 2654.0 FP: 36.0 FN: 40.0
test Loss: 0.3799 Acc: 0.8835 Pre: 0.5532 Rec: 0.3506 F1: 0.4292
TP: 156.0 TN: 2992.0 FP: 126.0 FN: 289.0

Training complete in 28m 14s


# Smooth Results

In [24]:
def pos_negs(predictions, gt):
    correct = np.where(np.array(predictions) == np.array(gt), 1, 0)
    incorrect = np.where(np.array(predictions) == np.array(gt), 0, 1)
    
    tp = np.where(correct * np.where(predictions == np.array(1), 1, 0), 1, 0)
    tn = np.where(correct * np.where(predictions == np.array(0), 1, 0), 1, 0)
    fp = np.where(incorrect * np.where(predictions == np.array(1), 1, 0), 1, 0)
    fn = np.where(incorrect * np.where(predictions == np.array(0), 1, 0), 1, 0)
    
    return tp, tn, fp, fn

def acc_prf1(predictions, gt):
    tp, tn, fp, fn = pos_negs(predictions, gt)
    
    acc = (np.sum(tp) + np.sum(tn)) / (np.sum(tp) + np.sum(tn) + np.sum(fp) + np.sum(fn))
    precision = np.sum(tp) / (np.sum(tp) + np.sum(fp))
    recall = np.sum(tp) / (np.sum(tp) + np.sum(fn))
    f1 = 2 * precision * recall / (precision + recall)
    
    return acc, precision, recall, f1, np.sum(tp), np.sum(tn), np.sum(fp), np.sum(fn)

def smooth_predictions(preds, window_radius = 3):
    result = []
    for i in range(len(preds)):
        start = max(0, i - window_radius)
        end = min(len(preds), i + window_radius)
        window = preds[start:end]
        result += [max(window, key=window.count)]
    
    return result

In [25]:
path = 'models/transfer_learning'

log_file_img = open(os.path.join(
    path, 'test_results_image_classifier.log'), 'w')
log_file_smoothed = open(os.path.join(
    path, 'test_results_smoothed.log'), 'w')

for seed in range(5):
    print(seed)
    model_ts = models.resnet50(pretrained=True)
    num_ftrs = model_ts.fc.in_features
    model_ts.fc = nn.Linear(num_ftrs, 1)
    model_ts.load_state_dict(torch.load(
        os.path.join(path, 'seed_{}.pth'.format(seed))
    ))
    model_ts.to(device)
    criterion = nn.BCEWithLogitsLoss()
    
    model = model_ts.eval()   # Set model to evaluate mode
    dl = dataloaders['test']
    dataset_size = dataset_sizes['test']
    
    running_corrects = 0
    true_positives = 0.
    true_negatives = 0.
    false_positives = 0.
    false_negatives = 0.

    i = 0

    predictions = []
    gt_labels = []
    # Iterate over data.
    for inputs, labels in dl:
        inputs = inputs.to(device)
        labels = labels.to(device).float()

        # forward
        # track history if only in train
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            preds = torch.where(
                outputs >= 0.,
                torch.tensor([1.]).to(device),
                torch.tensor([0.]).to(device))
            target = torch.where(
                labels >= 0.5,
                torch.tensor([1.]).to(device),
                torch.tensor([0.]).to(device)
            )

        predictions += preds.cpu().numpy().tolist()
        gt_labels += labels.cpu().numpy().tolist()

        # statistics
        label_vals = torch.where(
            labels >= 0.5,
            torch.tensor([1.]).to(device),
            torch.tensor([0.]).to(device)
        )
        correct = preds.view(label_vals.shape) == label_vals.data
        running_corrects += torch.sum(correct)

        true_positives += torch.sum(
            torch.where(
                (correct == 1.) * (label_vals == 1.),
                torch.tensor([1.]).to(device),
                torch.tensor([0.]).to(device))
        )
        true_negatives += torch.sum(
            torch.where(
                (correct == 1.) * (label_vals == 0.),
                torch.tensor([1.]).to(device),
                torch.tensor([0.]).to(device))
        )
        false_positives += torch.sum(
            torch.where(
                (correct == 0.) * (label_vals == 0.),
                torch.tensor([1.]).to(device),
                torch.tensor([0.]).to(device))
        )
        false_negatives += torch.sum(
            torch.where(
                (correct == 0.) * (label_vals == 1.),
                torch.tensor([1.]).to(device),
                torch.tensor([0.]).to(device))
        )

        num_fp = torch.sum(
        torch.where(
            (correct == 0.) * (label_vals == 0.),
            torch.tensor([1.]).to(device),
            torch.tensor([0.]).to(device))
        )

    #     if num_fp > 0:
    #         print(num_fp)
    #         print(torch.where(
    #             (correct == 0.) * (label_vals == 0.),
    #             torch.tensor([1.]).to(device),
    #             torch.tensor([0.]).to(device)))
    #         print(i)
    #         out = torchvision.utils.make_grid(inputs)
    # #         imshow(out, title=preds.cpu().numpy().tolist())
    #         imshow(out)

        i += 1

    epoch_acc = running_corrects.double() / dataset_size
    epoch_pre = safe_divide(true_positives, (true_positives + false_positives))
    epoch_recall = safe_divide(true_positives, (true_positives + false_negatives))
    epoch_f1 = safe_divide(2 * epoch_pre * epoch_recall, (epoch_pre + epoch_recall))

    print('Acc: {:.4f} Pre: {:.4f} Rec: {:.4f} F1: {:.4f}'.format(
        epoch_acc, epoch_pre, epoch_recall, epoch_f1))
    print('TP: {} TN: {} FP: {} FN: {}'.format(
        true_positives.data, true_negatives.data, false_positives.data, 
        false_negatives.data))
    
    log_file_img.write('Seed: {0}\t'
                   'Acc: {acc:.4f}\t'
                   'Pre: {pre:.4f}\t'
                   'Rec: {rec:.4f}\t'
                   'F1: {f1:.4f}\t'
                   'TP: {tp} '
                   'TN: {tn} '
                   'FP: {fp} '
                   'FN: {fn}\n'.format(
                       seed,
                       acc=epoch_acc, pre=epoch_pre, rec=epoch_recall,
                       f1=epoch_f1, tp=int(true_positives.data),
                       tn=int(true_negatives.data),
                       fp=int(false_positives.data), fn=int(false_negatives.data)
                   ))
    log_file_img.flush()

    predictions = [p[0] for p in predictions]

    smoothed_preds = smooth_predictions(predictions, 3)

    print("Smoothed stats:")
    print(acc_prf1(smoothed_preds, gt_labels))
    
    sm_acc, sm_pre, sm_rec, sm_f1, sm_tp, sm_tn, sm_fp, sm_fn = acc_prf1(
        smoothed_preds, gt_labels)
    
    log_file_smoothed.write('Seed: {0}\t'
                   'Acc: {acc:.4f}\t'
                   'Pre: {pre:.4f}\t'
                   'Rec: {rec:.4f}\t'
                   'F1: {f1:.4f}\t'
                   'TP: {tp} '
                   'TN: {tn} '
                   'FP: {fp} '
                   'FN: {fn}\n'.format(
                       seed,
                       acc=sm_acc, pre=sm_pre, rec=sm_rec,
                       f1=sm_f1, tp=sm_tp,
                       tn=sm_tn,
                       fp=sm_fp, fn=sm_fn
                   ))
    log_file_smoothed.flush()

0
Acc: 0.8832 Pre: 0.5472 Rec: 0.3775 F1: 0.4468
TP: 168.0 TN: 2979.0 FP: 139.0 FN: 277.0
Smoothed stats:
(0.8975582374403592, 0.6904761904761905, 0.3258426966292135, 0.4427480916030535, 145, 3053, 65, 300)
1
Acc: 0.8861 Pre: 0.5549 Rec: 0.4427 F1: 0.4925
TP: 197.0 TN: 2960.0 FP: 158.0 FN: 248.0
Smoothed stats:
(0.9003648610721302, 0.6956521739130435, 0.3595505617977528, 0.4740740740740741, 160, 3048, 70, 285)
2
Acc: 0.8838 Pre: 0.5425 Rec: 0.4449 F1: 0.4889
TP: 198.0 TN: 2951.0 FP: 167.0 FN: 247.0
Smoothed stats:
(0.9003648610721302, 0.6875, 0.3707865168539326, 0.4817518248175182, 165, 3043, 75, 280)
3
Acc: 0.8925 Pre: 0.6054 Rec: 0.4000 F1: 0.4817
TP: 178.0 TN: 3002.0 FP: 116.0 FN: 267.0
Smoothed stats:
(0.902610159977547, 0.7474747474747475, 0.3325842696629214, 0.4603421461897356, 148, 3068, 50, 297)
4
Acc: 0.8835 Pre: 0.5532 Rec: 0.3506 F1: 0.4292
TP: 156.0 TN: 2992.0 FP: 126.0 FN: 289.0
Smoothed stats:
(0.8958742632612967, 0.7283950617283951, 0.2651685393258427, 0.3887973640856672