In [None]:
from albumentations.pytorch import ToTensorV2
import numpy as np
from pytorch_lightning.callbacks import EarlyStopping
import logging
from sklearn.model_selection import StratifiedKFold
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from copy import copy
from argparse import Namespace
from tqdm import tqdm
import torch
import os
from sklearn.model_selection import train_test_split
import pandas as pd
from pytorch_lightning import Trainer

In [None]:
%matplotlib inline

In [None]:

import logging
import sys
logging.getLogger().addHandler(logging.StreamHandler())


In [None]:

!ls /kaggle/input/timm-pretrained-efficientnet
!mkdir -p /root/.cache/torch/hub/checkpoints/
!cp /kaggle/input/timm-pretrained-efficientnet/efficientnet/efficientnet_b0_ra-3dd342df.pth /root/.cache/torch/hub/checkpoints/efficientnet_b0_ra-3dd342df.pth


In [None]:

!pip install /kaggle/input/timm-package/timm-0.1.26-py3-none-any.whl
!pip install /kaggle/input/lmdb-python-package/lmdb-1.0.0/dist/lmdb-1.0.0.tar


# Functions

In [None]:
# file transforms.py


import albumentations as A
from albumentations.pytorch import ToTensorV2
from torchvision import transforms

dummy_transforms = A.Compose([
    A.ToFloat(max_value=1.0),
    ToTensorV2(),
])

lmdb_transforms = A.Compose([
    A.Resize(400, 400),
])


