In [None]:
COLAB = True if 'google.colab' in str(get_ipython()) else False

if COLAB:
    !rm -rf interview
    !git clone https://github.com/lukoshkin/interview.git
    !mv -n interview/CV/* .
    !unzip -nq EyesDataset.zip

In [None]:
if COLAB:
    %matplotlib inline
else:
    %matplotlib notebook

import random
import numpy as np
import matplotlib.pyplot as plt

from pathlib import Path
from sklearn.cluster import KMeans
from sklearn.model_selection import train_test_split
from IPython.display import clear_output

import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as torch_data

from src.model.architecture import SimpleNet
from src.model.scoring import LabelSmoothedBCEwLL, ROC_EER
from src.model.utils import train_open_eyes_clf
from src.data.loaders import BatchLoader, MRLEyesData
from src.data.utils import mend_labels

seed = 0
torch.manual_seed(seed)
np.random.seed(seed)  # likely sklearn is based on numpy
random.seed(seed)  # not sure there are python functions
                   # leveraging 'random' lib, just in case

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# First Attempts

In [None]:
! unzip -qn EyesDataset.zip

dset = []
for file in Path('EyesDataset').iterdir():
    img = plt.imread(str(file), )
    dset.append(img)
    
dset = np.array(dset, dtype='f4') / 255

In [None]:
dset_flat = dset.reshape(len(dset), -1)
kmeans = KMeans(n_clusters=2).fit(dset_flat)
dist = kmeans.transform(dset_flat)

plt.figure()
plt.plot(np.sort(dist[:, 0]));
plt.plot(np.sort(dist[:, 1]));

labels = kmeans.labels_ 
print(labels.sum()/labels.size)

In [None]:
n_samples = 1800
ids1 = np.argsort(dist[:,1] - dist[:,0])[:n_samples]
ids0 = np.argsort(dist[:,0] - dist[:,1])[:n_samples]

plt.figure()
plt.plot(labels[ids1]);
plt.plot(labels[ids0]);

In [None]:
mask1 = labels.astype('bool')
plt.figure()
plt.imshow(dset[~mask1][32], cmap='gray');

In [None]:
X = dset[:, None]
y = labels

X_train, X_test, y_train, y_test = map(torch.Tensor,
    train_test_split(X, y, test_size=.2))

In [None]:
label_smoothing = True

net = SimpleNet().to(device)
opt = optim.Adam(net.parameters(), lr=3e-3)

criterion = nn.BCEWithLogitsLoss()
if label_smoothing:
    criterion = LabelSmoothedBCEwLL(.4)

train_bl = BatchLoader(X_train, y_train, 40)
val_bl = BatchLoader(X_test, y_test, 100)

In [None]:
best_score = float('inf')
best_score, state = train_open_eyes_clf(
    net, criterion, opt, train_bl, val_bl, device=device,
    val_criterion=ROC_EER, epochs=10, continue_val_score=best_score)

if state is not None:
    torch.save(state, 'dummy.pth')

In [None]:
if Path('dummy.pth').exists():
    net = SimpleNet()
    net.load_state_dict(
        torch.load('dummy.pth', map_location=device))

In [None]:
y_pred = net.predict(X_test)
eer, fpr, tpr = ROC_EER(y_test, y_pred, return_roc=True)

In [None]:
x = np.linspace(0, 1, num=len(fpr))
print('EER =', eer)

plt.figure()
plt.plot(x, 1-x);
plt.plot(fpr, tpr);

# Third-Party Dataset

In [None]:
! wget -nc http://mrl.cs.vsb.cz/data/eyedataset/mrlEyes_2018_01.zip
! unzip -nq mrlEyes_2018_01.zip
fnames = list(Path('mrlEyes_2018_01').rglob('*.png'))
train_files, test_files = train_test_split(fnames, test_size=.2)

In [None]:
train_ds = MRLEyesData(fnames=train_files)
val_ds = MRLEyesData(fnames=test_files)

num_workers = !lscpu | grep 'CPU(s)' | head -1 | tr -s ' ' | cut -d ' ' -f2
num_workers = int(num_workers[0])

train_bl = torch_data.DataLoader(
    train_ds, batch_size=100, shuffle=True, num_workers=num_workers)
val_bl = torch_data.DataLoader(
    val_ds, batch_size=100, shuffle=True, num_workers=num_workers)

print(sum(train_ds.targets) / len(train_ds))

In [None]:
#     X, y = next(iter(dloader))
#     np.savez('mrleye_valset', X=X, y=y)

In [None]:
#     data = np.load('mrleye_valset.npz')
#     X_test = torch.Tensor(data['X'])
#     y_test = torch.Tensor(data['y'])

In [None]:
label_smoothing = False

net = SimpleNet().to(device)
opt = optim.Adam(net.parameters(), lr=3e-3)

criterion = nn.BCEWithLogitsLoss()
if label_smoothing:
    criterion = LabelSmoothedBCEwLL(.2)

In [None]:
best_score = float('inf')
best_score, state = train_open_eyes_clf(
    net, criterion, opt, train_bl, val_bl, device=device,
    val_criterion=ROC_EER, epochs=1, continue_val_score=best_score)

if state is not None:
    torch.save(state, 'mrl_eyes_weights.pth')

In [None]:
if Path('mrl_eyes_weights.pth').exists():
    net = SimpleNet()
    net.load_state_dict(
        torch.load('mrl_eyes_weights.pth', map_location=device))

In [None]:
probs = net.predict(X_test.to(device)).detach().cpu()

In [None]:
# just to check that all is good
mend_labels(X_test.squeeze(1), (probs > .5).long());