In [None]:
!pip install python-box timm

In [None]:
import os
import warnings
from pprint import pprint
from glob import glob
import numpy as np
import pandas as pd
from box import Box
from timm import create_model
import matplotlib.pyplot as plt
from sklearn.model_selection import StratifiedKFold

import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
from torchvision.io import read_image
import pytorch_lightning as pl
from pytorch_lightning import callbacks
from pytorch_lightning.utilities.seed import seed_everything
from pytorch_lightning.callbacks.early_stopping import EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

warnings.filterwarnings('ignore')

In [None]:
config = {
    'seed': 2021,
    'root': '../input/petfinder-pawpularity-score',
    'n_splits': 5,
    'epochs': 20,
    'image_size': 640,
    'train_loader': {
        'batch_size': 32,
        'shuffle': True,
        'num_workers': 4,
        'pin_memory': False,
        'drop_last': True,
    },
    'val_loader': {
        'batch_size': 32,
        'shuffle': False,
        'num_workers': 4,
        'pin_memory': False,
        'drop_last': False,
    },
    'trainer': {
        'gpus': [1, 2, 3],
        'accumulate_grad_batches': 1,
        'progress_bar_refresh_rate': 1,
        'fast_dev_run': False,
        'num_sanity_val_steps': 0,
        'resume_from_checkpoint': None,
    },
    'model': {
        # EfficientNet
        'name': 'tf_efficientnet_b6_ns',
        # 'name': 'tf_efficientnet_b7_ns',

        # EfficientNetv2
        # 'name': 'tf_efficientnetv2_m_in21k',
        # 'name': 'tf_efficientnetv2_l_in21k',
        'output_dim': 1,
    },
    'optimizer': {
        'name': 'optim.AdamW',
        'params': {
            'lr': 1e-5,
        },
    },
    'scheduler': {
        'name': 'optim.lr_scheduler.CosineAnnealingWarmRestarts',
        'params': {
            'T_0': 20,
            'eta_min': 1e-4,
        }
    },
    'loss': 'nn.BCEWithLogitsLoss',
}

config = Box(config)

In [None]:
class PetfinderDataset(Dataset):
    def __init__(self, df, image_size=224):
        self._X = df["Id"].values
        self._y = None
        if "Pawpularity" in df.keys():
            self._y = df["Pawpularity"].values
        self._transform = T.Resize([image_size, image_size])

    def __len__(self):
        return len(self._X)

    def __getitem__(self, idx):
        image_path = self._X[idx]
        image = read_image(image_path)
        image = self._transform(image)
        if self._y is not None:
            label = self._y[idx]
            return image, label
        return image

class PetfinderDataModule(pl.LightningDataModule):
    def __init__(self, train_df, val_df, cfg):
        super().__init__()
        self._train_df = train_df
        self._val_df = val_df
        self._cfg = cfg

    def __create_dataset(self, train=True):
        return (
            PetfinderDataset(self._train_df, self._cfg.image_size)
            if train
            else PetfinderDataset(self._val_df, self._cfg.image_size)
        )

    def train_dataloader(self):
        dataset = self.__create_dataset(True)
        return DataLoader(dataset, **self._cfg.train_loader)

    def val_dataloader(self):
        dataset = self.__create_dataset(False)
        return DataLoader(dataset, **self._cfg.val_loader)

In [None]:
torch.autograd.set_detect_anomaly(True)
seed_everything(config.seed)

df = pd.read_csv(os.path.join(config.root, "train.csv"))
df["Id"] = df["Id"].apply(lambda x: os.path.join(config.root, "train", x + ".jpg"))

In [None]:
IMAGENET_MEAN = [0.485, 0.456, 0.406]  # RGB
IMAGENET_STD = [0.229, 0.224, 0.225]  # RGB

def get_default_transforms():
    transform = {
        "train": T.Compose(
            [
                T.RandomHorizontalFlip(),
                T.RandomVerticalFlip(),
                T.RandomAffine(15, translate=(0.1, 0.1), scale=(0.9, 1.1)),
                T.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ]
        ),
        "val": T.Compose(
            [
                T.ConvertImageDtype(torch.float),
                T.Normalize(mean=IMAGENET_MEAN, std=IMAGENET_STD),
            ]
        ),
    }
    return transform

In [None]:
skf = StratifiedKFold(n_splits=config.n_splits, shuffle=True, random_state=config.seed)

for fold, (train_idx, val_idx) in enumerate(skf.split(df["Id"], df["Pawpularity"])):
    train_df = df.loc[train_idx].reset_index(drop=True)
    val_df = df.loc[val_idx].reset_index(drop=True)
    dm = PetfinderDataModule(train_df, val_df, config)
    model = Model(config)
    earlystopping = EarlyStopping(monitor="val_loss")
    lr_monitor = callbacks.LearningRateMonitor()
    loss_checkpoint = callbacks.ModelCheckpoint(
        filename="best_loss",
        monitor="val_loss",
        save_top_k=1,
        mode="min",
        save_last=False,
    )
    logger = TensorBoardLogger(config.model.name)
    
    trainer = pl.Trainer(
        logger=logger,
        max_epochs=config.epochs,
        callbacks=[lr_monitor, loss_checkpoint, earlystopping],
        **config.trainer,
    )
    trainer.fit(model, dm)

In [None]:
from tensorboard.backend.event_processing.event_accumulator import EventAccumulator

path = glob(f'./{config.model.name}/default/version_0/events*')[0]
even_acc = EventAccumulator(path, size_guidance={'scalar': 0})
even_acc.Reload()

scalars = {}
for tag in event_acc.Tags()['scalars']:
    event = event_acc.Scalars(tag)
    scalars[tag] = [event.value for event in events]

In [None]:
import seaborn as sns
sns.set()

plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plt.plot(range(len(scalars['lr-AdamW'])), scalars['lr-AdamW'])
plt.xlabel('epoch')
plt.ylabel('lr')
plt.title('adamw lr')

plt.subplot(1, 2, 2)
plt.plot(range(len(scalars['train_loss'])), scalars['train_loss'], label='train_loss')
plt.plot(range(len(scalars['val_loss'])), scalars['val_loss'], label='val_loss')
plt.legend()
plt.ylabel('rmse')
plt.xlabel('epoch')
plt.title('train/val rmse')
plt.show()

In [None]:
print('best_val_loss', min(scalars['val_loss']))