In [1]:
!pip install -q timm pytorch-metric-learning[with-hooks]

[0m

In [2]:
import os
import glob
import pandas as pd
import numpy as np
import logging
import timm
from tqdm.notebook import tqdm

import torch
import torch.nn as nn
import torch.optim as optim

from torch.utils.data import Dataset, DataLoader
from torchvision.io import ImageReadMode, read_image
from torchvision.transforms import Compose, Lambda, Normalize, AutoAugment, AutoAugmentPolicy

import pytorch_metric_learning
import pytorch_metric_learning.utils.logging_presets as LP
from pytorch_metric_learning.utils import common_functions
from pytorch_metric_learning import losses, miners, samplers, testers, trainers
from pytorch_metric_learning.utils.accuracy_calculator import AccuracyCalculator
from pytorch_metric_learning.utils.inference import InferenceModel

In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

device(type='cuda')

### config

In [4]:
MODEL_NAME='tf_efficientnet_b6_ns' # timmのmodel_name
N_CLASSES=15587
OUTPUT_SIZE = 2304
EMBEDDING_SIZE = 512
N_EPOCH=5
BATCH_SIZE=8
MODEL_LR = 1e-3
PCT_START=0.3
PATIENCE=5
N_WORKER=2

### directories

In [5]:
TRAIN_DIR = '../input/jpeg-happywhale-384x384/train_images-384-384/train_images-384-384'
TEST_DIR = '../input/jpeg-happywhale-384x384/test_images-384-384/test_images-384-384'
LOG_DIR = f'/kaggle/working/logs/{MODEL_NAME}'
MODEL_DIR = f'/kaggle/working/model/{MODEL_NAME}'

### Dataset

In [6]:
class HappyWhaleDataset(Dataset):
    def __init__(
        self,
        df: pd.DataFrame,
        image_dir: str,
        return_labels=True,
    ):
        self.df = df
        self.images = self.df["image"]
        self.image_dir = image_dir
        self.image_transform = Compose(
            [
                AutoAugment(AutoAugmentPolicy.IMAGENET),
                Lambda(lambda x: x / 255),
                
            ]
        )
        self.return_labels = return_labels

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        
        image_path = os.path.join(self.image_dir, self.images.iloc[idx])
        image = read_image(path=image_path)
        image = self.image_transform(image)
        
        if self.return_labels:
            label = self.df['label'].iloc[idx]
            return image, label
        else:
            return image

In [7]:
df = pd.read_csv('../input/happy-whale-and-dolphin/train.csv')
df['label'] = df.groupby('individual_id').ngroup()

valid_proportion = 0.05
valid_df = df.sample(frac=valid_proportion, replace=False, random_state=1).copy()
train_df = df[~df['image'].isin(valid_df['image'])].copy()

train_df.reset_index(drop=True, inplace=True)
valid_df.reset_index(drop=True, inplace=True)

In [8]:
train_dataset = HappyWhaleDataset(df=train_df, image_dir=TRAIN_DIR, return_labels=True)
valid_dataset = HappyWhaleDataset(df=valid_df, image_dir=TRAIN_DIR, return_labels=True)

In [9]:
dataset_dict = {"train": train_dataset, "val": valid_dataset}

### trunk

In [10]:
trunk = timm.create_model(MODEL_NAME, pretrained=True)
trunk.classifier = common_functions.Identity()
trunk = trunk.to(device)

Downloading: "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/tf_efficientnet_b6_ns-51548356.pth" to /root/.cache/torch/hub/checkpoints/tf_efficientnet_b6_ns-51548356.pth


### embedder

In [11]:
embedder = nn.Linear(OUTPUT_SIZE, EMBEDDING_SIZE).to(device)

### loss function

https://kevinmusgrave.github.io/pytorch-metric-learning/losses/#arcfaceloss

In [12]:
loss_func = losses.ArcFaceLoss(num_classes=N_CLASSES, embedding_size=EMBEDDING_SIZE).to(device)

### optimizer

In [13]:
trunk_optimizer = optim.Adam(trunk.parameters(), lr=0.005)
embedder_optimizer = optim.Adam(embedder.parameters(), lr=0.001)
loss_optimizer = optim.Adam(loss_func.parameters(), lr=0.001)

In [14]:
optimizers = {
    "trunk_optimizer": trunk_optimizer,
    "embedder_optimizer": embedder_optimizer,
    "loss_optimizer": loss_optimizer,
}

### logging, hooks, tester

In [15]:
record_keeper, _, _ = LP.get_record_keeper(LOG_DIR)
hooks = LP.get_hook_container(record_keeper, primary_metric='mean_average_precision')

In [16]:
tester = testers.GlobalEmbeddingSpaceTester(
    end_of_testing_hook=hooks.end_of_testing_hook,
    accuracy_calculator=AccuracyCalculator(
        include=['mean_average_precision'],
        device=torch.device("cpu"),
        k=5),
    dataloader_num_workers=N_WORKER,
    batch_size=BATCH_SIZE
)

https://kevinmusgrave.github.io/pytorch-metric-learning/logging_presets/

In [17]:
end_of_epoch_hook = hooks.end_of_epoch_hook(
    tester, 
    dataset_dict,
    MODEL_DIR,
    test_interval=1, 
    patience=PATIENCE, 
    splits_to_eval = [('val', ['train'])]
)

## Trainers

In [18]:
trainer = trainers.MetricLossOnly(
    models={"trunk": trunk, "embedder": embedder},
    optimizers={"trunk_optimizer": trunk_optimizer, "embedder_optimizer": embedder_optimizer, "metric_loss_optimizer": loss_optimizer},
    batch_size=BATCH_SIZE,
    loss_funcs={"metric_loss": loss_func},
    mining_funcs={}, # empty dict
    dataset=train_dataset,
    dataloader_num_workers=N_WORKER,
    end_of_epoch_hook=end_of_epoch_hook,
)

In [19]:
trainer.train(num_epochs=N_EPOCH)

total_loss=41.68263: 100%|██████████| 6060/6060 [1:07:53<00:00,  1.49it/s]
100%|██████████| 6061/6061 [17:07<00:00,  5.90it/s]
100%|██████████| 319/319 [00:54<00:00,  5.80it/s]
total_loss=36.22073: 100%|██████████| 6060/6060 [1:08:07<00:00,  1.48it/s]
100%|██████████| 6061/6061 [17:14<00:00,  5.86it/s]
100%|██████████| 319/319 [00:54<00:00,  5.84it/s]
total_loss=33.38531: 100%|██████████| 6060/6060 [1:08:26<00:00,  1.48it/s]
100%|██████████| 6061/6061 [17:11<00:00,  5.87it/s]
100%|██████████| 319/319 [00:53<00:00,  5.95it/s]
total_loss=28.89172: 100%|██████████| 6060/6060 [1:08:21<00:00,  1.48it/s]
100%|██████████| 6061/6061 [17:19<00:00,  5.83it/s]
100%|██████████| 319/319 [00:55<00:00,  5.79it/s]
total_loss=22.93649: 100%|██████████| 6060/6060 [1:08:34<00:00,  1.47it/s]
100%|██████████| 6061/6061 [17:18<00:00,  5.84it/s]
100%|██████████| 319/319 [00:55<00:00,  5.77it/s]