def get_train_transforms():
    return A.Compose([
        A.HueSaturationValue(hue_shift_limit=5, sat_shift_limit=5, val_shift_limit=5, p=1),
        A.ToFloat(max_value=1.0),
        A.RandomResizedCrop(256, 256, scale=(0.3, 0.9)),
        A.HorizontalFlip(p=0.5),
        A.VerticalFlip(p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])


def get_test_transforms():
    return A.Compose([
        A.ToFloat(max_value=1.0),
        A.Resize(400, 400),
        A.CenterCrop(256, 256),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2(),
    ])



In [None]:
# file utils.py

import numpy as np
import pandas as pd
import seaborn as sns
import os
from skimage import io
from torch.utils.data import Dataset
import torch
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid


class Unnormalize:
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        """
        Args:
            tensor (Tensor): Tensor image of size (C, H, W) to be normalized.
        Returns:
            Tensor: Normalized image.
        """
        for t, m, s in zip(tensor, self.mean, self.std):
            t.mul_(s).add_(m)
            # The normalize code -> t.sub_(m).div_(s)
        return tensor


def make_confusion_matrix(cf,
                          group_names=None,
                          categories='auto',
                          count=True,
                          percent=True,
                          cbar=True,
                          xyticks=True,
                          xyplotlabels=True,
                          sum_stats=True,
                          figsize=None,
                          cmap='Blues',
                          title=None):

    # CODE TO GENERATE TEXT INSIDE EACH SQUARE
    blanks = ['' for i in range(cf.size)]

    if group_names and len(group_names) == cf.size:
        group_labels = ["{}\n".format(value) for value in group_names]
    else:
        group_labels = blanks

    if count:
        group_counts = ["{0:0.0f}\n".format(value) for value in cf.flatten()]
    else:
        group_counts = blanks

    if percent:
        group_percentages = ["{0:.2%}".format(value) for value in cf.flatten() / np.sum(cf)]
    else:
        group_percentages = blanks

    box_labels = [f"{v1}{v2}{v3}".strip() for v1, v2, v3 in zip(group_labels, group_counts, group_percentages)]
    box_labels = np.asarray(box_labels).reshape(cf.shape[0], cf.shape[1])

    # CODE TO GENERATE SUMMARY STATISTICS & TEXT FOR SUMMARY STATS
    if sum_stats:
        # Accuracy is sum of diagonal divided by total observations
        accuracy = np.trace(cf) / float(np.sum(cf))

        # if it is a binary confusion matrix, show some more stats
        if len(cf) == 2:
            # Metrics for Binary Confusion Matrices
            precision = cf[1, 1] / sum(cf[:, 1])
            recall = cf[1, 1] / sum(cf[1, :])
            f1_score = 2 * precision * recall / (precision + recall)
            stats_text = "\n\nAccuracy={:0.3f}\nPrecision={:0.3f}\nRecall={:0.3f}\nF1 Score={:0.3f}".format(
                accuracy, precision, recall, f1_score)
        else:
            stats_text = "\n\nAccuracy={:0.3f}".format(accuracy)
    else:
        stats_text = ""

    # SET FIGURE PARAMETERS ACCORDING TO OTHER ARGUMENTS
    if figsize == None:
        # Get default figure size if not set
        figsize = plt.rcParams.get('figure.figsize')

    if xyticks == False:
        # Do not show categories if xyticks is False
        categories = False

    # MAKE THE HEATMAP VISUALIZATION
    plt.figure(figsize=figsize)
    sns.heatmap(cf, annot=box_labels, fmt="", cmap=cmap, cbar=cbar, xticklabels=categories, yticklabels=categories)

    if xyplotlabels:
        plt.ylabel('True label')
        plt.xlabel('Predicted label' + stats_text)
    else:
        plt.xlabel(stats_text)

    if title:
        plt.title(title)


def plot_image(img, label=None, ax=None):
    img = torch.Tensor(np.array(img))
    label_num_to_disease_map = {0: 'Cassava Bacterial Blight (CBB)',
                                1: 'Cassava Brown Streak Disease (CBSD)',
                                2: 'Cassava Green Mottle (CGM)',
                                3: 'Cassava Mosaic Disease (CMD)',
                                4: 'Healthy'}

    if not ax:
        ax = plt.gca()
    ax.imshow(img.permute(2, 1, 0))
    ax.axis('off')
    if label is not None:

        if isinstance(label, int):
            label = label_num_to_disease_map.get(label, 0)
        ax.set_title(f'{label}')


def plot_label_examples(dataset, targets, target_label):
    label_indices = np.where(targets == target_label)[0]

    sample = np.random.choice(label_indices, 6)

    fig = plt.figure(figsize=(20, 10))

    grid = ImageGrid(fig, 111,  # similar to subplot(111)
                     nrows_ncols=(2, 3),  # creates 2x2 grid of axes
                     axes_pad=0.1,  # pad between axes in inch.
                     )

    for ax, idx in zip(grid, sample):
        img, label = dataset[idx]
        assert label == target_label
        plot_image(img, ax=ax)
    plt.suptitle(f'Label {target_label}')
    plt.show()


class DatasetFromSubset(Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(image=x)['image']
        return x, y

    def __len__(self):
        return len(self.subset)


class CassavaDataset(Dataset):
    def __init__(self, root, image_ids, labels, transform=None):
        super().__init__()
        self.root = root
        self.image_ids = image_ids
        self.labels = labels
        self.targets = self.labels
        self.transform = transform

    def __len__(self):
        return len(self.image_ids)

    def __getitem__(self, idx):
        label = self.labels[idx]
        img = io.imread(os.path.join(self.root, self.image_ids[idx]))

        if self.transform:
            img = self.transform(image=img)['image']

        return img, label


In [None]:
# file models/model.py

from argparse import Namespace

import torch
from pytorch_lightning.metrics.functional import accuracy
from torch import nn
import timm
import pytorch_lightning as pl
import torch.nn.functional as F


class LeafDoctorModel(pl.LightningModule):
    def __init__(self, hparams = None):
        super().__init__()
        self.hparams = hparams or Namespace()

        self.trunk = timm.create_model('efficientnet_b0', pretrained=True, num_classes=5)

    def forward(self, x):
        return self.trunk(x)

    def predict_proba(self, x):
        probabilities = nn.functional.softmax(self.forward(x), dim=1)
        return probabilities

    def predict(self, x):
        return torch.max(self.forward(x), 1)[1]

    def configure_optimizers(self):
        optimizer = torch.optim.AdamW(self.parameters(),
                                      lr=self.hparams.lr or self.hparams.learning_rate,
                                      weight_decay=self.hparams.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                  patience=self.hparams.reduce_lr_on_pleteau_patience,
                                                                  verbose=True)
        return {
            'optimizer': optimizer,
            'lr_scheduler': lr_scheduler,
            'monitor': 'val_acc',
            'interval': 'epoch',
            'frequency': 1
        }

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)
        self.log("train_acc", acc, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log("train_loss", loss, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        return loss

    @torch.no_grad()
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)
        acc = accuracy(y_hat, y)
        self.log("val_acc", acc, prog_bar=True, logger=True),
        self.log("val_loss", loss, prog_bar=True, logger=True)


In [None]:
# file models/byol.py

import numpy as np
from argparse import Namespace
from copy import deepcopy, copy
from itertools import chain
from typing import Dict, List
import pytorch_lightning as pl
from torch import optim
import torch.nn.functional as f
import random
from typing import Callable, Tuple, Union
from kornia import augmentation as aug
from kornia import filters
from kornia.geometry import transform as tf
import torch
from torch import nn, Tensor


def normalized_mse(x: Tensor, y: Tensor) -> Tensor:
    x = f.normalize(x, dim=-1)
    y = f.normalize(y, dim=-1)
    return 2 - 2 * (x * y).sum(dim=-1)


class RandomApply(nn.Module):
    def __init__(self, fn: Callable, p: float):
        super().__init__()
        self.fn = fn
        self.p = p

    def forward(self, x: Tensor) -> Tensor:
        return x if random.random() > self.p else self.fn(x)


def default_augmentation(image_size: Tuple[int, int] = (224, 224)) -> nn.Module:
    return nn.Sequential(
        tf.Resize(size=image_size),
        RandomApply(aug.ColorJitter(0.2, 0.2, 0.2, 0.2), p=0.8),
        aug.RandomVerticalFlip(),
        aug.RandomHorizontalFlip(),
        RandomApply(filters.GaussianBlur2d((3, 3), (1.0, 1.0)), p=0.1),
        aug.RandomResizedCrop(size=image_size, scale=(0.3, 0.7)),
        aug.Normalize(
            mean=torch.tensor([0.485, 0.456, 0.406]),
            std=torch.tensor([0.229, 0.224, 0.225]),
        ),
    )


def mlp(dim: int, projection_size: int = 256, hidden_size: int = 4096) -> nn.Module:
    return nn.Sequential(
        nn.Linear(dim, hidden_size),
        nn.BatchNorm1d(hidden_size),
        nn.ReLU(inplace=True),
        nn.Linear(hidden_size, projection_size),
    )


class EncoderWrapper(nn.Module):
    def __init__(
        self,
        model: nn.Module,
        projection_size: int = 256,
        hidden_size: int = 4096,
        layer: Union[str, int] = -2,
    ):
        super().__init__()
        self.model = model
        self.projection_size = projection_size
        self.hidden_size = hidden_size
        self.layer = layer

        self._projector = None
        self._projector_dim = None
        self._encoded = torch.empty(0)
        self._register_hook()

    @property
    def projector(self):
        if self._projector is None:
            self._projector = mlp(
                self._projector_dim, self.projection_size, self.hidden_size
            )
        return self._projector

    def _hook(self, _, __, output):
        output = output.flatten(start_dim=1)
        if self._projector_dim is None:
            self._projector_dim = output.shape[-1]
        self._encoded = self.projector(output)

    def _register_hook(self):
        if isinstance(self.layer, str):
            layer = dict([*self.model.named_modules()])[self.layer]
        else:
            layer = list(self.model.children())[self.layer]

        layer.register_forward_hook(self._hook)

    def forward(self, x: Tensor) -> Tensor:
        _ = self.model(x)
        return self._encoded


class BYOL(pl.LightningModule):
    def __init__(
        self,
        model: nn.Module,
        image_size: Tuple[int, int] = (256, 256),
        hidden_layer: Union[str, int] = -2,
        projection_size: int = 256,
        hidden_size: int = 4096,
        augment_fn: Callable = None,
        beta: float = 0.99,
        hparams = None,
    ):
        super().__init__()
        self.augment = default_augmentation(image_size) if augment_fn is None else augment_fn
        self.beta = beta
        self.encoder = EncoderWrapper(
            model, projection_size, hidden_size, layer=hidden_layer
        )
        self.predictor = nn.Linear(projection_size, projection_size, hidden_size)
        self.hparams = hparams or Namespace()
        self._target = None

        self.encoder(torch.zeros(2, 3, *image_size))

    def forward(self, x: Tensor) -> Tensor:
        return self.predictor(self.encoder(x))

    @property
    def target(self):
        if self._target is None:
            self._target = deepcopy(self.encoder)
        return self._target

    def update_target(self):
        for p, pt in zip(self.encoder.parameters(), self.target.parameters()):
            pt.data = self.beta * pt.data + (1 - self.beta) * p.data

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay)
        lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                                  patience=self.hparams.reduce_lr_on_pleteau_patience,
                                                                  verbose=True)
        return {
            'optimizer': optimizer,
            'lr_scheduler': lr_scheduler,
            'monitor': 'train_loss',
            'interval': 'epoch',
            'frequency': 1
        }

    def training_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x = batch[0]
        with torch.no_grad():
            x1, x2 = self.augment(x), self.augment(x)

        pred1, pred2 = self.forward(x1), self.forward(x2)
        with torch.no_grad():
            targ1, targ2 = self.target(x1), self.target(x2)
        loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1))

        self.log("train_loss", loss.item())
        return {"loss": loss}

    @torch.no_grad()
    def validation_step(self, batch, *_) -> Dict[str, Union[Tensor, Dict]]:
        x = batch[0]
        x1, x2 = self.augment(x), self.augment(x)
        pred1, pred2 = self.forward(x1), self.forward(x2)
        targ1, targ2 = self.target(x1), self.target(x2)
        loss = torch.mean(normalized_mse(pred1, targ2) + normalized_mse(pred2, targ1))

        return {"loss": loss}

    @torch.no_grad()
    def validation_epoch_end(self, outputs: List[Dict]) -> Dict:
        val_loss = sum(x["loss"] for x in outputs) / len(outputs)
        self.log("val_loss", val_loss.item())


