In [76]:
from kekas import Keker, DataOwner, DataKek
from kekas.transformations import Transformer, to_torch, normalize
from kekas.metrics import accuracy, bce_accuracy
from kekas.modules import Flatten, AdaptiveConcatPool2d
from kekas.callbacks import Callback, Callbacks, DebuggerCallback
import pretrainedmodels as pm
from albumentations import Compose, JpegCompression, CLAHE, RandomRotate90, Transpose, ShiftScaleRotate, \
        Blur, OpticalDistortion, GridDistortion, HueSaturationValue, Flip, VerticalFlip

In [46]:
import argparse
from itertools import islice
import json
from pathlib import Path
import shutil
import warnings
from typing import Dict

import numpy as np
import pandas as pd
from sklearn.metrics import fbeta_score
from sklearn.exceptions import UndefinedMetricWarning
import torch
from torch import nn, cuda
from torch.optim import Adam
from torchvision import transforms
import tqdm

from imet.models import get_model
from imet.dataset import TrainDataset, TTADataset, get_ids, DATA_ROOT
from imet.transforms import train_transform, test_transform
from imet.utils import (
    write_event, load_model, mean_df,
    ON_KAGGLE, set_models_path_env, seed_everything, 
    _reduce_loss, _make_mask, binarize_prediction, N_CLASSES)
from imet.losses import loss_function
from imet.optimizers import optimizer
import cv2
from PIL import Image
from torch.utils.data import DataLoader

In [47]:
batch_size = 32
input_size = 288
fold = 0
model = 'resnet34'

In [48]:
folds = pd.read_csv('folds.csv')
train_root = DATA_ROOT / 'train'
train_fold = folds[folds['fold'] != fold]
valid_fold = folds[folds['fold'] == fold]

In [49]:
def make_loader(df: pd.DataFrame, image_transform) -> DataLoader:
        return DataLoader(
            TrainDataset(train_root, df, image_transform, debug=False),
            shuffle=True,
            batch_size=batch_size,
            num_workers=6,
        )

In [50]:
model = get_model('resnet50', num_classes=N_CLASSES, pretrained=True, input_size=input_size)

In [51]:
fresh_params = list(model._classifier.parameters())
all_params = list(model.parameters())

In [67]:
def reader_fn(i, row):
    image = cv2.imread(str(train_root / f'{row["id"]}.png'))[:,:,::-1]
    image = Image.fromarray(image)
    labels = torch.zeros(N_CLASSES)
    for cls in row["attribute_ids"].split():
        labels[int(cls)] = 1
    return {"image": image, "labels": labels}


def get_transforms(dataset_key, size, p):
    # we need to use a Transformer class to apply transformations to DataKeks elements
    # dataset_key is an image key in dict returned by reader_fn
    
    AUGS = Transformer(dataset_key, train_transform(input_size))
                           
    TO_ARRAY = Transformer(dataset_key, lambda x: np.array(x))

    NRM_TFMS = transforms.Compose([
        Transformer(dataset_key, to_torch()),
        Transformer(dataset_key, normalize())
    ])
    
    train_tfms = transforms.Compose([AUGS, TO_ARRAY, NRM_TFMS])
    val_tfms = transforms.Compose([TO_ARRAY, NRM_TFMS])  # because we don't want to augment val set yet
    
    return train_tfms, val_tfms

In [68]:
train_tfms, val_tfms = get_transforms("image", input_size, 0.5)

train_dk = DataKek(df=train_fold, reader_fn=reader_fn, transforms=train_tfms)
val_dk = DataKek(df=valid_fold, reader_fn=reader_fn, transforms=val_tfms)

In [69]:
train_dl = DataLoader(train_dk, batch_size=batch_size, num_workers=6, shuffle=True, drop_last=True)
val_dl = DataLoader(val_dk, batch_size=batch_size, num_workers=6, shuffle=False)

In [70]:
dataowner = DataOwner(train_dl, val_dl, None)
criterion = nn.BCEWithLogitsLoss()

In [71]:
def step_fn(model: torch.nn.Module,
            batch: torch.Tensor) -> torch.Tensor:
    """Determine what your model will do with your data.

    Args:
        model: the pytorch module to pass input in
        batch: the batch of data from the DataLoader

    Returns:
        The models forward pass results
    """
    
    # you could define here whatever logic you want
    inp = batch["image"]  # here we get an "image" from our dataset
    return model(inp)

In [77]:
keker = Keker(model=model,
              dataowner=dataowner,
              criterion=criterion,
              step_fn=step_fn,                    # previosly defined step function
              target_key="labels",                 # remember, we defined it in the reader_fn for DataKek?
              metrics={"acc": bce_accuracy},          # optional, you can not specify any metrics at all
              opt=torch.optim.Adam,               # optimizer class. if note specifiyng, 
                                                  # an SGD is using by default
              opt_params={"weight_decay": 1e-5})

In [None]:
keker.kek_lr(final_lr=0.1, logdir="logdir")

Epoch 1/1:  10% 275/2731 [01:29<13:23,  3.06it/s, loss=0.0207]