In [1]:
cd ../

/mnt/NVME1TB/Projects/people-with-glasses-classifier


In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import os.path as osp
import pandas as pd
import numpy as np

import torchvision as tv
import torch
from torch import nn

from modules.comp_tools import preprocessing_fn, ClsDataset
from torch.utils.data import DataLoader as BaseDataLoader

from modules.mobilenetv2 import MobileNetV2
from modules.mobilenetv3 import mobilenetv3_large

from tqdm.auto import tqdm
from sklearn.metrics import f1_score, accuracy_score, classification_report
import time

In [4]:
TRAIN_FOLDS = (0, 1, 2)
VALID_FOLDS = (3,)
TEST_FOLDS = (4,)
# CONTINUE = 'logs/mobilenetv3-binary/checkpoints/last.pth'
CONTINUE = 'logs/mobilenetv2-adam-binary/checkpoints/last.pth'

MODE = 'multiclass'
BINARY= True
ACTIVATION = 'sigmoid'

BATCH_SIZE = 256
TRAIN_IMAGES = 'data/crops/'

In [5]:
df = pd.read_csv('data/all.csv')
test_df = df[df.fold_num.isin(TEST_FOLDS)]
valid_df = df[df.fold_num.isin(VALID_FOLDS)]

In [6]:
dataset_params = dict(
    img_size=(120, 120),
    img_prefix=TRAIN_IMAGES, 
    augmentations=None,
    preprocess_img=preprocessing_fn,
    mode=MODE,
    binary=BINARY,
)

