In [1]:
from kekas import Keker, DataOwner, DataKek
from kekas.transformations import Transformer, to_torch, normalize
from kekas.metrics import accuracy, bce_accuracy
from kekas.modules import Flatten, AdaptiveConcatPool2d
from kekas.callbacks import Callback, Callbacks, DebuggerCallback
import pretrainedmodels as pm
from albumentations import Compose, JpegCompression, CLAHE, RandomRotate90, Transpose, ShiftScaleRotate, \
        Blur, OpticalDistortion, GridDistortion, HueSaturationValue, Flip, VerticalFlip

In [6]:
import argparse
from itertools import islice
import json
from pathlib import Path
import shutil
import warnings
from typing import Dict

import numpy as np
import pandas as pd
from sklearn.metrics import fbeta_score
from sklearn.exceptions import UndefinedMetricWarning
import torch
from torch import nn, cuda
from torch.optim import Adam
from torchvision import transforms
import tqdm

from imet.models import get_model
from imet.dataset import TrainDataset, TTADataset, get_ids, DATA_ROOT
from imet.transforms import train_transform, test_transform
from imet.utils import (
    write_event, load_model, mean_df,
    ON_KAGGLE, set_models_path_env, seed_everything, 
    _reduce_loss, _make_mask, binarize_prediction, N_CLASSES)
from imet.losses import loss_function
from imet.optimizers import optimizer
import cv2
from PIL import Image
from torch.utils.data import DataLoader
import os
from imet.losses import FBeta

%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [3]:
batch_size = 32
input_size = 224
fold = 0
model = 'resnet34'

In [4]:
folds = pd.read_csv('folds.csv')
train_root = os.path.join(DATA_ROOT, 'train')
train_fold = folds[folds['fold'] != fold]
valid_fold = folds[folds['fold'] == fold]

In [5]:
def reader_fn(i, row):
    image = cv2.imread(os.path.join(train_root, f'{row["id"]}.png'))
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    image = Image.fromarray(image)
    labels = torch.zeros(N_CLASSES)
    for cls in row["attribute_ids"].split():
        labels[int(cls)] = 1
    return {"image": image, "labels": labels}


def get_transforms(dataset_key, size, p):
    # we need to use a Transformer class to apply transformations to DataKeks elements
    # dataset_key is an image key in dict returned by reader_fn
    
    TRAIN_AUGS = Transformer(dataset_key, train_transform(input_size))
    VAL_AUGS = Transformer(dataset_key, test_transform(input_size))
                                                      
    TO_ARRAY = Transformer(dataset_key, lambda x: np.array(x))

    NRM_TFMS = transforms.Compose([
        Transformer(dataset_key, to_torch()),
        Transformer(dataset_key, normalize())
    ])
    
    train_tfms = transforms.Compose([TRAIN_AUGS, TO_ARRAY, NRM_TFMS])
    val_tfms = transforms.Compose([VAL_AUGS, TO_ARRAY, NRM_TFMS])  # because we don't want to augment val set yet
    
    return train_tfms, val_tfms

In [7]:
train_tfms, val_tfms = get_transforms("image", input_size, 0.5)

train_dk = DataKek(df=train_fold, reader_fn=reader_fn, transforms=train_tfms)
val_dk = DataKek(df=valid_fold, reader_fn=reader_fn, transforms=val_tfms)

train_dl = DataLoader(train_dk, batch_size=batch_size, num_workers=6, shuffle=True, drop_last=True)
val_dl = DataLoader(val_dk, batch_size=batch_size, num_workers=6, shuffle=False)

In [8]:
dataowner = DataOwner(train_dl, val_dl, None)
criterion = FBeta(reduction="mean")

In [9]:
def step_fn(model: torch.nn.Module,
            batch: torch.Tensor) -> torch.Tensor:
    """Determine what your model will do with your data.

    Args:
        model: the pytorch module to pass input in
        batch: the batch of data from the DataLoader

    Returns:
        The models forward pass results
    """
    
    # you could define here whatever logic you want
    inp = batch["image"]  # here we get an "image" from our dataset
    return model(inp)

In [10]:
def fbeta(target: torch.Tensor,
          preds: torch.Tensor,
          thresh: bool = 0.1,
          beta: float = 2) -> float:
    target = target.cpu().detach().numpy()
    preds = (torch.sigmoid(preds).cpu().detach().numpy() > thresh).astype(int)
    return fbeta_score(target, preds, beta=beta, average='samples')

In [11]:
model = get_model(model, num_classes=N_CLASSES, pretrained=True, input_size=input_size)

In [12]:
keker = Keker(model=model,
              dataowner=dataowner,
              criterion=criterion,
              step_fn=step_fn,                 
              target_key="labels",               
              metrics={"acc": bce_accuracy, "fbeta": fbeta},      
              opt=torch.optim.SGD).to_fp16()

In [13]:
keker.freeze()
keker.kek_one_cycle(max_lr=1e-3, cycle_len=2, momentum_range=(0.95, 0.85), div_factor=25)

