In [1]:
from __future__ import print_function, division

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import numpy as np
import torchvision
from torchvision import datasets, models, transforms
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import time
import os
import copy
from PIL import Image
from tqdm import tqdm
from scipy import stats

In [2]:
device = torch.device('cuda:0')

In [3]:
data_transforms = {
    'train': transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
    'val': transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ]),
}


In [4]:
class ImageDataset(Dataset):
    def __init__(self, paths, labels, transform=None, stride = 1, max_size=None):
        self.paths = paths[::stride]
        if max_size is not None:
            self.paths = self.paths[:max_size]
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.paths)

    def __getitem__(self, idx):
        img_name = self.paths[idx]
        img = Image.open(img_name).convert('RGB')
        
        img_tensor = self.transform(img)

        return img_tensor, self.labels[idx]

In [5]:
train_file = 'conversation_export/data/train.txt'
val_file = 'conversation_export/data/val.txt'
test_file = 'conversation_export/data/test.txt'

In [6]:
def read_file(filename):
    paths = []
    labels = []
    with open(filename, 'r') as f:
        for line in f.readlines():
            video, image, label = line.split(' ')
            paths.append('conversation_export/images/{}/{:04d}.jpg'.format(video, int(image)))
            labels.append(int(label))
            
    return paths, labels

In [7]:
train_paths, Y_train = read_file(train_file)
val_paths, Y_val = read_file(val_file)
test_paths, Y_test = read_file(test_file)

In [15]:
image_datasets = {
    'train': ImageDataset(train_paths, Y_train, transform = data_transforms['train'],
                          stride=1, max_size=None),
    'val_train': ImageDataset(val_paths, Y_val, transform = data_transforms['train'],
                       stride=1, max_size=None),
    'val': ImageDataset(val_paths, Y_val, transform = data_transforms['val'],
                       stride=1, max_size=None),
    'test': ImageDataset(test_paths, Y_test, transform = data_transforms['val'],
                        stride=1, max_size=None)
}

In [16]:
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x],
                                              batch_size=4,
                                              shuffle='train' in x,
                                              num_workers=4,
                                              pin_memory=True)
                                              for x in image_datasets}

In [17]:
dataset_sizes = {
    x: len(image_datasets[x])
    for x in image_datasets
}

In [18]:
dataset_sizes

{'train': 3673, 'val_train': 550, 'val': 550, 'test': 657}

In [19]:
def safe_divide(a, b):
    return a / b if b > 0 else 0