valid_dataset = ClsDataset(valid_df, **dataset_params)
valid_dl = BaseDataLoader(valid_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

test_dataset = ClsDataset(test_df, **dataset_params)
test_dl = BaseDataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

criterion = nn.BCEWithLogitsLoss()

In [7]:
def evaluate(model, data_loader, device=None, th=0.5, n_batches=None):
    if device is None:
        device = next(model.parameters()).device
    accs = []
    start = time.time()
    c = 0
    n = 0
    total = len(data_loader)
    if n_batches:
        total = n_batches
    pred_probas = []
    true_labels = []
    with torch.no_grad():
        for data_dict in tqdm(data_loader, total=total):
            if n_batches and c >= n_batches:
                break
            image = data_dict['features'].to(device)
            target = data_dict['targets'].to(device)
            output = torch.sigmoid(model(image))
            pred_probas.append(output.cpu().numpy())
            true_labels.append(target.cpu().numpy().astype(int))
            c += 1
            n += target.size(0)
    probas = np.concatenate(pred_probas)
    pred = (probas > th).astype(int)
    true = np.concatenate(true_labels)
    
    print(classification_report(true, pred))
    elapsed_time = time.time() - start
    fps = n / elapsed_time
    time_per_image = 1 / fps
    print(f'Elapsed time: {elapsed_time:0.2f}')
    print(f'{time_per_image:0.5f} sec/img')
    print(f'{fps:0.2f} img/sec (fps)')
    return true, probas

In [8]:
def get_model(weights=None):
    model = MobileNetV2(num_classes=1)
    if weights:
        state_dict = torch.load(weights)['model_state_dict']
        model.load_state_dict(state_dict)
    model.cuda()
    model.eval()
    return model

In [9]:
model = get_model(CONTINUE)

In [10]:
print('Validation')
true, probas = evaluate(model, valid_dl, th=0.5)

print('Test')
_, _ = evaluate(model, test_dl, th=0.5)

Validation


HBox(children=(FloatProgress(value=0.0, max=198.0), HTML(value='')))


              precision    recall  f1-score   support

           0       1.00      0.99      0.99     43867
           1       0.95      0.99      0.97      6785

    accuracy                           0.99     50652
   macro avg       0.97      0.99      0.98     50652
weighted avg       0.99      0.99      0.99     50652

Elapsed time: 19.74
0.00039 sec/img
2565.45 img/sec (fps)
Test


HBox(children=(FloatProgress(value=0.0, max=198.0), HTML(value='')))


              precision    recall  f1-score   support

           0       1.00      0.99      0.99     41692
           1       0.97      0.99      0.98      8958

    accuracy                           0.99     50650
   macro avg       0.98      0.99      0.99     50650
weighted avg       0.99      0.99      0.99     50650

Elapsed time: 18.59
0.00037 sec/img
2725.24 img/sec (fps)


In [11]:
print('Per dataset validation')
for dataset_type in test_df.dataset.unique():
    print(f'Dataset: {dataset_type}')
    sub_df = test_df[test_df.dataset==dataset_type]
    sub_dataset = ClsDataset(sub_df, **dataset_params)
    dl = BaseDataLoader(sub_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
    evaluate(model, dl)

Per dataset validation
Dataset: celeba


HBox(children=(FloatProgress(value=0.0, max=159.0), HTML(value='')))


              precision    recall  f1-score   support

           0       1.00      0.99      1.00     37801
           1       0.91      0.97      0.94      2718

    accuracy                           0.99     40519
   macro avg       0.95      0.98      0.97     40519
weighted avg       0.99      0.99      0.99     40519

Elapsed time: 15.64
0.00039 sec/img
2590.22 img/sec (fps)
Dataset: specface


HBox(children=(FloatProgress(value=0.0, max=1.0), HTML(value='')))


              precision    recall  f1-score   support

           1       1.00      1.00      1.00        64

    accuracy                           1.00        64
   macro avg       1.00      1.00      1.00        64
weighted avg       1.00      1.00      1.00        64

Elapsed time: 0.26
0.00407 sec/img
246.00 img/sec (fps)
Dataset: sof


HBox(children=(FloatProgress(value=0.0, max=2.0), HTML(value='')))


              precision    recall  f1-score   support

           0       0.00      0.00      0.00         0
           1       1.00      0.96      0.98       484

    accuracy                           0.96       484
   macro avg       0.50      0.48      0.49       484
weighted avg       1.00      0.96      0.98       484

Elapsed time: 0.55
0.00114 sec/img
880.98 img/sec (fps)
Dataset: meglass


  'recall', 'true', average, warn_for)


HBox(children=(FloatProgress(value=0.0, max=38.0), HTML(value='')))


              precision    recall  f1-score   support

           0       1.00      0.99      0.99      3891
           1       0.99      1.00      1.00      5692

    accuracy                           0.99      9583
   macro avg       0.99      0.99      0.99      9583
weighted avg       0.99      0.99      0.99      9583

Elapsed time: 3.98
0.00042 sec/img
2409.08 img/sec (fps)


# Static quantization

In [12]:
model.cpu()
model.fuse_model()
model.qconfig = torch.quantization.default_qconfig
torch.quantization.prepare(model, inplace=True)
print('Post Training Quantization Prepare: Inserting Observers')
evaluate(model, valid_dl, device='cpu', n_batches=25)
print('Post Training Quantization: Calibration done')

torch.quantization.convert(model, inplace=True)
print('Post Training Quantization: Convert done')

print('Quantized on valid dataset')
evaluate(model, valid_dl, device='cpu')
print('Quantized on test dataset')
evaluate(model, test_dl, device='cpu')

Post Training Quantization Prepare: Inserting Observers


HBox(children=(FloatProgress(value=0.0, max=25.0), HTML(value='')))

              precision    recall  f1-score   support

           0       1.00      0.99      0.99      6007
           1       0.88      0.97      0.92       393

    accuracy                           0.99      6400
   macro avg       0.94      0.98      0.96      6400
weighted avg       0.99      0.99      0.99      6400

Elapsed time: 76.28
0.01192 sec/img
83.90 img/sec (fps)
Post Training Quantization: Calibration done
Post Training Quantization: Convert done
Quantized on valid dataset


  Returning default scale and zero point.")


HBox(children=(FloatProgress(value=0.0, max=198.0), HTML(value='')))


              precision    recall  f1-score   support

           0       1.00      0.99      0.99     43867
           1       0.95      0.99      0.97      6785

    accuracy                           0.99     50652
   macro avg       0.97      0.99      0.98     50652
weighted avg       0.99      0.99      0.99     50652

Elapsed time: 69.53
0.00137 sec/img
728.53 img/sec (fps)
Quantized on test dataset


HBox(children=(FloatProgress(value=0.0, max=198.0), HTML(value='')))


              precision    recall  f1-score   support

           0       1.00      0.99      1.00     41692
           1       0.97      0.99      0.98      8958

    accuracy                           0.99     50650
   macro avg       0.98      0.99      0.99     50650
weighted avg       0.99      0.99      0.99     50650

Elapsed time: 68.76
0.00136 sec/img
736.62 img/sec (fps)


(array([[0],
        [0],
        [0],
        ...,
        [0],
        [0],
        [0]]), array([[3.5459205e-05],
        [3.1313755e-06],
        [1.3195903e-02],
        ...,
        [1.5790707e-05],
        [7.9624326e-05],
        [1.1794723e-03]], dtype=float32))

In [13]:
# torch.jit.save(torch.jit.script(model), 'quantized-mobilenetv2-scripted.pth')

# Full pipe

In [13]:
from modules.common import visualize
from modules.utils import crop_img, resize_shortest_edge, open_img
import cv2

In [14]:
def predict(model, img):
    with torch.no_grad():
        tensor = preprocessing_fn(img)
        tensor = torch.Tensor(tensor).unsqueeze(0)
        logit = model(tensor)
        prob = torch.sigmoid(logit).cpu().numpy()[0]
    return prob

In [15]:
model = torch.jit.load('trained_models/quantized-mobilenetv2-scripted.pth')

In [16]:
fld = 'data/cameos_dataset/without_glasses/'
for fname in os.listdir(fld):
    fp = osp.join(fld, fname)
    img = open_img(fp)
    crop = crop_img(img) / 255.
    crop = cv2.resize(crop, (120, 120))
    prob = predict(model, crop)
    print(prob)

[1.20584455e-05]
[2.3912426e-06]
[0.00010427]
[2.7078262e-05]
[1.20584455e-05]
[0.00452678]
[7.0318465e-06]
[0.13151582]
[0.00010427]
[4.7418968e-07]
[0.00090094]
[1.0648484e-06]
[0.0020209]
[6.20961e-07]
[2.3912426e-06]
[1.5790707e-05]
[3.5459205e-05]
[7.9624326e-05]
[7.9624326e-05]
[1.8260454e-06]


In [17]:
fld = 'data/cameos_dataset/with_glasses/'
for fname in os.listdir(fld):
    fp = osp.join(fld, fname)
    img = open_img(fp)
    crop = crop_img(img) / 255.
    crop = cv2.resize(crop, (120, 120))
    prob = predict(model, crop)
    print(prob)

[0.99947447]
[0.9973552]
[0.9999999]
[0.99982125]
[0.99998426]
[0.9999392]
[0.99998796]
[0.9999392]
[0.9999908]
[0.99999595]
[0.7938518]
[0.99992037]
[0.9999999]
[0.9979791]
[0.99999964]
[0.9999535]
[0.99999595]
[0.99999595]
[0.99999976]
[0.9998957]


# Some error analysis

In [88]:
from PIL import Image

In [43]:
sub_df = test_df[(test_df.dataset == 'celeba') & (test_df.has_glasses == 1)]
for ind, row in tqdm(sub_df.iterrows(), total=len(sub_df)):
    fp = row.filename
    img = open_img(fp)
    crop = crop_img(img) / 255.
    crop = cv2.resize(crop, (120, 120))
    prob = predict(model, crop)
    if prob < 0.85:
        print(prob, ind, fp)
        visualize(img=img, crop=crop)

HBox(children=(FloatProgress(value=0.0, max=2718.0), HTML(value='')))




NameError: name 'open_img' is not defined