In [None]:
# file node_helpers.py

import logging
from copy import deepcopy

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping
from tqdm.auto import tqdm
from sklearn.metrics import accuracy_score, f1_score

from matplotlib import pyplot as plt


def score(predictions, labels):
    return {
        'accuracy': accuracy_score(predictions, labels),
        'f1_score': f1_score(predictions, labels, average='weighted'),
    }


def predict(model, dataset, indices, batch_size=10, num_workers=4, transform=None):
    transform = transform or get_test_transforms()
    dataset = DatasetFromSubset(
        torch.utils.data.Subset(dataset, indices=indices),
        transform=transform)

    loader = torch.utils.data.DataLoader(dataset,
                                         batch_size=batch_size,
                                         num_workers=num_workers,
                                         shuffle=False,
                                         drop_last=False)

    predictions = []
    probas = []
    model.eval()
    if torch.cuda.is_available():
        model = model.cuda()
    with torch.no_grad():
        for images, labels in tqdm(loader):
            if torch.cuda.is_available():
                images = images.cuda()
            batch_probas = model.predict_proba(images)
            batch_preds = torch.max(batch_probas, 1)[1]
            predictions.append(batch_preds)
            probas.append(batch_probas)

    predictions = torch.hstack(predictions).flatten().tolist()
    probas = torch.vstack(probas).tolist()

    return predictions, probas


