In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
import os
import copy
from PIL import Image
from tqdm import tqdm
from scipy import stats

In [2]:
device = torch.device('cuda:0')

In [3]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}


In [4]:
class ImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None, stride = 1, max_size=None):
        self.paths = paths[::stride]
        if max_size is not None:
            self.paths = self.paths[:max_size]
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_name = self.paths[idx]
        img = Image.open(img_name).convert('RGB')
        
        img_tensor = self.transform(img)

        return img_tensor, self.labels[idx]

In [5]:
train_file = 'conversation_export/data/train.txt'
val_file = 'conversation_export/data/val.txt'
test_file = 'conversation_export/data/test.txt'

In [6]:
def read_file(filename):
    paths = []
    labels = []
    with open(filename, 'r') as f:
        for line in f.readlines():
            video, image, label = line.split(' ')
            paths.append('conversation_export/images/{}/{:04d}.jpg'.format(video, int(image)))
            labels.append(int(label))
            
    return paths, labels

In [7]:
train_paths, Y_train = read_file(train_file)
val_paths, Y_val = read_file(val_file)
test_paths, Y_test = read_file(test_file)

In [8]:
image_datasets = {
    'train': ImageDataset(train_paths, Y_train, transform = data_transforms['train'],
                          stride=1, max_size=None),
    'val_train': ImageDataset(val_paths, Y_val, transform = data_transforms['train'],
                       stride=1, max_size=None),
    'val': ImageDataset(val_paths, Y_val, transform = data_transforms['val'],
                       stride=1, max_size=None),
    'test': ImageDataset(test_paths, Y_test, transform = data_transforms['val'],
                        stride=1, max_size=None)
}

In [9]:
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                              batch_size=4,
                                              shuffle='train' in x,
                                              num_workers=4,
                                              pin_memory=True)
                                              for x in image_datasets}

In [10]:
dataset_sizes = {
    x: len(image_datasets[x])
    for x in image_datasets
}

In [11]:
dataset_sizes

{'train': 32410, 'val_train': 4710, 'val': 4710, 'test': 5230}

In [12]:
def safe_divide(a, b):
    return a / b if b > 0 else 0

