# Imports

In [None]:
import timm

import wandb
import pandas as pd
import numpy as np
import plotly.express as px
import torch
import wandb
from matplotlib import pyplot as plt

# Kaggle cell

In [None]:
# !pip install wandb timm numpy --upgrade
# !rm -r kaggle_happywhale_2022
# !git clone https://github.com/btseytlin/kaggle_happywhale_2022.git
# import sys
# sys.path.append('kaggle_happywhale_2022')
#
# import wandb
#
# try:
#     from kaggle_secrets import UserSecretsClient
#     user_secrets = UserSecretsClient()
#     api_key = user_secrets.get_secret("WANDB")
#     wandb.login(key=api_key)
#     anonymous = None
# except:
#     anonymous = "must"
#     wandb.login(anonymous=anonymous)
#     print('wand secret missing')

In [None]:
from happywhale import (ImageDataMoodule,
                        ImageBackbone,
                        MetricLearner,
                        seed_torch,
                        load_train_test_dfs,
                        get_cv_splits)

In [None]:
%matplotlib inline

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# seed_torch(42)

# Config

In [None]:
OFFLINE = False
EXP_NAME = 'metric_learning'
tags = ['dataset_base_256', 'backbone_efficientnet_b0', 'metric_learning', 'contrastive_loss']
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'


TRAIN_IMG_DIR = '../input/jpeg-happywhale-256x256/train_images-256-256/train_images-256-256'
TEST_IMG_DIR = '../input/jpeg-happywhale-256x256/test_images-256-256/test_images-256-256'
TRAIN_CSV_PATH = '../input/happy-whale-and-dolphin/train.csv'
TEST_CSV_PATH = '../input/happy-whale-and-dolphin/sample_submission.csv'



BACKBONE = 'efficientnet_b0'
LR = 1e-3
BACKBONE_EMBEDDING_DIM = 1000
EMBEDDING_SIZE = 256

N_EPOCHS = 30
BATCH_SIZE = 64
NUM_WORKERS = 2


TRAINER_KWARGS = dict(
    max_epochs=N_EPOCHS,
    devices="auto",
    accelerator="auto",
    gradient_clip_val=1,
    accumulate_grad_batches=2,
    # stochastic_weight_avg=True,
    # fast_dev_run=True,
)

if DEVICE != 'cpu':
    TRAINER_KWARGS.update(
        dict(
            # amp_backend='apex',
            # amp_level='O2',
            precision=16,
        )
    )

# Local overrides

In [None]:
TRAINER_KWARGS['fast_dev_run'] = 5
TRAIN_IMG_DIR = '../data/images_128/train_images-128-128'
TEST_IMG_DIR = '../data/images_128/test_images-128-128'
TRAIN_CSV_PATH = '../data/train.csv'
TEST_CSV_PATH = '../data/test.csv'
EXP_NAME = 'LOCAL_TEST'
OFFLINE = True

# CV: split and prepare datamodules

In [None]:
train_df, test_df = load_train_test_dfs(
    train_csv_path=TRAIN_CSV_PATH,
    test_csv_path=TEST_CSV_PATH,
    train_images_path=TRAIN_IMG_DIR,
    test_images_path=TEST_IMG_DIR,
)
print(train_df.shape, test_df.shape)

In [None]:
cv_splits = get_cv_splits(train_df)

In [None]:
len(cv_splits[0].train), len(cv_splits[0].val), len(cv_splits[0].test)

In [None]:
split_datamodules = []
for split in cv_splits:
    split_train_df = train_df.iloc[split.train]
    split_val_df = train_df.iloc[split.val]
    split_test_df = train_df.iloc[split.test]
    datamodule = ImageDataMoodule(
        train_df=split_train_df,
        val_df=split_val_df,
        test_df=split_test_df,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        sampler='m_per_class',
    )
    datamodule.setup()
    split_datamodules.append(datamodule)

# Train model

In [None]:
models = []
for i, datamodule in enumerate(split_datamodules):
    backbone = ImageBackbone(model_name=BACKBONE)
    model = MetricLearner(
        backbone=backbone,
        lr=LR,
        num_labels=len(datamodule.label_encoder.classes_),
        num_training_steps=len(datamodule.train)//datamodule.batch_size * N_EPOCHS,
        trainer_kwargs=TRAINER_KWARGS,
        backbone_embedding_dim=BACKBONE_EMBEDDING_DIM,
        embedding_size=EMBEDDING_SIZE,
        offline=OFFLINE,
    )

    run = wandb.init(
        project='kaggle_happywhale',
        name=EXP_NAME + f'_fold_{i}',
        tags=tags,
    )

    model.fit(datamodule)
    model.test(datamodule)
    models.append(model.cpu())
    run.finish()


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'test_loss': 0.0}
--------------------------------------------------------------------------------



VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
Running in fast_dev_run mode: will run a full train, val, test and prediction loop using 5 batch(es).


  rank_zero_deprecation(

  | Name     | Type               | Params
------------------------------------------------
0 | backbone | ImageBackbone      | 5.3 M 
1 | mlp      | Sequential         | 16.3 K
2 | loss     | ContrastiveLoss    | 0     
3 | miner    | BatchEasyHardMiner | 0     
------------------------------------------------
5.3 M     Trainable params
0         Non-trainable params
5.3 M     Total params
21.219    Total estimated model params size (MB)
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x1415d3ca0>
Traceback (most recent call last):
  File "/Users/btseytlin/.pyenv/versions/3.8.12/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1328, in __del__
    self._shutdown_workers()
  File "/Users/btseytlin/.pyenv/versions/3.8.12/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 1301, in _shutdown_workers
    w.join(timeout=_utils.MP_STATUS_CHECK_INTERVAL)
  File "/Users/btseytlin/.pyenv/versions/3.8.12/lib/python3.8/multiprocessing/process.py", line 149, in join
    res = self._popen.wait(timeout)
  File "/Users/btseytlin/.pyenv/versions/3.8.12/lib/python3.8/multiprocessing/popen_fork.py", line 44, in wait
    if not wait([self.sentinel], timeout):
  File "/Users/btseytlin/.pyenv/versions/3.8.12/lib/python3.8/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/Users/btseytlin/.pyenv/versions/3.8.12/lib/python3.8/selectors.py", 