def lr_find(trainer, model, train_data_loader, val_data_loader=None, plot=False):
    val_dataloaders = [val_data_loader] if val_data_loader else None

    lr_finder = trainer.tuner.lr_find(model,
                                      train_dataloader=train_data_loader,
                                      val_dataloaders=val_dataloaders)
    if plot:
        plt.figure()
        plt.title('LR finder results')
        lr_finder.plot(suggest=True)
        plt.show()

    newlr = lr_finder.suggestion()
    logging.info('LR finder suggestion: %f', newlr)

    return newlr


def train_byol(model, hparams, loader):
    byol = BYOL(model, image_size=(256, 256), hparams=hparams)
    early_stopping = EarlyStopping('train_loss',
                                   patience=hparams.early_stop_patience,
                                   verbose=True)

    trainer = Trainer.from_argparse_args(
        hparams,
        reload_dataloaders_every_epoch=True,
        terminate_on_nan=True,
        callbacks=[early_stopping],
    )

    if hparams.auto_lr_find:
        new_lr = lr_find(trainer, byol, loader, val_data_loader=loader)
        hparams.lr = new_lr
        byol.hparams.lr = new_lr

    trainer.fit(byol, loader, loader)
    return byol


In [None]:
# file lmdb_dataset.py

import logging
import os
from PIL import Image
import six

