In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import sys

sys.path.insert(0, '..')

import numpy as np
import pandas as pd

from batchflow import Pipeline, B, V, I, M, C, plot
from batchflow.models.torch import TorchModel, VGGBlock

from src.loader import ImagesDataset

In [None]:
seed = 11
rng = np.random.default_rng(seed)

In [None]:
%%time

DATA_PATH = '../images'
SHAPE = (128, 128, 3)

dataset = ImagesDataset(path=DATA_PATH, encode_labels=True, normalize=True, resize_shape=SHAPE)
n_classes = dataset.label_encoder.classes_.size

dataset.split(shuffle=seed)

In [None]:
indices = rng.choice(range(dataset.size), 8)
images = list(dataset.images[indices])
labels = list(dataset.labels[indices])
plot(data=images, title=labels, combine='separate')

In [None]:
model_config = {
    'inputs/shapes': SHAPE[::-1],

    'body': {
        'type': 'encoder',
        'output_type': 'tensor',
        'order': ['block', 'downsampling'],
        'num_stages': 3,
        'blocks': {
            'base_block': VGGBlock,
            'channels': [64, 128, 256],
            'depth3': 2,
            'depth1': [0, 0, 1],
        },
    },

    'head': {
        'layout': 'Vdf',
        'dropout_rate': 0.4,
        'classes': n_classes
    },

    'common/conv/bias' : False,

    # Model training details:
    'init_model_weights': 'xavier',
    'loss': 'ce',
    'optimizer': 'Adam',
    'output': {'predicted': ['proba', 'labels']},
    'device': 'cpu'
}

In [None]:
def evaluate(iteration, frequency, model, metrics, agg):
    if (iteration - 1) % frequency == 0:
        infer_pipeline = infer_template << dataset.test << {'model': model}
        infer_pipeline.run(batch_size=dataset.test.size, n_epochs=1, drop_last=False)
        metrics_value = infer_pipeline.v('metrics').evaluate(metrics=metrics, agg=agg)
        return [metrics_value] * frequency
    return []
        
train_template = (
    Pipeline()
    .to_array(channels='first', dtype=np.float32)
    .init_variable(name='loss_history', default=[])
    .init_variable(name='test_metrics', default=[])
    .init_model(name='model', model_class=TorchModel, mode='dynamic', config=model_config)
    .train_model(name='model', inputs=B('images'), targets=B('labels'),
                 outputs='loss', save_to=V('loss_history', mode='a'))
    .call(evaluate, iteration=I(), model=M('model'), frequency=C('evaluate/frequency'),
          metrics=C('evaluate/metrics'), agg=C('evaluate/metrics'), save_to=V('test_metrics', mode='e'))
)

infer_template = (
    Pipeline()
    .to_array(channels='first', dtype=np.float32)
    .init_variables('proba', 'predictions', 'metrics')
    .import_model('model', C('model'))
    .predict_model(name='model', inputs=B('images'),
                   outputs=['predicted_proba', 'predictions'],
                   save_to=[V('proba'), V('predictions')])
    .gather_metrics('classification', targets=B('labels'), predictions=V('predictions'),
                    fmt='logits', num_classes=n_classes,
                    axis=1, save_to=V('metrics', mode='update'))
)

train_config = {
    'evaluate': {
        'frequency': 50,
        'metrics': 'accuracy',
    }
}

train_pipeline = train_template << dataset.train << train_config

In [None]:
BATCH_SIZE = 128
FREQUENCY = 50
EPOCH_NUM = 200

notifier = {
    'bar': 'n', 'frequency': FREQUENCY,
    'graphs': ['loss_history', 'cpu', 'test_metrics'],
}

_ = train_pipeline.run(batch_size=BATCH_SIZE, n_epochs=EPOCH_NUM, shuffle=True, notifier=notifier)

In [None]:
train_pipeline.model.plot_loss()

In [None]:
infer_pipeline = infer_template << dataset.test << {'model': train_pipeline.model}

In [None]:
infer_pipeline.run(batch_size=dataset.test.size, n_epochs=1, drop_last=False, bar='t')

In [None]:
infer_pipeline.v('metrics').plot_confusion_matrix(normalize=True)

In [None]:
labels, counts = np.unique(dataset.labels[dataset.test.indices], return_counts=True)
shares = counts / counts.sum()

metrics = ['precision', 'recall']
metrics_dict = infer_pipeline.v('metrics').evaluate(metrics, multiclass=None)
metrics_df = pd.DataFrame({'names': dataset.label_encoder.classes_, 'shares': shares, **metrics_dict})

formatter = lambda value: value if isinstance(value, str) else f"{int(value * 100)}%"
metrics_df.style.background_gradient('RdYlGn', vmin=0, vmax=1, subset=metrics).format(formatter)