In [2]:
import logging
import os
import sys
sys.path.append("/home/hiramatsu/kaggle/hms-harmful-brain-activity-classification/")
from pathlib import Path

import hydra
import numpy as np
import pandas as pd
import yaml
import torch
from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.callbacks import (EarlyStopping, LearningRateMonitor,
                                         ModelCheckpoint, RichModelSummary,
                                         RichProgressBar)
from pytorch_lightning.loggers import WandbLogger
from omegaconf import OmegaConf

import wandb
from src.conf import TrainConfig
from src.datamodule import HMSDataModule
from src.modelmodule import HMSModel

from typing import Optional

import albumentations as A
import numpy as np
import pandas as pd
import torch
import torch.optim as optim
from pytorch_lightning import LightningModule
from transformers import get_cosine_schedule_with_warmup

from src.conf import TrainConfig
from src.models.common import get_model
from src.utils.augmentation import cutmix_data, mixup_data
from src.utils.kaggle_kl_div import score
from src.utils.loss_functions import (KLDivLossWithLogits,
                                      KLDivLossWithLogitsForVal,
                                      KLDWithContrastiveLoss)


In [None]:
class HMSModel(LightningModule):
    def __init__(
        self,
        cfg: TrainConfig,
        val_df: pd.DataFrame,
        fold_id: int
    ):
        super().__init__()
        self.cfg = cfg
        self.model = get_model(cfg)

        self.loss_func = KLDWithContrastiveLoss() if (cfg.model.name=="HMSSpecPararellModel") or (cfg.model.name=="HMSSpecEEGPararellModel")  else KLDivLossWithLogits() 

        self.validation_step_outputs: list = []
        self.best_score = np.inf

        self.val_df = val_df
        self.fold_id = fold_id            


    def forward(
        self,
        x: torch.Tensor
        ):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        img, t, level_t = batch["spec_img"],  batch["target"], batch['level_target']

        if self.cfg.aug.do_mixup and (np.random.random() > 0.5):
            img_original, t_original, level_t_original = img[:img.shape[0]//2], t[:t.shape[0]//2], level_t[:level_t.shape[0]//2]
            img_mixup, t_mixup, level_t_mixup, index, lam = mixup_data(img[img.shape[0]//2:], t[t.shape[0]//2:], level_t[level_t.shape[0]//2:])
            img = torch.cat([img_original, img_mixup], 0)
            t = torch.cat([t_original, t_mixup], 0)
            level_t = torch.cat([level_t_original, level_t_mixup])
            batch["spec_img"] = img
            batch["target"] = t
            batch['level_target'] = level_t

            if self.cfg.use_raw_eeg:
                eeg = batch["raw_eeg"]
                eeg_original = eeg[:eeg.shape[0]//2]
                eeg_mixup = eeg[eeg.shape[0]//2:]
                eeg_mixup = lam * eeg_mixup + (1 - lam) * eeg_mixup[index]
                eeg = torch.cat([eeg_original, eeg_mixup], 0)
                batch["raw_eeg"] = eeg

        elif self.cfg.aug.do_cutmix and (np.random.random() > 0.5):
            img_original, t_original, level_t_original = img[:img.shape[0]//2], t[:t.shape[0]//2], level_t[:level_t.shape[0]//2]
            img_cutmix, t_cutmix, level_t_cutmix, index, lam = cutmix_data(img[img.shape[0]//2:], t[t.shape[0]//2:], level_t[level_t.shape[0]//2:])
            img = torch.cat([img_original, img_cutmix], 0)
            t = torch.cat([t_original, t_cutmix], 0)
            level_t = torch.cat([level_t_original, level_t_cutmix])
            batch["spec_img"] = img
            batch["target"] = t
            batch["level_target"] = level_t

            if self.cfg.use_raw_eeg:
                eeg = batch["raw_eeg"]
                eeg_original = eeg[:eeg.shape[0]//2]
                eeg_mixup = eeg[eeg.shape[0]//2:]
                eeg_mixup = lam * eeg_mixup + (1 - lam) * eeg_mixup[index]
                eeg = torch.cat([eeg_original, eeg_mixup], 0)
                batch["raw_eeg"] = eeg

        output = self.model(batch)
        loss = self.loss_func(output, t, level_t)

        if isinstance(loss, dict):
            for key in loss.keys():
                self.log(
                f"train_{key}",
                loss[key],
                on_step=False,
                on_epoch=True,
                logger=True,
                prog_bar=True,
                )
            
            return loss['loss']

        else:
            self.log(
                "train_loss",
                loss,
                on_step=False,
                on_epoch=True,
                logger=True,
                prog_bar=True,
            )

            return loss

    def validation_step(self, batch, batch_idx):
        t = batch["target"]
        level_t = batch['level_target']
        self.model.training = False
        output = self.model(batch)
        loss = self.loss_func(output, t, level_t)

        if isinstance(output, dict):
            output = output['weighted_output']


        if isinstance(loss, dict):
            for key in loss.keys():
                self.log(
                f"valid_{key}",
                loss[key],
                on_step=False,
                on_epoch=True,
                logger=True,
                prog_bar=True,
                )
            
            loss = loss['loss']

        else:
            self.log(
                "valid_loss",
                loss,
                on_step=False,
                on_epoch=True,
                logger=True,
                prog_bar=True,
            )

        self.validation_step_outputs.append(
            (   
                t.detach().cpu().numpy(),
                output.softmax(dim=1).detach().cpu(),
                loss.detach().cpu().numpy(),
            )
        )

        return loss

    def on_validation_epoch_end(self):
        labels = np.concatenate([x[0] for x in self.validation_step_outputs])
        preds = np.concatenate([x[1] for x in self.validation_step_outputs])
        losses = np.array([x[2] for x in self.validation_step_outputs])
        loss = losses.mean()

        val_pred_df = pd.DataFrame(preds, columns=self.cfg.labels)

        val_pred_df.insert(0, "label_id", self.val_df["label_id"].values)

        val_score = score(solution=self.val_df[["label_id"] + self.cfg.labels].copy().reset_index(drop=True), 
                          submission=val_pred_df, 
                          row_id_column_name='label_id')

        self.log("valid_score", val_score, on_step=False, on_epoch=True, logger=True, prog_bar=True)

        if val_score < self.best_score:
            np.save(self.cfg.dir.save_dir + f"/fold_{self.fold_id}/labels.npy", labels)
            np.save(self.cfg.dir.save_dir + f"/fold_{self.fold_id}/preds.npy", preds)
            val_pred_df.insert(0, "label_id", self.val_df["label_id"].values)
            val_pred_df.to_csv(self.cfg.dir.save_dir + f"/fold_{self.fold_id}/val_pred_df.csv", index=False)
            torch.save(self.model.state_dict(), self.cfg.dir.save_dir + f"/fold_{self.fold_id}/best_model.pth")
            print(f"Saved best model {self.best_score} -> {val_score}")
            self.best_score = val_score

        self.validation_step_outputs.clear()

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.cfg.optimizer.lr)
        scheduler = get_cosine_schedule_with_warmup(
            optimizer, num_training_steps=self.trainer.max_steps, **self.cfg.scheduler
        )
        return [optimizer], [{"scheduler": scheduler, "interval": "step"}]

In [None]:
cfg = OmegaConf.load("/home/hiramatsu/kaggle/hms-harmful-brain-activity-classification/result/exp045/.hydra/config.yaml")

seed_everything(cfg.seed)

save_dir = cfg.dir.save_dir
os.makedirs(save_dir, exist_ok=True)

best_scores = []
folds = []

for fold_id in range(cfg.n_folds):
    os.makedirs(save_dir + f"/fold_{fold_id}", exist_ok=True)
    # init lightning model
    datamodule = HMSDataModule(cfg, fold_id)

    model = HMSModel(
        cfg, datamodule.val_df, fold_id
    )

    # set callbacks
    checkpoint_cb = ModelCheckpoint(
        dirpath=save_dir+f"/fold_{fold_id}",
        verbose=True,
        monitor=cfg.trainer.monitor,
        mode=cfg.trainer.monitor_mode,
        save_top_k=1,
        save_last=False,
    )
    lr_monitor = LearningRateMonitor("epoch")
    progress_bar = RichProgressBar()
    early_stopping = EarlyStopping(monitor='valid_score', patience=cfg.early_stopping_rounds)
    model_summary = RichModelSummary(max_depth=2)

    trainer = Trainer( 
        # env
        default_root_dir=Path.cwd(),
        # num_nodes=cfg.training.num_gpus,
        accelerator=cfg.trainer.accelerator,
        devices=cfg.trainer.device,
        precision=16 if cfg.trainer.use_amp else 32,
        # training
        fast_dev_run=cfg.trainer.debug,  # run only 1 train batch and 1 val batch
        max_epochs=cfg.trainer.epochs,
        max_steps=cfg.trainer.epochs * len(datamodule.train_dataloader()),
        gradient_clip_val=cfg.trainer.gradient_clip_val,
        accumulate_grad_batches=cfg.trainer.accumulate_grad_batches,
        callbacks=[checkpoint_cb, lr_monitor, progress_bar, model_summary, early_stopping],
        # resume_from_checkpoint=resume_from,
        num_sanity_val_steps=0,
        log_every_n_steps=int(len(datamodule.train_dataloader()) * 0.1),
        sync_batchnorm=True,
        check_val_every_n_epoch=cfg.trainer.check_val_every_n_epoch,
        
    )

    trainer.fit(model, datamodule=datamodule)

    best_scores.append(model.best_score)
    folds.append(f"fold_{fold_id}")
    print(f'fold_{fold_id}: best_score is {model.best_score}')

    wandb.finish()

best_scores.append(np.mean(best_scores))
folds.append(f"mean")

print(f'CV score is {np.mean(best_scores)}.')

best_score_df = pd.DataFrame(data = {"fold_id": folds, "scores": best_scores})

best_score_df.to_csv(save_dir + '/best_scores.csv', index=False)