from torch.utils.data import DataLoader

import lmdb
from tqdm.auto import tqdm
import pyarrow as pa
import lz4framed

import torch.utils.data as data


def compress_serialize(thing):
    return pa.serialize(thing).to_buffer()


def deserialize_decompress(thing):
    return pa.deserialize(thing)


def raw_reader(path):
    with open(path, 'rb') as f:
        bin_data = f.read()
    return bin_data


class ImageLMDBDataset(data.Dataset):
    def __init__(self, db_path, transform=None, target_transform=None):
        self.db_path = str(db_path)
        self.env = lmdb.open(self.db_path, subdir=os.path.isdir(db_path),
                                     readonly=True, lock=False,
                                     readahead=False, meminit=False)
        with self.env.begin(write=False) as txn:
            self.length = deserialize_decompress(txn.get(b'__len__'))
            self.keys = deserialize_decompress(txn.get(b'__keys__'))
            self.labels = deserialize_decompress(txn.get(b'labels'))

        self.transform = transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        env = self.env
        with env.begin(write=False) as txn:
            byteflow = txn.get(self.keys[index])

        unpacked = deserialize_decompress(byteflow)
        image, label = unpacked

        if self.transform:
            image = self.transform(image=image)['image']

        return image, label

    def __len__(self):
        return self.length

    def __repr__(self):
        return self.__class__.__name__ + ' (' + self.db_path + ')'


def dataset_to_lmdb(dataset, out_path, write_frequency=2000, num_workers=8, map_size=1e11):
    dataset.loader = raw_reader
    data_loader = DataLoader(dataset, num_workers=num_workers, collate_fn=lambda x: x)

    lmdb_path = out_path
    isdir = os.path.isdir(lmdb_path)

    logging.debug("Generate LMDB to %s" % lmdb_path)
    db = lmdb.open(lmdb_path, subdir=isdir,
                           map_size=map_size, readonly=False,
                           meminit=False, map_async=True)

    labels = []
    logging.debug(len(dataset), len(data_loader))
    txn = db.begin(write=True)
    for idx, data in tqdm(enumerate(data_loader), total=len(data_loader)):
        image, label = data[0]
        txn.put(u'{}'.format(idx).encode('ascii'), compress_serialize((image, label)))
        if idx % write_frequency == 0:
            txn.commit()
            txn = db.begin(write=True)
        labels.append(int(label))

    # finish iterating through dataset
    logging.debug('Final commit')
    txn.commit()

    logging.debug('Writing keys and len')
    keys = [u'{}'.format(k).encode('ascii') for k in range(idx + 1)]
    with db.begin(write=True) as txn:
        txn.put(b'__keys__', compress_serialize(keys))
        txn.put(b'__len__', compress_serialize(len(keys)))
        txn.put(b'labels', compress_serialize(list(labels)))

    logging.debug("Flushing database ...")
    db.sync()
    db.close()

    return ImageLMDBDataset(out_path)


In [None]:
#Pipeline prepare

def prepare_lmdb(train_images_torch, test_images_torch):
    train_images_torch.transform = lmdb_transforms
    test_images_torch.transform = lmdb_transforms

    train_lmdb_path = 'data/03_primary/train.lmdb'
    test_lmdb_path = 'data/03_primary/test.lmdb'

    if any([os.path.exists(train_lmdb_path),
            os.path.exists(test_lmdb_path)]):
        raise Exception('LMDB files lready exist, delete manually to overwrite.')

    train_images_lmdb = dataset_to_lmdb(train_images_torch, train_lmdb_path)
    test_images_lmdb = dataset_to_lmdb(test_images_torch, test_lmdb_path)

    return train_images_lmdb, test_images_lmdb


In [None]:
#Pipeline pretrain

def pretrain_model(train_images_lmdb, test_images_lmdb, parameters):
    train_images_lmdb.transform = dummy_transforms
    test_images_lmdb.transform = dummy_transforms
    dataset = torch.utils.data.ConcatDataset([train_images_lmdb, test_images_lmdb])
    loader = torch.utils.data.DataLoader(dataset,
                                        batch_size=parameters['byol']['batch_size'],
                                        num_workers=parameters['data_loader_workers'],
                                        shuffle=True)

    classifier_params = Namespace(**parameters['classifier'])
    model = LeafDoctorModel(classifier_params)

    hparams = Namespace(**parameters['byol'])
    byol = train_byol(model.trunk, hparams, loader)

    state_dict = byol.encoder.model.state_dict()
    model = LeafDoctorModel(classifier_params)
    model.trunk.load_state_dict(state_dict)
    return model