Epoch 1/2: 100% 2731/2731 [05:42<00:00, 10.88it/s, loss=1.1859, val_loss=1.1868, acc=0.0002, fbeta=0.0159]
Epoch 2/2: 100% 2731/2731 [04:58<00:00, 11.53it/s, loss=1.1067, val_loss=1.1084, acc=0.0000, fbeta=0.0240]


In [14]:
keker.unfreeze()

In [17]:
keker.kek_lr(final_lr=3, logdir="logdir/resnet50/5")

Epoch 1/1: 100% 2731/2731 [05:48<00:00,  7.83it/s, loss=1.0092]
End of LRFinder



In [18]:
keker.plot_kek_lr("logdir/resnet50/5")

In [19]:
keker.kek_one_cycle(max_lr=0.05, cycle_len=30, momentum_range=(0.95, 0.85), div_factor=25)

Epoch 1/30: 100% 2731/2731 [06:16<00:00,  8.18it/s, loss=1.0138, val_loss=1.0135, acc=0.0007, fbeta=0.2430]
Epoch 2/30: 100% 2731/2731 [06:12<00:00,  8.39it/s, loss=1.0073, val_loss=1.0075, acc=0.0039, fbeta=0.2874]
Epoch 3/30: 100% 2731/2731 [06:10<00:00,  8.34it/s, loss=1.0039, val_loss=1.0040, acc=0.0080, fbeta=0.3183]
Epoch 4/30: 100% 2731/2731 [06:11<00:00,  8.30it/s, loss=1.0019, val_loss=1.0015, acc=0.0133, fbeta=0.3358]
Epoch 5/30: 100% 2731/2731 [06:11<00:00,  8.29it/s, loss=0.9996, val_loss=0.9993, acc=0.0163, fbeta=0.3520]
Epoch 6/30: 100% 2731/2731 [06:11<00:00,  8.31it/s, loss=0.9976, val_loss=0.9976, acc=0.0238, fbeta=0.3679]
Epoch 7/30: 100% 2731/2731 [06:12<00:00,  8.24it/s, loss=0.9956, val_loss=0.9960, acc=0.0219, fbeta=0.3825]
Epoch 8/30: 100% 2731/2731 [06:12<00:00,  8.35it/s, loss=0.9948, val_loss=0.9948, acc=0.0280, fbeta=0.3951]
Epoch 9/30: 100% 2731/2731 [06:11<00:00,  8.31it/s, loss=0.9935, val_loss=0.9938, acc=0.0324, fbeta=0.4014]
Epoch 10/30: 100% 2731/2731 


F-score is ill-defined and being set to 0.0 in samples with no predicted labels.



Epoch 15/30: 100% 2731/2731 [06:12<00:00,  8.33it/s, loss=0.9876, val_loss=0.9901, acc=0.0453, fbeta=0.4332]
Epoch 16/30: 100% 2731/2731 [06:09<00:00,  8.32it/s, loss=0.9896, val_loss=0.9896, acc=0.0445, fbeta=0.4381]
Epoch 17/30: 100% 2731/2731 [06:08<00:00,  8.32it/s, loss=0.9882, val_loss=0.9894, acc=0.0434, fbeta=0.4389]
Epoch 18/30: 100% 2731/2731 [06:09<00:00,  8.41it/s, loss=0.9885, val_loss=0.9890, acc=0.0495, fbeta=0.4443]
Epoch 19/30: 100% 2731/2731 [06:09<00:00,  8.39it/s, loss=0.9892, val_loss=0.9887, acc=0.0475, fbeta=0.4471]
Epoch 20/30: 100% 2731/2731 [06:09<00:00,  8.34it/s, loss=0.9870, val_loss=0.9885, acc=0.0496, fbeta=0.4471]
Epoch 21/30:  70% 1906/2731 [03:51<01:40,  8.23it/s, loss=0.9869]


error: Traceback (most recent call last):
  File "/home/allerria/anaconda3/envs/images/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in _worker_loop
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/allerria/anaconda3/envs/images/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 138, in <listcomp>
    samples = collate_fn([dataset[i] for i in batch_indices])
  File "/home/allerria/anaconda3/envs/images/lib/python3.6/site-packages/kekas/data.py", line 24, in __getitem__
    datum = self.reader_fn(ind, data_dict)
  File "<ipython-input-5-bcd74728a171>", line 3, in reader_fn
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
cv2.error: OpenCV(4.1.0) /io/opencv/modules/imgproc/src/color.cpp:182: error: (-215:Assertion failed) !_src.empty() in function 'cvtColor'



In [None]:
keker.kek_lr(final_lr=4, logdir="logdir/resnet50/4")
keker.plot_kek_lr("logdir/resnet50/4")

In [None]:
keker.kek_one_cycle(max_lr=0.5, cycle_len=10, momentum_range=(0.95, 0.85), div_factor=25)