# Preliminary steps

## Competition

[HMS - Harmful Brain Activity Classification](https://www.kaggle.com/competitions/hms-harmful-brain-activity-classification)

## Objective

Learn Pytorch by following the [Tobi's](https://www.kaggle.com/morodertobias) excellent [training](https://www.kaggle.com/code/morodertobias/hms-pytorch-baseline-training-private) and [inference](https://www.kaggle.com/code/morodertobias/hms-pytorch-baseline-inference-private) kernels. This notebook is mostly just copied from there.

## Connected kernels
- [training](https://www.kaggle.com/fejust/24-hms-fj-02-pytorch-starter-training)
- [inference](https://www.kaggle.com/fejust/24-hms-fj-03-pytorch-starter-inference)

## References
- https://www.kaggle.com/code/morodertobias/hms-pytorch-baseline-training-private
- https://www.kaggle.com/code/morodertobias/hms-pytorch-baseline-inference-private
- pytorch image models [timm](https://pypi.org/project/timm/)
- https://www.kaggle.com/code/morodertobias/hms-pytorch-learning-nb
- https://www.kaggle.com/code/andreasbis/hms-train-efficientnetb0/notebook


## Imports

In [None]:
# standard imports
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from datetime import datetime
from datetime import date
from pytz import timezone
import pandas as pd
import numpy as np
import pathlib
import json
import os
import gc

# other imports
import random
import albumentations as A
from sklearn.model_selection import KFold

# pytorch
import timm
import torch
import torch.nn as nn  
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim.lr_scheduler import CosineAnnealingLR

import torchvision
import torchvision.transforms as transforms

print(f"pytorch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")

## Setting

In [None]:
class CFG:
    # general
    seed = 42
    debug = True
    kernel_name = "24-hms-fj-02-pytorch-starter-training"
    
    # wandb
    wandb_tracking = False
    wandb_project = "24-HMS-Harmful-Brain-Activity-Classification"
    wandb_run_name = kernel_name + "-" + datetime.strftime(datetime.now(timezone("Europe/Berlin")), "%Y-%m-%d %H:%M")
    
    # paths
    base_dir = pathlib.Path("/kaggle/input/hms-harmful-brain-activity-classification")
    spec_dir = base_dir / "train_spectrograms"
    model_dir = "models/"
    
    # spectra
    transform = transforms.Resize((512, 512), antialias=False)
    
    # model
    model_name = "tf_efficientnet_b0_ns"
    ckpt_name  = "ckpt_" + model_name
        
    # training
    one_fold = False
    n_fold = 5
    epochs = 15
    batch_size = 16    
    lr = 0.001
        
config_dict = vars(CFG)
config_dict = {key: value for key, value in config_dict.items() if not key.startswith("__")}

# Setup

In [None]:
start_time = datetime.now()

os.makedirs(CFG.model_dir, exist_ok=True)

## Device

In [None]:
!nvidia-smi

In [None]:
if torch.cuda.is_available():
    DEVICE = "cuda"  # this will likely not work for GPUT T4x2
    DEVICE_NAME = torch.cuda.get_device_name(0)
else:
    DEVICE = "cpu"
    DEVICE_NAME = device
    
print(DEVICE_NAME)
config_dict["device"] = DEVICE
config_dict["device_name"] = DEVICE_NAME

## wandb

In [None]:
from kaggle_secrets import UserSecretsClient
user_secrets = UserSecretsClient()
wandb_api = user_secrets.get_secret("wandb_api_key")

In [None]:
import wandb
from wandb.keras import WandbCallback

if CFG.wandb_tracking:
    wandb.login(key=wandb_api)

    # Initialize W&B run
    run = wandb.init(project=CFG.wandb_project,
                     config=config_dict,
                     group="GPU_model", 
                     job_type="train",
                     name=CFG.wandb_run_name
                     )

    wandb.config.type = "baseline"
    wandb.config.kaggle_competition = CFG.wandb_project

# Data handling

## Average votes

In [None]:
train_df = pd.read_csv(CFG.base_dir / "train.csv")
train_df.head()

In [None]:
label_columns = train_df.filter(like="_vote").columns.to_list()
label_columns

In [None]:
data = train_df.groupby("spectrogram_id")[label_columns].sum()
n = data.sum(axis=1)
for x in label_columns:
    data[x] = data[x] / n
# data.head()

In [None]:
data["path"] = data.index.map(lambda x: CFG.spec_dir / f"{x}.parquet")
data = data.reset_index()
data

## Spectrograms

In [None]:
def to_image(x):
    """
    clip, log-transform, and standardise images
    """
    
    x = x.astype('float32')
    x = np.clip(x, np.exp(-6), np.exp(10))
    x = np.log(x)
    v_min, v_max = np.min(x), np.max(x)
    x = 255.0 * (x - v_min) / (v_max - v_min + 1e-8)
    x = x.astype('uint8')
    return x


def to_tensor(x):
    x = x.astype('float32') / 255.0
    x = 2 * x - 1.0
    x = torch.Tensor(x[None, :])
    x = CFG.transform(x)
    return x

In [None]:
# # dev: load and plot spectra

# row = data.iloc[0]
# x = pd.read_parquet(row.path)
# x = x.fillna(-1).values[:, 1:].T
# x = to_image(x)
# print(x.shape, x.dtype, x.min(), x.max())

# plt.imshow(x)
# plt.show()

In [None]:
aug = A.Compose([
    A.HorizontalFlip(p=0.5),
    A.RandomBrightnessContrast(p=0.05),
    A.GaussianBlur(blur_limit=5, p=0.3)
])

In [None]:
# # dev: check agmentations

# fig, axs = plt.subplots(nrows=2, ncols=5, figsize=(16, 8), sharex='all', sharey='all')
# for i, ax in enumerate(axs.flat):
#     if i == 0:
#         img1 = x
#     else:
#         img1 = aug(image=x)["image"]
#     ax.imshow(img1)
#     print((type(img1), img1.dtype, img1.min(), img1.max()))
# plt.show()

In [None]:
class SpecDataset(Dataset):
    
    def __init__(self, df, aug=False):
        self.df = df
        self.aug = aug
    
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, index):
        row = self.df.iloc[index]
        # input
        x = pd.read_parquet(row.path)
        x = x.fillna(-1).values[:, 1:].T
        x = to_image(x)
        if self.aug:
            x = aug(image=x)["image"]
        x = to_tensor(x)
        # output
        y = np.array(row.loc[label_columns].values, 'float32')
        y = torch.Tensor(y)
        return x, y

In [None]:
# test datasÃ©t
ds = SpecDataset(df=data.iloc[:50], aug=True)
x, y = ds[0]

print(len(ds))
print(x.shape, y.shape)

In [None]:
# test dataloader

ld = DataLoader(dataset=ds, batch_size=CFG.batch_size, drop_last=True, num_workers=os.cpu_count())

x, y = next(iter(ld))
img = x[0, 0]
plt.imshow(img)

print(len(ld))
print(x.shape, y.shape)

# Model

In [None]:
model = timm.create_model(model_name=CFG.model_name, pretrained=True, num_classes=6, in_chans=1)
model.to(DEVICE)
num_parameter = sum(x.numel() for x in model.parameters())
print(f"Model has {num_parameter} parameters.")

In [None]:
# # dev: model output
# y_out = model(x.to(DEVICE))
# y_out.shape

# Training

In [None]:
def KLDivLoss(logit, target):
    log_prob = F.log_softmax(logit, dim=1)
    return F.kl_div(log_prob, target, reduction="batchmean")

In [None]:
def compute_loss(model, data_loader):
    model.eval()
    l_loss = []
    with torch.no_grad():
        for x, y in data_loader:
            y_pred = model(x.to(DEVICE))
            loss = KLDivLoss(y_pred, y.to(DEVICE))
            l_loss.append(loss.item())
    return np.mean(l_loss) 

In [None]:
compute_loss(model, ld)

In [None]:
del model, x, y
torch.cuda.empty_cache()

In [None]:
if CFG.debug:
    data = data.iloc[:200]

In [None]:
%%time

kf = KFold(n_splits=CFG.n_fold, shuffle=True, random_state=CFG.seed)

l_best_loss = []
l_history = []

for fold, (iloc_train, iloc_valid) in enumerate(kf.split(data)):
    print(f"Fold {fold}:")

    # prepare data
    train_ds = SpecDataset(df=data.iloc[iloc_train], aug=True)
    valid_ds = SpecDataset(df=data.iloc[iloc_valid])
    train_loader = DataLoader(dataset=train_ds, shuffle=True, batch_size=CFG.batch_size, num_workers=os.cpu_count(), drop_last=True)
    valid_loader = DataLoader(dataset=valid_ds, batch_size=CFG.batch_size, num_workers=os.cpu_count())
    
    # init training
    model = timm.create_model(model_name=CFG.model_name, pretrained=True, num_classes=6, in_chans=1)
    model.to(DEVICE)
    optimizer = optim.AdamW(model.parameters(), lr=CFG.lr)
    scheduler = CosineAnnealingLR(optimizer=optimizer, T_max=CFG.epochs)
    optimizer, scheduler
    best_loss = float("inf")
    history = []
    
    # run training
    for epoch in tqdm(range(CFG.epochs)):
        model.train()
        l_loss = []
        for x, y in tqdm(train_loader):         # go through batches
            x, y = x.to(DEVICE), y.to(DEVICE)
            y_pred = model(x)
            loss = KLDivLoss(y_pred, y)
            l_loss.append(loss.item())
            loss.backward()                     # calculate gradients
            optimizer.step()                    # update weights
            optimizer.zero_grad()               # reset gradients
            
        train_loss = np.mean(l_loss)
        valid_loss = compute_loss(model, valid_loader)
        
        if CFG.wandb_tracking:
            wandb.log({
                "fold" : fold,
                "training loss": train_loss,
                "validation loss": valid_loss
            })
        
        history.append((epoch, train_loss, valid_loss))
        print(f"Epoch {epoch}")
        print(f"Train Loss: {train_loss:>10.6f}, Valid Loss: {valid_loss:>10.6}")
        if valid_loss < best_loss:
            print(f"Loss improves from {best_loss:>10.6f} to {valid_loss:>10.6}")
            torch.save(model.state_dict(), f"{CFG.ckpt_name}__{fold}.pt")
            best_loss = valid_loss
    print(f"\nBest loss Model training with {best_loss}\n")
    
    history = pd.DataFrame(history, columns=["epoch", "loss", "val_loss"]).set_index("epoch")
#     history.plot(subplots=True, layout=(1, 2), sharey="row", figsize=(14, 6))
#     plt.show()
    
    l_best_loss.append(best_loss)
    l_history.append(history)
    
    if CFG.one_fold:
        break

In [None]:
for i, h in enumerate(l_history):
    h.plot(subplots=True, layout=(1, 2), sharey="row", figsize=(8, 4))
    plt.gcf().suptitle(f"fold {i+1}")
    plt.show()

## OOF

In [None]:
l_best_loss, np.mean(l_best_loss)

# Save config

In [None]:
end_time = datetime.now()
runtime = (end_time - start_time).seconds

config_dict["runtime"] = runtime

config_dump = {str(key) : str(value) for key, value in config_dict.items()}

with open('cfg.json', 'w') as f:
    json.dump(config_dump, f)

print(f"runtime: {runtime/60/60 :.2f} h")