In [None]:
#Pipeline train

def split_data(train_labels, parameters):
    """Splits trainig data into the train and validation set"""
    train_indices, val_indices = train_test_split(range(len(train_labels)),
                     stratify=train_labels.label,
                     random_state=parameters['seed'],
                     test_size=parameters['validation_size'])
    return train_indices, val_indices


def train_model(pretrained_model, train_images_lmdb, train_indices, val_indices, parameters):
    train_transform, val_transform = get_train_transforms(), get_test_transforms()

    train_dataset = DatasetFromSubset(torch.utils.data.Subset(train_images_lmdb, indices=train_indices),
                                      transform=train_transform)

    val_dataset = DatasetFromSubset(torch.utils.data.Subset(train_images_lmdb, indices=val_indices),
                                    transform=val_transform)

    train_data_loader = torch.utils.data.DataLoader(train_dataset,
                                                    batch_size=parameters['classifier']['batch_size'],
                                                    num_workers=parameters['data_loader_workers'],
                                                    shuffle=True)

    val_data_loader = torch.utils.data.DataLoader(val_dataset,
                                                  num_workers=parameters['data_loader_workers'],
                                                  batch_size=parameters['classifier']['batch_size'])

    # Callbacks
    model_checkpoint = ModelCheckpoint(monitor="val_acc",
                                       verbose=True,
                                       dirpath=parameters['classifier']['checkpoints_dir'],
                                       filename="{epoch}_{val_acc:.4f}",
                                       save_top_k=parameters['classifier']['save_top_k_checkpoints'])
    early_stopping = EarlyStopping('val_acc',
                                   patience=parameters['classifier']['early_stop_patience'],
                                   verbose=True,
                                   )

    hparams = Namespace(**parameters['classifier'])

    trainer = Trainer.from_argparse_args(
        hparams,
        reload_dataloaders_every_epoch = True,
        terminate_on_nan=True,
        callbacks=[model_checkpoint, early_stopping],
    )

    # Model
    model = LeafDoctorModel(hparams)
    model.load_state_dict(pretrained_model.state_dict())

    # Training
    trainer.fit(model, train_data_loader, val_data_loader)
    logging.info('Training finished')

    # Saving
    best_checkpoint = model_checkpoint.best_model_path
    model = LeafDoctorModel().load_from_checkpoint(checkpoint_path=best_checkpoint)
    return model


def score_model(model, train_images_torch, indices, parameters):
    logging.info('Scoring model')
    if parameters['classifier'].get('limit_val_batches'):
        indices = indices[:parameters['classifier']['limit_val_batches']*parameters['classifier']['batch_size']]
    labels = train_images_torch.labels[indices]
    predictions, probas = predict(model,
                          dataset=train_images_torch,
                          indices=indices,
                          batch_size=parameters['classifier']['batch_size'],
                          num_workers=parameters['data_loader_workers'],
                          transform=get_test_transforms())

    scores = score(predictions, labels)

    logging.info(f'Validation scores:\n{scores}')
    return scores, predictions


In [None]:
#Pipeline predict

def predict_submission(model, test_images_lmdb, sample_submission, parameters):
    logging.debug('Predicting with model')
    test_images_lmdb.transform = get_test_transforms()

    predictions, probas = predict(model,
                                  dataset=test_images_lmdb,
                                  indices=list(range(len(test_images_lmdb))),
                                  batch_size=parameters['classifier']['batch_size'],
                                  num_workers=parameters['data_loader_workers'],
                                  transform=get_test_transforms())

    sample_submission.label = predictions

    return sample_submission


In [None]:
#Pipeline cv