In [20]:
def train_model(model, criterion, optimizer, scheduler, train_dl, val_dl, test_dl=None,
                num_epochs=25, return_best=False, verbose=True, log_file=None):
    print(train_dl, val_dl, test_dl)
    since = time.time()

    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    best_epoch = 0
    best_test_acc = 0.0
    best_f1 = 0.0
    best_precision = 0.0
    best_recall = 0.0
    
    phases = ['train', 'val', 'test'] if test_dl is not None else ['train', 'val']

    for epoch in range(num_epochs):
        print('Epoch {}/{}'.format(epoch, num_epochs - 1))
        print('-' * 10)

        # Each epoch has a training and validation phase
        for phase in phases:
            if phase == 'train':
                scheduler.step()
                model.train()  # Set model to training mode
                dl = dataloaders[train_dl]
                dataset_size = dataset_sizes[train_dl]
                
            elif phase == 'val':
                model.eval()   # Set model to evaluate mode
                dl = dataloaders[val_dl]
                dataset_size = dataset_sizes[val_dl]
            else:
                model.eval()
                dl = dataloaders[test_dl]
                dataset_size = dataset_sizes[test_dl]

            running_loss = 0.0
            running_corrects = 0
            true_positives = 0.
            true_negatives = 0.
            false_positives = 0.
            false_negatives = 0.

            # Iterate over data.
            for inputs, labels in dl:
                inputs = inputs.to(device)
                labels = labels.to(device).float()

                # zero the parameter gradients
                optimizer.zero_grad()

                # forward
                # track history if only in train
                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    preds = torch.where(
                        outputs >= 0.,
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                    target = torch.where(
                        labels >= 0.5,
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device)
                    )
                    loss = criterion(outputs.view(target.shape), target)

                    # backward + optimize only if in training phase
                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                # statistics
                running_loss += loss.item() * inputs.size(0)
                label_vals = torch.where(
                    labels >= 0.5,
                    torch.tensor([1.]).to(device),
                    torch.tensor([0.]).to(device)
                )
                correct = preds.view(label_vals.shape) == label_vals.data
                running_corrects += torch.sum(correct)
                
                true_positives += torch.sum(
                    torch.where(
                        (correct == 1.) * (label_vals == 1.),
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                )
                true_negatives += torch.sum(
                    torch.where(
                        (correct == 1.) * (label_vals == 0.),
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                )
                false_positives += torch.sum(
                    torch.where(
                        (correct == 0.) * (label_vals == 0.),
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                )
                false_negatives += torch.sum(
                    torch.where(
                        (correct == 0.) * (label_vals == 1.),
                        torch.tensor([1.]).to(device),
                        torch.tensor([0.]).to(device))
                )
            
            epoch_loss = running_loss / dataset_size
            epoch_acc = running_corrects.double() / dataset_size
            epoch_pre = safe_divide(true_positives, (true_positives + false_positives))
            epoch_recall = safe_divide(true_positives, (true_positives + false_negatives))
            epoch_f1 = safe_divide(2 * epoch_pre * epoch_recall, (epoch_pre + epoch_recall))

            if verbose:
                print('{} Loss: {:.4f} Acc: {:.4f} Pre: {:.4f} Rec: {:.4f} F1: {:.4f}'.format(
                    phase, epoch_loss, epoch_acc, epoch_pre, epoch_recall, epoch_f1))
                print('TP: {} TN: {} FP: {} FN: {}'.format(
                    true_positives.data, true_negatives.data, false_positives.data, false_negatives.data))
            if log_file is not None:
                log_file.write('Phase: {0}\t'
                               'Epoch: [{1}/{2}]\t'
                               'Loss: {loss_c:.4f}\t'
                               'Acc: {acc:.4f}\t'
                               'Pre: {pre:.4f}\t'
                               'Rec: {rec:.4f}\t'
                               'F1: {f1:.4f}\t'
                               'TP: {tp} '
                               'TN: {tn} '
                               'FP: {fp} '
                               'FN: {fn}\n'.format(
                                   phase, epoch + 1, num_epochs, loss_c=epoch_loss,
                                   acc=epoch_acc, pre=epoch_pre, rec=epoch_recall,
                                   f1=epoch_f1, tp=int(true_positives.data), tn=int(true_negatives.data),
                                   fp=int(false_positives.data), fn=int(false_negatives.data)
                               ))
                log_file.flush()

            # deep copy the model
            if phase == 'val' and epoch_f1 > best_f1:
                best_acc = epoch_acc
                best_f1 = epoch_f1
                best_precision = epoch_pre
                best_recall = epoch_recall
                best_epoch = epoch
                best_model_wts = copy.deepcopy(model.state_dict())
            if phase == 'test' and best_epoch == epoch:
                best_test_acc = epoch_acc

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    
    if return_best:
        print('Best epoch: {}'.format(best_epoch))
        print('Best val Acc: {:4f}'.format(best_acc))
        print('Best val Pre: {:4f}'.format(best_precision))
        print('Best val Rec: {:4f}'.format(best_recall))
        print('Best val F1: {:4f}'.format(best_f1))
        print('Test Acc: {:4f}'.format(best_test_acc))

        # load best model weights
        model.load_state_dict(best_model_wts)
    return model

In [21]:
path = 'models/transfer_learning_tutorial'
for seed in range(1):
    torch.manual_seed(seed)
    model_ts = models.resnet18(pretrained=True)
    num_ftrs = model_ts.fc.in_features
    model_ts.fc = nn.Linear(num_ftrs, 1)

    model_ts = model_ts.to(device)

#     criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([10.]).to(device))
#     criterion = nn.BCEWithLogitsLoss(pos_weight=torch.tensor([1.]).to(device))
#     criterion = nn.CrossEntropyLoss()
    criterion = nn.BCEWithLogitsLoss().to(device)

    # Observe that all parameters are being optimized
    optimizer_ts = optim.SGD(model_ts.parameters(), lr=0.001, momentum=0.9)

    # Decay LR by a factor of 0.1 every 7 epochs
    exp_lr_scheduler_ts = lr_scheduler.StepLR(optimizer_ts, step_size=7, gamma=0.1)
    
    if not os.path.exists(path):
        os.makedirs(path)
    with open(os.path.join(path, 'seed_{}.log'.format(seed)), 'w') as log_file:
        model_ts = train_model(model_ts, criterion, optimizer_ts, exp_lr_scheduler_ts,
                               'val_train', 'val', test_dl='test', num_epochs=25,
                               verbose=True, log_file=log_file, return_best=False)
        torch.save(model_ts.state_dict(), os.path.join(path, 'seed_{}.pth'.format(seed)))

val_train val test
Epoch 0/24
----------
train Loss: 0.7291 Acc: 0.6164 Pre: 0.6685 Rec: 0.7305 F1: 0.6981
TP: 244.0 TN: 95.0 FP: 121.0 FN: 90.0
val Loss: 0.8596 Acc: 0.5545 Pre: 0.8938 Rec: 0.3024 F1: 0.4519
TP: 101.0 TN: 204.0 FP: 12.0 FN: 233.0
test Loss: 0.9269 Acc: 0.5160 Pre: 0.6538 Rec: 0.2374 F1: 0.3484
TP: 85.0 TN: 254.0 FP: 45.0 FN: 273.0

Epoch 1/24
----------
train Loss: 0.6511 Acc: 0.6691 Pre: 0.7099 Rec: 0.7695 F1: 0.7385
TP: 257.0 TN: 111.0 FP: 105.0 FN: 77.0
val Loss: 0.4896 Acc: 0.7527 Pre: 0.8133 Rec: 0.7695 F1: 0.7908
TP: 257.0 TN: 157.0 FP: 59.0 FN: 77.0
test Loss: 0.7332 Acc: 0.5875 Pre: 0.6160 Rec: 0.6453 F1: 0.6303
TP: 231.0 TN: 155.0 FP: 144.0 FN: 127.0

Epoch 2/24
----------
train Loss: 0.6407 Acc: 0.6636 Pre: 0.7110 Rec: 0.7515 F1: 0.7307
TP: 251.0 TN: 114.0 FP: 102.0 FN: 83.0
val Loss: 0.4556 Acc: 0.7945 Pre: 0.7576 Rec: 0.9731 F1: 0.8519
TP: 325.0 TN: 112.0 FP: 104.0 FN: 9.0
test Loss: 0.8876 Acc: 0.5495 Pre: 0.5527 Rec: 0.9078 F1: 0.6871
TP: 325.0 TN: 36.0 

test Loss: 0.7590 Acc: 0.6088 Pre: 0.6120 Rec: 0.7709 F1: 0.6823
TP: 276.0 TN: 124.0 FP: 175.0 FN: 82.0

Training complete in 11m 47s
