In [None]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [None]:
MODEL_NAME = 'qa_accept_cogito_skips_03-04-2020_stratified'
MODEL_PATH = os.path.join('/root/data/sid/skip_classifier_checkpoints/', MODEL_NAME)
SPLIT_PATH = os.path.join('/root/data/sid/skip_classifier_datasets/splits', MODEL_NAME + '_splits.json')

In [None]:
import torch

torch.cuda.is_available()

In [None]:
device = 0
torch.cuda.set_device(device)

### Pick Best Model Using Validation Metrics

In [None]:
import json

best_epoch ={'precision':(0, None, None), 'recall': (0, None, None), 'auc': (0, None, None)}

for epoch in os.listdir(MODEL_PATH):
    metrics_path = os.path.join(MODEL_PATH, epoch, 'val', 'metrics.json')
    metrics = json.load(open(metrics_path))['acc']
    for m in best_epoch:
        if metrics[m] > best_epoch[m][0]:
            best_epoch[m] = (metrics[m], metrics, epoch)

In [None]:
best_epoch

In [None]:
BEST_EPOCH = best_epoch['auc']

### Reconstruct Test Set

In [None]:
import json

splits = json.load(open(SPLIT_PATH))
splits.keys()

In [None]:
test_set = [splits['original'][i] for i in splits['test_indices']]
len(test_set)

In [None]:
all_image_files = [x[0] for x in test_set]
all_metadata_files = [s.replace('_crop.jpg', '_metadata.json') for s in all_image_files]
all_metadata_data = []
for i, (metadata_path, image_path) in enumerate(zip(all_metadata_files, all_image_files)):
    if i % 100 == 0:
        print(i)
    assert metadata_path.split('/')[-1].split('_')[0] == image_path.split('/')[-1].split('_')[0], (metadata_path, image_path)
    metadata = json.load(open(metadata_path))
    metadata['local_image_path'] = image_path
    all_metadata_data.append(metadata)

In [None]:
import pandas as pd

eval_dataset = pd.DataFrame.from_dict(all_metadata_data)
eval_dataset

In [None]:
eval_dataset = eval_dataset.drop_duplicates('left_crop_url')

In [None]:
eval_dataset

In [None]:
accepts = eval_dataset[eval_dataset['skip_reasons'].isnull()]
accepts.shape

In [None]:
accepts.pen_id.value_counts()

In [None]:
skips = eval_dataset[eval_dataset['skip_reasons'].notnull()]
skips.shape

In [None]:
skips = skips.groupby('pen_id', group_keys=False).apply(lambda x: x.sample(min(len(x), 1000)))
skips.pen_id.value_counts()

In [None]:
skips.shape

In [None]:
eval_set = pd.concat([skips, accepts])
eval_set = eval_set.sample(frac=1)
print(eval_set.pen_id.value_counts())
eval_set.shape

In [None]:
import json

useful_labels = [
    'BLURRY',
    'BAD_CROP',
    'BAD_ORIENTATION',
    'OBSTRUCTION',
    'TOO_DARK'
]

eval_dataset['skip_reasons'] = eval_dataset['skip_reasons'].str.replace("'", "\"")
eval_dataset['skip_reasons'] = eval_dataset['skip_reasons'].apply(lambda l: l if l is None else json.loads(l))

In [None]:
for i, label in enumerate(useful_labels):
    eval_set[f'{label}'] = eval_set['skip_reasons'].apply(lambda l: False if l is None else (label in l))

### Get Model Predictions

In [None]:
BEST_EPOCH

In [None]:
from model import ImageClassifier
from train import ACCEPT_LABEL, SKIP_LABEL

path = os.path.join(MODEL_PATH, BEST_EPOCH[2], 'val', 'model.pt')
model = ImageClassifier([ACCEPT_LABEL, SKIP_LABEL], device=0, savename=None)
model.load_state_dict(torch.load(path))
model.to(device)

In [None]:
from torchvision import get_image_backend

get_image_backend()

In [None]:
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import cv2

IMG_EXTENSIONS = ('.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp')

def albumentations_loader(file_path):
    # Read an image with OpenCV
    image = cv2.imread(file_path)

    # By default OpenCV uses BGR color space for color images,
    # so we need to convert the image to RGB color space.
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    return image

class ImageDataset(Dataset):
    """"""
    def __init__(self, classes, samples, loader=albumentations_loader, extensions=None, transform=None,
                 target_transform=None, is_valid_file=None):
        if len(samples) == 0:
            raise (RuntimeError("Found 0 files in subfolders of: " + self.root + "\n"
                                "Supported extensions are: " + ",".join(extensions)))

        self.loader = loader
        self.transform = transform
        self.extensions = extensions

        self.classes = classes
        self.class_to_idx = {c: classes.index(c) for c in classes}
        self.samples = samples
        self.targets = [s[1] for s in samples]
        
    def __getitem__(self, index):
        """
        Args:
            index (int): Index

        Returns:
            tuple: (sample, target) where target is class_index of the target class.
        """
        path, target = self.samples[index]
        sample = albumentations_loader(path)
        if self.transform is not None:
            sample = self.transform(image=sample)['image']

        return sample, target
    
    def __len__(self):
        return len(self.samples)

In [None]:
classes = [ACCEPT_LABEL, SKIP_LABEL]
eval_set['paths'] = eval_set['local_image_path']
eval_set['labels'] = eval_set['skip_reasons'].notnull().apply(int)
samples = [(path, label) for path, label in zip(
            eval_set['paths'], eval_set['labels'])]
len(samples)

In [None]:
from loader import TRANSFORMS