def cross_validation(pretrained_model, train_images_lmdb, test_images_lmdb, parameters):
    cv_results = {}
    score_values = {}

    if os.path.exists(parameters['cv_models_dir']):
        raise Exception('CV models path already exists, please delete it explicitly to overwrite')
    else:
        os.makedirs(parameters['cv_models_dir'])

    cv = StratifiedKFold(n_splits=parameters['cv_splits'], random_state=parameters['seed'])
    indices = np.array(list(range(len(train_images_lmdb))))
    labels = train_images_lmdb.labels
    for fold_num, (train_idx, val_idx) in enumerate(cv.split(indices, labels)):
        logging.info('Fitting CV fold %d', fold_num)
        model_path = os.path.join(parameters['cv_models_dir'], f'model_fold_{fold_num}.pt')
        fold_parameters = copy(parameters)
        model = train_model(pretrained_model, train_images_lmdb, train_idx, val_idx, fold_parameters)
        torch.save(model.state_dict(), model_path)
        scores, oof_predictions = score_model(model, train_images_lmdb, val_idx, fold_parameters)
        cv_results[f'fold_{fold_num}'] = {
            'model_path': model_path,
            'scores': scores,
            'val_indices': val_idx,
            'oof_predictions': oof_predictions,
        }

        for score in scores:
            if not score_values.get(score):
                score_values[score] = []
            score_values[score].append(scores[score])

    cv_results['summary'] = {}
    for score_name, scores in score_values.items():
        cv_results['summary'][f'{score_name}_mean'] = np.mean(scores)
        cv_results['summary'][f'{score_name}_std'] = np.std(scores)

    logging.info('Cross-validation results %s')
    return cv_results


# Parameters

In [None]:
parameters = {
    "seed": 42,
    "validation_size": 0.15,
    "data_loader_workers": 4,
    "classifier": {
        "gpus": -1,
        "batch_size": 10,
        "max_epochs": 100,
        "max_steps": 0,
        "auto_lr_find": 0,
        "lr": 0.001,
        "weight_decay": 0.0001,
        "early_stop_patience": 4,
        "reduce_lr_on_pleteau_patience": 3,
        "save_top_k_checkpoints": 1,
        "checkpoints_dir": "data/06_models/classifier/checkpoints"
    },
    "byol": {
        "gpus": -1,
        "batch_size": 10,
        "max_epochs": 100,
        "max_steps": 0,
        "auto_lr_find": 1,
        "lr": 0.01,
        "reduce_lr_on_pleteau_patience": 1,
        "weight_decay": 0.0001,
        "limit_train_batches": 300,
        "limit_val_batches": 1,
        "accumulate_grad_batches": 4,
        "early_stop_patience": 3,
        "from_checkpoint": 0
    }
}

In [None]:

DATA_DIR = '/kaggle/input/cassava-leaf-disease-classification'

train_labels = pd.read_csv(f'{DATA_DIR}/train.csv')
sample_submission = pd.read_csv(f'{DATA_DIR}/sample_submission.csv')
label_num_to_disease_map = pd.read_csv(f'{DATA_DIR}/label_num_to_disease_map.json')

train_images_torch = CassavaDataset(image_ids=train_labels.image_id.values, labels=train_labels.label.values, root=f'{DATA_DIR}/train_images')
test_images_torch = CassavaDataset(image_ids=sample_submission.image_id.values, labels=sample_submission.label.values, root=f'{DATA_DIR}/test_images')

pretrained_model_path = '/kaggle/input/byol-pretrained-cassava/pretrained_model_best.pt'
pretrained_model = LeafDoctorModel(hparams=Namespace(**parameters['classifier']))
pretrained_model.load_state_dict(torch.load(pretrained_model_path))

submission = pd.read_csv(f'{DATA_DIR}/sample_submission.csv')


# Execution

In [None]:
train_images_lmdb, test_images_lmdb = prepare_lmdb(train_images_torch, test_images_torch)

In [None]:
train_indices, val_indices = split_data(train_labels, parameters)

In [None]:
pretrained_model = pretrain_model(train_images_lmdb, test_images_lmdb, parameters)

In [None]:
model = train_model(pretrained_model, train_images_lmdb, train_indices, val_indices, parameters)

In [None]:
submission = predict_submission(model, test_images_lmdb, sample_submission, parameters)

In [None]:
val_scores, val_predictions = score_model(model, train_images_lmdb, val_indices, parameters)

In [None]:
print(val_scores)

In [None]:
submission.to_csv('submission.csv', index=False)