In [13]:
def train_model(model, criterion, optimizer, scheduler, train_dl, val_dl, test_dl=None,
                num_epochs=25, return_best=False, verbose=True, log_file=None):
    print(train_dl, val_dl, test_dl)
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_epoch = 0
    best_test_acc = 0.0
    best_f1 = 0.0
    best_precision = 0.0
    best_recall = 0.0
    
    phases = ['train', 'val', 'test'] if test_dl is not None else ['train', 'val']

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in phases:
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
                dl = dataloaders[train_dl]
                dataset_size = dataset_sizes[train_dl]
                
            elif phase == 'val':
                model.eval()   # Set model to evaluate mode
                dl = dataloaders[val_dl]
                dataset_size = dataset_sizes[val_dl]
            else:
                model.eval()
                dl = dataloaders[test_dl]
                dataset_size = dataset_sizes[test_dl]

            running_loss = 0.0
            running_corrects = 0
            true_positives = 0.
            true_negatives = 0.
            false_positives = 0.
            false_negatives = 0.

            # Iterate over data.
            for inputs, labels in dl:
                inputs = inputs.to(device)
                labels = labels.to(device).float()

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    preds = torch.where(
                        outputs >= 0.,
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                    target = torch.where(
                        labels >= 0.5,
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device)
                    )
                    loss = criterion(outputs.view(target.shape), target)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                label_vals = torch.where(
                    labels >= 0.5,
                    torch.tensor([1.]).to(device),
                    torch.tensor([0.]).to(device)
                )
                correct = preds.view(label_vals.shape) == label_vals.data
                running_corrects += torch.sum(correct)
                
                true_positives += torch.sum(
                    torch.where(
                        (correct == 1.) * (label_vals == 1.),
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                )
                true_negatives += torch.sum(
                    torch.where(
                        (correct == 1.) * (label_vals == 0.),
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                )
                false_positives += torch.sum(
                    torch.where(
                        (correct == 0.) * (label_vals == 0.),
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                )
                false_negatives += torch.sum(
                    torch.where(
                        (correct == 0.) * (label_vals == 1.),
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                )
            
            epoch_loss = running_loss / dataset_size
            epoch_acc = running_corrects.double() / dataset_size
            epoch_pre = safe_divide(true_positives, (true_positives + false_positives))
            epoch_recall = safe_divide(true_positives, (true_positives + false_negatives))
            epoch_f1 = safe_divide(2 * epoch_pre * epoch_recall, (epoch_pre + epoch_recall))

            if verbose:
                print('{} Loss: {:.4f} Acc: {:.4f} Pre: {:.4f} Rec: {:.4f} F1: {:.4f}'.format(
                    phase, epoch_loss, epoch_acc, epoch_pre, epoch_recall, epoch_f1))
                print('TP: {} TN: {} FP: {} FN: {}'.format(
                    true_positives.data, true_negatives.data, false_positives.data, false_negatives.data))
            if log_file is not None:
                log_file.write('Phase: {0}\t'
                               'Epoch: [{1}/{2}]\t'
                               'Loss: {loss_c:.4f}\t'
                               'Acc: {acc:.4f}\t'
                               'Pre: {pre:.4f}\t'
                               'Rec: {rec:.4f}\t'
                               'F1: {f1:.4f}\t'
                               'TP: {tp} '
                               'TN: {tn} '
                               'FP: {fp} '
                               'FN: {fn}\n'.format(
                                   phase, epoch + 1, num_epochs, loss_c=epoch_loss,
                                   acc=epoch_acc, pre=epoch_pre, rec=epoch_recall,
                                   f1=epoch_f1, tp=int(true_positives.data), tn=int(true_negatives.data),
                                   fp=int(false_positives.data), fn=int(false_negatives.data)
                               ))
                log_file.flush()

            # deep copy the model
            if phase == 'val' and epoch_f1 > best_f1:
                best_acc = epoch_acc
                best_f1 = epoch_f1
                best_precision = epoch_pre
                best_recall = epoch_recall
                best_epoch = epoch
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'test' and best_epoch == epoch:
                best_test_acc = epoch_acc

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    
    if return_best:
        print('Best epoch: {}'.format(best_epoch))
        print('Best val Acc: {:4f}'.format(best_acc))
        print('Best val Pre: {:4f}'.format(best_precision))
        print('Best val Rec: {:4f}'.format(best_recall))
        print('Best val F1: {:4f}'.format(best_f1))
        print('Test Acc: {:4f}'.format(best_test_acc))

        # load best model weights
        model.load_state_dict(best_model_wts)
    return model

In [14]:
path = 'models/transfer_learning_tutorial'
for seed in range(5):
    torch.manual_seed(seed)
    model_ts = models.resnet50(pretrained=True)
    num_ftrs = model_ts.fc.in_features
    model_ts.fc = nn.Linear(num_ftrs, 1)

    model_ts = model_ts.to(device)

#     criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.]).to(device))
#     criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.]).to(device))
#     criterion = nn.CrossEntropyLoss()
    criterion = nn.BCEWithLogitsLoss().to(device)

    # Observe that all parameters are being optimized
    optimizer_ts = optim.SGD(model_ts.parameters(), lr=0.001, momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler_ts = lr_scheduler.StepLR(optimizer_ts, step_size=7, gamma=0.1)
    
    if not os.path.exists(path):
        os.makedirs(path)
    with open(os.path.join(path, 'seed_{}.log'.format(seed)), 'w') as log_file:
        model_ts = train_model(model_ts, criterion, optimizer_ts, exp_lr_scheduler_ts,
                               'val_train', 'val', test_dl='test', num_epochs=25,
                               verbose=True, log_file=log_file, return_best=False)
        torch.save(model_ts.state_dict(), os.path.join(path, 'seed_{}.pth'.format(seed)))

val_train val test
Epoch 0/24
----------
train Loss: 0.6486 Acc: 0.6726 Pre: 0.7286 Rec: 0.8032 F1: 0.7641
TP: 2497.0 TN: 671.0 FP: 930.0 FN: 612.0
val Loss: 0.6799 Acc: 0.7187 Pre: 0.7024 Rec: 0.9958 F1: 0.8237
TP: 3096.0 TN: 289.0 FP: 1312.0 FN: 13.0
test Loss: 1.0372 Acc: 0.6310 Pre: 0.6261 Rec: 0.9975 F1: 0.7694
TP: 3219.0 TN: 81.0 FP: 1922.0 FN: 8.0

Epoch 1/24
----------
train Loss: 0.5703 Acc: 0.7225 Pre: 0.7712 Rec: 0.8241 F1: 0.7968
TP: 2562.0 TN: 841.0 FP: 760.0 FN: 547.0
val Loss: 0.4794 Acc: 0.7760 Pre: 0.9161 Rec: 0.7272 F1: 0.8108
TP: 2261.0 TN: 1394.0 FP: 207.0 FN: 848.0
test Loss: 0.8110 Acc: 0.5736 Pre: 0.7151 Rec: 0.5135 F1: 0.5978
TP: 1657.0 TN: 1343.0 FP: 660.0 FN: 1570.0

Epoch 2/24
----------
train Loss: 0.5274 Acc: 0.7529 Pre: 0.7962 Rec: 0.8408 F1: 0.8179
TP: 2614.0 TN: 932.0 FP: 669.0 FN: 495.0
val Loss: 0.3528 Acc: 0.8669 Pre: 0.8733 Rec: 0.9337 F1: 0.9025
TP: 2903.0 TN: 1180.0 FP: 421.0 FN: 206.0
test Loss: 1.0453 Acc: 0.6141 Pre: 0.6998 Rec: 0.6560 F1: 0.677

train Loss: 0.2654 Acc: 0.8930 Pre: 0.9069 Rec: 0.9337 F1: 0.9201
TP: 2903.0 TN: 1303.0 FP: 298.0 FN: 206.0
val Loss: 0.1185 Acc: 0.9590 Pre: 0.9655 Rec: 0.9727 F1: 0.9691
TP: 3024.0 TN: 1493.0 FP: 108.0 FN: 85.0
test Loss: 0.9156 Acc: 0.5990 Pre: 0.7280 Rec: 0.5590 F1: 0.6324
TP: 1804.0 TN: 1329.0 FP: 674.0 FN: 1423.0

Training complete in 119m 4s
val_train val test
Epoch 0/24
----------
train Loss: 0.6365 Acc: 0.6864 Pre: 0.7326 Rec: 0.8266 F1: 0.7768
TP: 2570.0 TN: 663.0 FP: 938.0 FN: 539.0
val Loss: 0.4078 Acc: 0.8166 Pre: 0.8245 Rec: 0.9173 F1: 0.8685
TP: 2852.0 TN: 994.0 FP: 607.0 FN: 257.0
test Loss: 0.6344 Acc: 0.6728 Pre: 0.6944 Rec: 0.8392 F1: 0.7599
TP: 2708.0 TN: 811.0 FP: 1192.0 FN: 519.0

Epoch 1/24
----------
train Loss: 0.5569 Acc: 0.7312 Pre: 0.7773 Rec: 0.8308 F1: 0.8032
TP: 2583.0 TN: 861.0 FP: 740.0 FN: 526.0
val Loss: 0.3671 Acc: 0.8369 Pre: 0.8862 Rec: 0.8639 F1: 0.8749
TP: 2686.0 TN: 1256.0 FP: 345.0 FN: 423.0
test Loss: 0.7037 Acc: 0.6757 Pre: 0.7257 Rec: 0.7626

train Loss: 0.2979 Acc: 0.8775 Pre: 0.8984 Rec: 0.9183 F1: 0.9082
TP: 2855.0 TN: 1278.0 FP: 323.0 FN: 254.0
val Loss: 0.1161 Acc: 0.9603 Pre: 0.9600 Rec: 0.9807 F1: 0.9702
TP: 3049.0 TN: 1474.0 FP: 127.0 FN: 60.0
test Loss: 0.8857 Acc: 0.6159 Pre: 0.7056 Rec: 0.6477 F1: 0.6754
TP: 2090.0 TN: 1131.0 FP: 872.0 FN: 1137.0

Epoch 24/24
----------
train Loss: 0.2882 Acc: 0.8883 Pre: 0.9047 Rec: 0.9286 F1: 0.9165
TP: 2887.0 TN: 1297.0 FP: 304.0 FN: 222.0
val Loss: 0.1155 Acc: 0.9588 Pre: 0.9715 Rec: 0.9659 F1: 0.9687
TP: 3003.0 TN: 1513.0 FP: 88.0 FN: 106.0
test Loss: 0.9364 Acc: 0.5897 Pre: 0.7276 Rec: 0.5355 F1: 0.6169
TP: 1728.0 TN: 1356.0 FP: 647.0 FN: 1499.0

Training complete in 91m 43s
val_train val test
Epoch 0/24
----------
train Loss: 0.6247 Acc: 0.6815 Pre: 0.7323 Rec: 0.8157 F1: 0.7718
TP: 2536.0 TN: 674.0 FP: 927.0 FN: 573.0
val Loss: 0.5190 Acc: 0.7318 Pre: 0.8320 Rec: 0.7440 F1: 0.7855
TP: 2313.0 TN: 1134.0 FP: 467.0 FN: 796.0
test Loss: 0.7055 Acc: 0.6434 Pre: 0.7077 Rec: 0.7

train Loss: 0.2887 Acc: 0.8834 Pre: 0.9000 Rec: 0.9263 F1: 0.9130
TP: 2880.0 TN: 1281.0 FP: 320.0 FN: 229.0
val Loss: 0.1402 Acc: 0.9552 Pre: 0.9770 Rec: 0.9546 F1: 0.9657
TP: 2968.0 TN: 1531.0 FP: 70.0 FN: 141.0
test Loss: 0.9149 Acc: 0.5943 Pre: 0.7454 Rec: 0.5200 F1: 0.6126
TP: 1678.0 TN: 1430.0 FP: 573.0 FN: 1549.0

Epoch 23/24
----------
train Loss: 0.2900 Acc: 0.8849 Pre: 0.8963 Rec: 0.9337 F1: 0.9146
TP: 2903.0 TN: 1265.0 FP: 336.0 FN: 206.0
val Loss: 0.1314 Acc: 0.9565 Pre: 0.9730 Rec: 0.9608 F1: 0.9668
TP: 2987.0 TN: 1518.0 FP: 83.0 FN: 122.0
test Loss: 0.8505 Acc: 0.6052 Pre: 0.7360 Rec: 0.5615 F1: 0.6370
TP: 1812.0 TN: 1353.0 FP: 650.0 FN: 1415.0

Epoch 24/24
----------
train Loss: 0.2980 Acc: 0.8786 Pre: 0.8909 Rec: 0.9299 F1: 0.9100
TP: 2891.0 TN: 1247.0 FP: 354.0 FN: 218.0
val Loss: 0.1237 Acc: 0.9561 Pre: 0.9624 Rec: 0.9714 F1: 0.9669
TP: 3020.0 TN: 1483.0 FP: 118.0 FN: 89.0
test Loss: 0.7886 Acc: 0.6400 Pre: 0.7264 Rec: 0.6681 F1: 0.6960
TP: 2156.0 TN: 1191.0 FP: 812.0 

train Loss: 0.2894 Acc: 0.8832 Pre: 0.9005 Rec: 0.9254 F1: 0.9128
TP: 2877.0 TN: 1283.0 FP: 318.0 FN: 232.0
val Loss: 0.1495 Acc: 0.9418 Pre: 0.9742 Rec: 0.9366 F1: 0.9551
TP: 2912.0 TN: 1524.0 FP: 77.0 FN: 197.0
test Loss: 0.8610 Acc: 0.6000 Pre: 0.7232 Rec: 0.5699 F1: 0.6374
TP: 1839.0 TN: 1299.0 FP: 704.0 FN: 1388.0

Epoch 22/24
----------
train Loss: 0.2879 Acc: 0.8851 Pre: 0.8983 Rec: 0.9315 F1: 0.9146
TP: 2896.0 TN: 1273.0 FP: 328.0 FN: 213.0
val Loss: 0.1435 Acc: 0.9463 Pre: 0.9789 Rec: 0.9389 F1: 0.9585
TP: 2919.0 TN: 1538.0 FP: 63.0 FN: 190.0
test Loss: 0.9617 Acc: 0.5665 Pre: 0.7353 Rec: 0.4648 F1: 0.5696
TP: 1500.0 TN: 1463.0 FP: 540.0 FN: 1727.0

Epoch 23/24
----------
train Loss: 0.2806 Acc: 0.8856 Pre: 0.8971 Rec: 0.9337 F1: 0.9151
TP: 2903.0 TN: 1268.0 FP: 333.0 FN: 206.0
val Loss: 0.3397 Acc: 0.9161 Pre: 0.9572 Rec: 0.9138 F1: 0.9350
TP: 2841.0 TN: 1474.0 FP: 127.0 FN: 268.0
test Loss: 0.8461 Acc: 0.6424 Pre: 0.7168 Rec: 0.6951 F1: 0.7058
TP: 2243.0 TN: 1117.0 FP: 886.0

train Loss: 0.2848 Acc: 0.8890 Pre: 0.9026 Rec: 0.9325 F1: 0.9173
TP: 2899.0 TN: 1288.0 FP: 313.0 FN: 210.0
val Loss: 0.1297 Acc: 0.9524 Pre: 0.9555 Rec: 0.9733 F1: 0.9643
TP: 3026.0 TN: 1460.0 FP: 141.0 FN: 83.0
test Loss: 0.8602 Acc: 0.6463 Pre: 0.6952 Rec: 0.7598 F1: 0.7261
TP: 2452.0 TN: 928.0 FP: 1075.0 FN: 775.0

Epoch 21/24
----------
train Loss: 0.2877 Acc: 0.8796 Pre: 0.8925 Rec: 0.9296 F1: 0.9107
TP: 2890.0 TN: 1253.0 FP: 348.0 FN: 219.0
val Loss: 0.1316 Acc: 0.9539 Pre: 0.9635 Rec: 0.9669 F1: 0.9652
TP: 3006.0 TN: 1487.0 FP: 114.0 FN: 103.0
test Loss: 0.8376 Acc: 0.6287 Pre: 0.6980 Rec: 0.7019 F1: 0.6999
TP: 2265.0 TN: 1023.0 FP: 980.0 FN: 962.0

Epoch 22/24
----------
train Loss: 0.2891 Acc: 0.8837 Pre: 0.8993 Rec: 0.9276 F1: 0.9132
TP: 2884.0 TN: 1278.0 FP: 323.0 FN: 225.0
val Loss: 0.1409 Acc: 0.9514 Pre: 0.9728 Rec: 0.9530 F1: 0.9628
TP: 2963.0 TN: 1518.0 FP: 83.0 FN: 146.0
test Loss: 0.8444 Acc: 0.6163 Pre: 0.7220 Rec: 0.6148 F1: 0.6641
TP: 1984.0 TN: 1239.0 FP: 764.0 F

# Smooth Results

In [15]:
def pos_negs(predictions, gt):
    correct = np.where(np.array(predictions) == np.array(gt), 1, 0)
    incorrect = np.where(np.array(predictions) == np.array(gt), 0, 1)
    
    tp = np.where(correct * np.where(predictions == np.array(1), 1, 0), 1, 0)
    tn = np.where(correct * np.where(predictions == np.array(0), 1, 0), 1, 0)
    fp = np.where(incorrect * np.where(predictions == np.array(1), 1, 0), 1, 0)
    fn = np.where(incorrect * np.where(predictions == np.array(0), 1, 0), 1, 0)
    
    return tp, tn, fp, fn

def acc_prf1(predictions, gt):
    tp, tn, fp, fn = pos_negs(predictions, gt)
    
    acc = (np.sum(tp) + np.sum(tn)) / (np.sum(tp) + np.sum(tn) + np.sum(fp) + np.sum(fn))
    precision = np.sum(tp) / (np.sum(tp) + np.sum(fp))
    recall = np.sum(tp) / (np.sum(tp) + np.sum(fn))
    f1 = 2 * precision * recall / (precision + recall)
    
    return acc, precision, recall, f1, np.sum(tp), np.sum(tn), np.sum(fp), np.sum(fn)

def smooth_predictions(preds, window_radius = 3):
    result = []
    for i in range(len(preds)):
        start = max(0, i - window_radius)
        end = min(len(preds), i + window_radius)
        window = preds[start:end]
        result += [max(window, key=window.count)]
    
    return result

In [None]:
path = 'models/transfer_learning_tutorial'

log_file_img = open(os.path.join(
    path, 'test_results_image_classifier.log'), 'w')
log_file_smoothed = open(os.path.join(
    path, 'test_results_smoothed.log'), 'w')

for seed in range(5):
    print(seed)
    model_ts = models.resnet50(pretrained=True)
    num_ftrs = model_ts.fc.in_features
    model_ts.fc = nn.Linear(num_ftrs, 1)
    model_ts.load_state_dict(torch.load(
        os.path.join(path, 'seed_{}.pth'.format(seed))
    ))
    model_ts.to(device)
    criterion = nn.BCEWithLogitsLoss()
    
    model = model_ts.eval()   # Set model to evaluate mode
    dl = dataloaders['test']
    dataset_size = dataset_sizes['test']
    
    running_corrects = 0
    true_positives = 0.
    true_negatives = 0.
    false_positives = 0.
    false_negatives = 0.

    i = 0

    predictions = []
    gt_labels = []
    # Iterate over data.
    for inputs, labels in dl:
        inputs = inputs.to(device)
        labels = labels.to(device).float()

        # forward
        # track history if only in train
        with torch.set_grad_enabled(False):
            outputs = model(inputs)
            preds = torch.where(
                outputs >= 0.,
                torch.tensor([1.]).to(device),
                torch.tensor([0.]).to(device))
            target = torch.where(
                labels >= 0.5,
                torch.tensor([1.]).to(device),
                torch.tensor([0.]).to(device)
            )

        predictions += preds.cpu().numpy().tolist()
        gt_labels += labels.cpu().numpy().tolist()

        # statistics
        label_vals = torch.where(
            labels >= 0.5,
            torch.tensor([1.]).to(device),
            torch.tensor([0.]).to(device)
        )
        correct = preds.view(label_vals.shape) == label_vals.data
        running_corrects += torch.sum(correct)

        true_positives += torch.sum(
            torch.where(
                (correct == 1.) * (label_vals == 1.),
                torch.tensor([1.]).to(device),
                torch.tensor([0.]).to(device))
        )
        true_negatives += torch.sum(
            torch.where(
                (correct == 1.) * (label_vals == 0.),
                torch.tensor([1.]).to(device),
                torch.tensor([0.]).to(device))
        )
        false_positives += torch.sum(
            torch.where(
                (correct == 0.) * (label_vals == 0.),
                torch.tensor([1.]).to(device),
                torch.tensor([0.]).to(device))
        )
        false_negatives += torch.sum(
            torch.where(
                (correct == 0.) * (label_vals == 1.),
                torch.tensor([1.]).to(device),
                torch.tensor([0.]).to(device))
        )

        num_fp = torch.sum(
        torch.where(
            (correct == 0.) * (label_vals == 0.),
            torch.tensor([1.]).to(device),
            torch.tensor([0.]).to(device))
        )

    #     if num_fp > 0:
    #         print(num_fp)
    #         print(torch.where(
    #             (correct == 0.) * (label_vals == 0.),
    #             torch.tensor([1.]).to(device),
    #             torch.tensor([0.]).to(device)))
    #         print(i)
    #         out = torchvision.utils.make_grid(inputs)
    # #         imshow(out, title=preds.cpu().numpy().tolist())
    #         imshow(out)

        i += 1

    epoch_acc = running_corrects.double() / dataset_size
    epoch_pre = safe_divide(true_positives, (true_positives + false_positives))
    epoch_recall = safe_divide(true_positives, (true_positives + false_negatives))
    epoch_f1 = safe_divide(2 * epoch_pre * epoch_recall, (epoch_pre + epoch_recall))

    print('Acc: {:.4f} Pre: {:.4f} Rec: {:.4f} F1: {:.4f}'.format(
        epoch_acc, epoch_pre, epoch_recall, epoch_f1))
    print('TP: {} TN: {} FP: {} FN: {}'.format(
        true_positives.data, true_negatives.data, false_positives.data, 
        false_negatives.data))
    
    log_file_img.write('Seed: {0}\t'
                   'Acc: {acc:.4f}\t'
                   'Pre: {pre:.4f}\t'
                   'Rec: {rec:.4f}\t'
                   'F1: {f1:.4f}\t'
                   'TP: {tp} '
                   'TN: {tn} '
                   'FP: {fp} '
                   'FN: {fn}\n'.format(
                       seed,
                       acc=epoch_acc, pre=epoch_pre, rec=epoch_recall,
                       f1=epoch_f1, tp=int(true_positives.data),
                       tn=int(true_negatives.data),
                       fp=int(false_positives.data), fn=int(false_negatives.data)
                   ))
    log_file_img.flush()

    predictions = [p[0] for p in predictions]

    smoothed_preds = smooth_predictions(predictions, 3)

    print("Smoothed stats:")
    print(acc_prf1(smoothed_preds, gt_labels))
    
    sm_acc, sm_pre, sm_rec, sm_f1, sm_tp, sm_tn, sm_fp, sm_fn = acc_prf1(
        smoothed_preds, gt_labels)
    
    log_file_smoothed.write('Seed: {0}\t'
                   'Acc: {acc:.4f}\t'
                   'Pre: {pre:.4f}\t'
                   'Rec: {rec:.4f}\t'
                   'F1: {f1:.4f}\t'
                   'TP: {tp} '
                   'TN: {tn} '
                   'FP: {fp} '
                   'FN: {fn}\n'.format(
                       seed,
                       acc=sm_acc, pre=sm_pre, rec=sm_rec,
                       f1=sm_f1, tp=sm_tp,
                       tn=sm_tn,
                       fp=sm_fp, fn=sm_fn
                   ))
    log_file_smoothed.flush()

# Old

In [16]:
path = 'models/transfer_learning_tutorial'
model_ts = models.resnet50(pretrained=True)
num_ftrs = model_ts.fc.in_features
model_ts.fc = nn.Linear(num_ftrs, 1)
model_ts.load_state_dict(torch.load(
    os.path.join(path, 'seed_0.pth')
))
model_ts.to(device)
criterion = nn.BCEWithLogitsLoss()

In [17]:
model = model_ts.eval()   # Set model to evaluate mode
dl = dataloaders['test']
dataset_size = dataset_sizes['test']

In [40]:
def pos_negs(predictions, gt):
    correct = np.where(np.array(predictions) == np.array(gt), 1, 0)
    incorrect = np.where(np.array(predictions) == np.array(gt), 0, 1)
    
    tp = np.where(correct * np.where(predictions == np.array(1), 1, 0), 1, 0)
    tn = np.where(correct * np.where(predictions == np.array(0), 1, 0), 1, 0)
    fp = np.where(incorrect * np.where(predictions == np.array(1), 1, 0), 1, 0)
    fn = np.where(incorrect * np.where(predictions == np.array(0), 1, 0), 1, 0)
    
    return tp, tn, fp, fn

def acc_prf1(predictions, gt):
    tp, tn, fp, fn = pos_negs(predictions, gt)
    
    acc = (np.sum(tp) + np.sum(tn)) / (np.sum(tp) + np.sum(tn) + np.sum(fp) + np.sum(fn))
    precision = np.sum(tp) / (np.sum(tp) + np.sum(fp))
    recall = np.sum(tp) / (np.sum(tp) + np.sum(fn))
    f1 = 2 * precision * recall / (precision + recall)
    
    return acc, precision, recall, f1, np.sum(tp), np.sum(tn), np.sum(fp), np.sum(fn)

def smooth_predictions(preds, window_radius = 3):
    result = []
    for i in range(len(preds)):
        start = max(0, i - window_radius)
        end = min(len(preds), i + window_radius)
        window = preds[start:end]
        result += max(window, key=window.count)
    
    return result

In [19]:
running_corrects = 0
true_positives = 0.
true_negatives = 0.
false_positives = 0.
false_negatives = 0.

i = 0

predictions = []
gt_labels = []
# Iterate over data.
for inputs, labels in dl:
    inputs = inputs.to(device)
    labels = labels.to(device).float()

    # forward
    # track history if only in train
    with torch.set_grad_enabled(False):
        outputs = model(inputs)
        preds = torch.where(
            outputs >= 0.,
            torch.tensor([1.]).to(device),
            torch.tensor([0.]).to(device))
        target = torch.where(
            labels >= 0.5,
            torch.tensor([1.]).to(device),
            torch.tensor([0.]).to(device)
        )
        
    predictions += preds.cpu().numpy().tolist()
    gt_labels += labels.cpu().numpy().tolist()

    # statistics
    label_vals = torch.where(
        labels >= 0.5,
        torch.tensor([1.]).to(device),
        torch.tensor([0.]).to(device)
    )
    correct = preds.view(label_vals.shape) == label_vals.data
    running_corrects += torch.sum(correct)

    true_positives += torch.sum(
        torch.where(
            (correct == 1.) * (label_vals == 1.),
            torch.tensor([1.]).to(device),
            torch.tensor([0.]).to(device))
    )
    true_negatives += torch.sum(
        torch.where(
            (correct == 1.) * (label_vals == 0.),
            torch.tensor([1.]).to(device),
            torch.tensor([0.]).to(device))
    )
    false_positives += torch.sum(
        torch.where(
            (correct == 0.) * (label_vals == 0.),
            torch.tensor([1.]).to(device),
            torch.tensor([0.]).to(device))
    )
    false_negatives += torch.sum(
        torch.where(
            (correct == 0.) * (label_vals == 1.),
            torch.tensor([1.]).to(device),
            torch.tensor([0.]).to(device))
    )
    
    num_fp = torch.sum(
    torch.where(
        (correct == 0.) * (label_vals == 0.),
        torch.tensor([1.]).to(device),
        torch.tensor([0.]).to(device))
    )

#     if num_fp > 0:
#         print(num_fp)
#         print(torch.where(
#             (correct == 0.) * (label_vals == 0.),
#             torch.tensor([1.]).to(device),
#             torch.tensor([0.]).to(device)))
#         print(i)
#         out = torchvision.utils.make_grid(inputs)
# #         imshow(out, title=preds.cpu().numpy().tolist())
#         imshow(out)
    
    i += 1

epoch_acc = running_corrects.double() / dataset_size
epoch_pre = safe_divide(true_positives, (true_positives + false_positives))
epoch_recall = safe_divide(true_positives, (true_positives + false_negatives))
epoch_f1 = safe_divide(2 * epoch_pre * epoch_recall, (epoch_pre + epoch_recall))

print('Acc: {:.4f} Pre: {:.4f} Rec: {:.4f} F1: {:.4f}'.format(
    epoch_acc, epoch_pre, epoch_recall, epoch_f1))
print('TP: {} TN: {} FP: {} FN: {}'.format(
    true_positives.data, true_negatives.data, false_positives.data, false_negatives.data))

Acc: 0.6117 Pre: 0.7095 Rec: 0.6275 F1: 0.6660
TP: 2025.0 TN: 1174.0 FP: 829.0 FN: 1202.0


In [41]:
smoothed_preds = smooth_predictions(predictions, 3)

print("Smoothed stats:")
print(acc_prf1(smoothed_preds, gt_labels))

Smoothed stats:
(0.6252390057361377, 0.7205012182387748, 0.641462658816238, 0.6786885245901639, 2070, 1200, 803, 1157)