dataset = ImageDataset(classes, samples, transform=TRANSFORMS['pad'])
example = dataset[0]
print(example)
print(example[0].shape)

In [None]:
import torch

loader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=False, num_workers=1)
loader

In [None]:
model.cuda()

In [None]:
all_labels = None
all_outputs = None

with torch.no_grad():
    for i, (inputs, tgts) in enumerate(loader):
        cuda_inputs = inputs.to(device)
        outputs = model(cuda_inputs)
        outputs = outputs.cpu()
        if all_outputs is None:
            all_outputs = outputs
            all_labels = tgts
        else:
            all_outputs = torch.cat([all_outputs, outputs])
            all_labels = torch.cat([all_labels, tgts])
        print(f'batch:{i}...')

In [None]:
print(all_outputs.shape)
print(all_labels.shape)

In [None]:
all_outputs = all_outputs.detach().numpy()
all_labels = all_labels.detach().numpy()

### Evaluation results

In [None]:
eval_set = eval_set.iloc[:all_outputs.shape[0]]

In [None]:
eval_set.shape

In [None]:
assert len(eval_set) == all_outputs.shape[0]

eval_set['model_outputs'] = all_outputs[:, 0]

In [None]:
eval_set['loaded_labels'] = (all_labels == 0)

In [None]:
#assert (eval_set['labels'] == eval_set['loaded_labels']).sum() == len(eval_set), eval_set[['labels', 'loaded_labels']]

In [None]:
import numpy as np

eval_set['model_preds'] = eval_set['model_outputs'] > 0.5

In [None]:
from sklearn.metrics import roc_curve
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import roc_auc_score

def plot_roc(fpr, tpr, auc, pen_id, skip_reason, ax):
    lw = 2
    ax.plot(fpr, tpr, color='darkorange',
             lw=lw, label='ROC curve (area = %0.2f)' % auc)
    ax.plot([0, 1], [0, 1], color='navy', lw=lw, linestyle='--')
    ax.set_xlim([0.0, 1.0])
    ax.set_ylim([0.0, 1.05])
    ax.set_xlabel('False Positive Rate (skiprate)')
    ax.set_ylabel('Recall (KPI)')
    ax.set_title(f'ROC Curve Pen:{pen_id} SkipReason:{skip_reason}', size=20)
    ax.legend(loc="lower right")

def evaluate(eval_set, pen_id, skip_reason):
    results = dict()
    results['pen_id'] = pen_id
    results['skip_reason'] = skip_reason
    results['n'] = len(eval_set)
    if eval_set['model_preds'].sum():
        results['prec'] = precision_score(eval_set['loaded_labels'], eval_set['model_preds'])
    else:
        results['prec'] = None
    if eval_set['loaded_labels'].sum():
        results['rec'] = recall_score(eval_set['loaded_labels'], eval_set['model_preds'])
    else:
        results['rec'] = None
    try:
        results['auc'] = roc_auc_score(eval_set['loaded_labels'], eval_set['model_outputs'])
        fpr, tpr, thresholds = roc_curve(eval_set['loaded_labels'], eval_set['model_outputs'])
    except:
        results['auc'] = None
        fpr, tpr, thresholds = None, None, None
    return results,  (fpr, tpr, thresholds)

In [None]:
import matplotlib.pyplot as plt

results = []
all_pens = ['overall']
all_labels = ['overall']
nrows = len(all_pens)
ncols = len(all_labels)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10*ncols, 10*nrows))

pen_id, skip_reason='overall', 'overall'
if pen_id != 'overall':
    this_eval_set = eval_set[eval_set['pen_id'] == pen_id]
else:
    this_eval_set = eval_set
    
if skip_reason != 'overall' :
    skipped_with_this_reason = this_eval_set[skip_reason]
    accepted = this_eval_set['loaded_labels']
    #print(skipped_with_this_reason)
    #print(accepted)
    this_eval_set = this_eval_set[skipped_with_this_reason | accepted]
else:
    this_eval_set = this_eval_set
    
result, (fpr, tpr, thresholds) = evaluate(this_eval_set, pen_id, skip_reason)
results.append(result)
if fpr is not None:
    plot_roc(fpr, tpr, result['auc'], pen_id, skip_reason, axes)
    
out = pd.DataFrame.from_dict(results)
out.set_index(['pen_id', 'skip_reason'], inplace=True)
out.T

In [None]:
import matplotlib.pyplot as plt

results = []
all_pens = ['overall'] + list(eval_set['pen_id'].unique())
all_labels = ['overall'] + useful_labels
nrows = len(all_pens)
ncols = 1 + len(useful_labels)
fig, axes = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10*ncols, 10*nrows))


for i, pen_id in enumerate(all_pens):
    for j, skip_reason in enumerate(all_labels):
        if pen_id != 'overall':
            this_eval_set = eval_set[eval_set['pen_id'] == pen_id]
        else:
            this_eval_set = eval_set
            
        if skip_reason != 'overall' :
            skipped_with_this_reason = this_eval_set[skip_reason]
            accepted = this_eval_set['loaded_labels']
            #print(skipped_with_this_reason)
            #print(accepted)
            this_eval_set = this_eval_set[skipped_with_this_reason | accepted]
        else:
            this_eval_set = this_eval_set
            
        result, (fpr, tpr, thresholds) = evaluate(this_eval_set, pen_id, skip_reason)
        results.append(result)
        if fpr is not None:
            plot_roc(fpr, tpr, result['auc'], pen_id, skip_reason, axes[i][j])
    
out = pd.DataFrame.from_dict(results)
out.set_index(['pen_id', 'skip_reason'], inplace=True)
out.T

In [None]:
out