In [2]:
import polars as pl
import gc
import pickle
from pathlib import Path, PosixPath
from tqdm.auto import tqdm

import sys
sys.path.append('..')

from src.utils import seed_everything, get_logger, get_config, TimeUtil
from src.utils.competition_utils import clipping_input
from src.data import DataProvider, FeatureEngineering, Preprocessor, HFPreprocessor
from src.train import get_dataloader

In [3]:
# コマンドライン引数
exp = '146'

In [4]:
config = get_config(exp, config_dir=Path('../config'))
logger = get_logger(config.output_path)
logger.info(f'exp: {exp} | run_mode={config.run_mode}, multi_task={config.multi_task}, loss_type={config.loss_type}')

seed_everything(config.seed)

[ [32m2024-10-12 23:37:46[0m | [1mINFO ] exp: 146 | run_mode=hf, multi_task=False, loss_type=mae[0m


In [6]:
config.run_mode = 'hf'

In [8]:
with TimeUtil.timer('Data Loading...'):
    dpr = DataProvider(config)
    train_df, test_df = dpr.load_data()

with TimeUtil.timer('Feature Engineering...'):
    fer = FeatureEngineering(config)
    train_df = fer.feature_engineering(train_df)
    test_df = fer.feature_engineering(test_df)

with TimeUtil.timer('Scaling and Clipping Features...'):
    ppr = Preprocessor(config)
    train_df, test_df = ppr.scaling(train_df, test_df)
    input_cols, target_cols = ppr.input_cols, ppr.target_cols
    if config.task_type == 'grid_pred':
        train_df = train_df.drop(target_cols)

    valid_df = train_df.filter(pl.col('fold') == 0)
    train_df = train_df.filter(pl.col('fold') != 0)
    valid_df, input_clip_dict = clipping_input(train_df, valid_df, input_cols)
    test_df, _ = clipping_input(None, test_df, input_cols, input_clip_dict)
    pickle.dump(input_clip_dict, open(config.output_path / 'input_clip_dict.pkl', 'wb'))

with TimeUtil.timer('Converting to arrays for NN...'):
    array_data = ppr.convert_numpy_array(train_df, valid_df, test_df)
    del train_df, valid_df, test_df
    gc.collect()



[Data Loading...] done [109.1GB(45.0%)(+108.550GB)] 112.9162 s
[Feature Engineering...] done [133.5GB(33.6%)(+24.372GB)] 43.2264 s
[Scaling and Clipping Features...] done [108.4GB(22.6%)(-25.101GB)] 94.7392 s
[Converting to arrays for NN...] done [148.0GB(46.0%)(+39.667GB)] 314.0020 s


In [13]:
if config.run_mode == 'hf':
    with TimeUtil.timer('HF Data Preprocessing...'):
        del array_data['train_ids'], array_data['X_train'], array_data['y_train']
        gc.collect()

        hf_ppr = HFPreprocessor(config)
        hf_ppr.shrink_file_size()
        hf_ppr.convert_numpy_array(unlink_parquet=True)



  0%|          | 0/16 [00:00<?, ?it/s]

[HF Data Preprocessing...] done [61.8GB(6.5%)(+13.077GB)] 555.8223 s


In [50]:
with TimeUtil.timer('Creating Torch DataLoader...'):
    if config.run_mode == 'hf':
        train_loader = get_dataloader(
            config,
            from_hdf5=False,
            is_train=True
        )
    else:
        train_loader = get_dataloader(
            config,
            array_data['train_ids'],
            array_data['X_train'],
            array_data['y_train'],
            is_train=True
        )
    valid_loader = get_dataloader(
        config,
        array_data['valid_ids'],
        array_data['X_valid'],
        array_data['y_valid'],
        is_train=False
    )
    test_loader = get_dataloader(
        config,
        array_data['test_ids'],
        array_data['X_test'],
        is_train=False
    )
    del array_data
    gc.collect()

# Trainer

In [51]:
from copy import deepcopy
from typing import Optional

import torch
from torch import nn


# https://github.com/huggingface/pytorch-image-models/blob/main/timm/utils/model_ema.py
class ModelEmaV3(nn.Module):
    """Model Exponential Moving Average V3

    Keep a moving average of everything in the model state_dict (parameters and buffers).
    V3 of this module leverages for_each and in-place operations for faster performance.

    Decay warmup based on code by @crowsonkb, her comments:
      If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
      good values for models you plan to train for a million or more steps (reaches decay
      factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
      you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
      215.4k steps).

    This is intended to allow functionality like
    https://www.tensorflow.org/api_docs/python/tf/train/ExponentialMovingAverage

    To keep EMA from using GPU resources, set device='cpu'. This will save a bit of memory but
    disable validation of the EMA weights. Validation will have to be done manually in a separate
    process, or after the training stops converging.

    This class is sensitive where it is initialized in the sequence of model init,
    GPU assignment and distributed training wrappers.
    """

    def __init__(
        self,
        model,
        decay: float = 0.9999,
        min_decay: float = 0.0,
        update_after_step: int = 0,
        use_warmup: bool = False,
        warmup_gamma: float = 1.0,
        warmup_power: float = 2 / 3,
        device: torch.device | None = None,
        foreach: bool = True,
        exclude_buffers: bool = False,
    ):
        super().__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay
        self.min_decay = min_decay
        self.update_after_step = update_after_step
        self.use_warmup = use_warmup
        self.warmup_gamma = warmup_gamma
        self.warmup_power = warmup_power
        self.foreach = foreach
        self.device = (
            "cuda:0" if device == "cuda" else device
        )  # perform ema on different device from model if set
        self.exclude_buffers = exclude_buffers
        if self.device is not None and device != next(model.parameters()).device:
            self.foreach = False  # cannot use foreach methods with different devices
            self.module.to(device=device)

    def get_decay(self, step: int | None = None) -> float:
        """
        Compute the decay factor for the exponential moving average.
        """
        if step is None:
            return self.decay

        step = max(0, step - self.update_after_step - 1)
        if step <= 0:
            return 0.0

        if self.use_warmup:
            decay = 1 - (1 + step / self.warmup_gamma) ** -self.warmup_power
            decay = max(min(decay, self.decay), self.min_decay)
        else:
            decay = self.decay

        return decay

    @torch.no_grad()
    def update(self, model, step: int | None = None):
        decay = self.get_decay(step)
        if self.exclude_buffers:
            self.apply_update_no_buffers_(model, decay)
        else:
            self.apply_update_(model, decay)

    def apply_update_(self, model, decay: float):
        # interpolate parameters and buffers
        if self.foreach:
            ema_lerp_values = []
            model_lerp_values = []
            for ema_v, model_v in zip(
                self.module.state_dict().values(), model.state_dict().values()
            ):
                if ema_v.is_floating_point():
                    ema_lerp_values.append(ema_v)
                    model_lerp_values.append(model_v)
                else:
                    ema_v.copy_(model_v)

            if hasattr(torch, "_foreach_lerp_"):
                torch._foreach_lerp_(ema_lerp_values, model_lerp_values, weight=1.0 - decay)
            else:
                torch._foreach_mul_(ema_lerp_values, scalar=decay)
                torch._foreach_add_(ema_lerp_values, model_lerp_values, alpha=1.0 - decay)
        else:
            for ema_v, model_v in zip(
                self.module.state_dict().values(), model.state_dict().values()
            ):
                if ema_v.is_floating_point():
                    ema_v.lerp_(model_v, weight=1.0 - decay)
                else:
                    ema_v.copy_(model_v)

    def apply_update_no_buffers_(self, model, decay: float):
        # interpolate parameters, copy buffers
        ema_params = tuple(self.module.parameters())
        model_params = tuple(model.parameters())
        if self.foreach:
            if hasattr(torch, "_foreach_lerp_"):
                torch._foreach_lerp_(ema_params, model_params, weight=1.0 - decay)
            else:
                torch._foreach_mul_(ema_params, scalar=decay)
                torch._foreach_add_(ema_params, model_params, alpha=1 - decay)
        else:
            for ema_p, model_p in zip(ema_params, model_params):
                ema_p.lerp_(model_p, weight=1.0 - decay)

        for ema_b, model_b in zip(self.module.buffers(), model.buffers()):
            ema_b.copy_(model_b.to(device=self.device))

    @torch.no_grad()
    def set(self, model):
        for ema_v, model_v in zip(self.module.state_dict().values(), model.state_dict().values()):
            ema_v.copy_(model_v.to(device=self.device))

    def forward(self, *args, **kwargs):
        return self.module(*args, **kwargs)


In [52]:
import pickle
from collections import defaultdict
from typing import Dict, List, Literal, Tuple

import loguru
import numpy as np
import polars as pl
import torch
from omegaconf import DictConfig, OmegaConf
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from src.train import ComponentFactory
from src.train.train_utils import AverageMeter
from src.utils import clean_message
from src.utils.competition_utils import evaluate_metric, get_io_columns, get_sub_factor
from src.utils.constant import (
    PP_TARGET_COLS,
    SCALER_TARGET_COLS,
    TARGET_MIN_MAX,
    VERTICAL_TARGET_COLS,
)


class Trainer:
    def __init__(self, config: DictConfig, logger: loguru._Logger, save_suffix: str = ""):
        self.config = config
        self.eval_step = config.eval_step[config.run_mode]
        self.logger = logger
        self.save_suffix = save_suffix
        self.detail_pbar = True

        self.model = ComponentFactory.get_model(config)
        self.model = self.model.to(config.device)
        n_device = torch.cuda.device_count()
        if n_device > 1:
            self.model = nn.DataParallel(self.model)
        if config.ema:
            self.model_ema = None

        self.loss_fn = ComponentFactory.get_loss(config)
        self.train_loss = AverageMeter()
        self.valid_loss = AverageMeter()

        _, self.target_cols = get_io_columns(config)
        self.model_target_cols = self.get_model_target_cols()
        self.factor_dict = get_sub_factor(config.input_path, old=False)
        self.old_factor_dict = get_sub_factor(config.input_path, old=True)

        self.y_numerators = np.load(
            config.output_path / f"y_numerators_{config.target_scale_method}.npy"
        )
        self.y_denominators = np.load(
            config.output_path / f"y_denominators_{config.target_scale_method}.npy"
        )
        self.target_min_max = [TARGET_MIN_MAX[col] for col in self.target_cols]

        self.valid_ids = None
        self.test_ids = None
        self.valid_pp_df = None
        self.test_pp_df = None
        self.pp_run = True
        self.pp_y_cols = PP_TARGET_COLS
        self.pp_x_cols = [col.replace("ptend", "state") for col in self.pp_y_cols]

        self.best_score_dict = defaultdict(lambda: (-1, -np.inf))

    def train(
        self,
        train_loader: DataLoader,
        valid_loader: DataLoader,
        colwise_mode: bool = True,
        retrain: bool = False,
        retrain_weight_name: str = "",
        retrain_best_score: float = -np.inf,
        eval_only: bool = False,
    ):
        if eval_only:
            self.best_score_dict = pickle.load(
                open(self.config.output_path / f"best_score_dict{self.save_suffix}.pkl", "rb")
            )
            eval_method = "colwise" if colwise_mode else "single"
            score, cw_score, preds, _ = self.valid_evaluate(
                valid_loader, current_epoch=-1, eval_count=-1, eval_method=eval_method
            )
            self.save_oof_df(self.valid_ids, preds)
            return score, cw_score, -1

        self.optimizer = ComponentFactory.get_optimizer(self.config, self.model)
        steps_per_epoch = len(train_loader) if self.config.run_mode != 'hf' else self.config.eval_step[self.config.run_mode]
        self.scheduler = ComponentFactory.get_scheduler(
            self.config, self.optimizer, steps_per_epoch=steps_per_epoch
        )
        global_step = 0
        eval_count = 0
        best_score = -np.inf

        if retrain:
            self.best_score_dict = pickle.load(
                open(self.config.output_path / f"best_score_dict{self.save_suffix}.pkl", "rb")
            )
            self.model.load_state_dict(
                torch.load(self.config.output_path / f"{retrain_weight_name}.pth")
            )
            weight_numbers = [
                int(file.stem.split("_")[-1].replace("eval", ""))
                for file in list(self.config.output_path.glob(f"model{self.save_suffix}_eval*.pth"))
            ]
            eval_count = sorted(weight_numbers)[-1] + 1
            best_score = retrain_best_score

        # 学習ループの開始
        for epoch in tqdm(range(self.config.epochs)):
            self.model.train()
            self.train_loss.reset()

            iterations = (
                tqdm(train_loader, total=len(train_loader)) if self.detail_pbar else train_loader
            )
            for data in iterations:
                _, loss = self.forward_step(data, calc_loss=True)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                self.train_loss.update(loss.item(), n=data[0].size(0))
                global_step += 1

                if global_step % self.eval_step == 0:
                    score, _, preds, update_num = self.valid_evaluate(
                        valid_loader,
                        current_epoch=epoch,
                        eval_count=eval_count,
                        eval_method="single",
                    )
                    if colwise_mode and update_num > 0:
                        torch.save(
                            self.model.state_dict(),
                            self.config.output_path
                            / f"model{self.save_suffix}_eval{eval_count}.pth",
                        )

                    if score > best_score:
                        best_score = score
                        best_preds = preds
                        best_epochs = epoch
                        torch.save(
                            self.model.state_dict(),
                            self.config.output_path / f"model{self.save_suffix}_best.pth",
                        )

                    eval_count += 1
                    self.model.train()

            message = f"""
                [Train] :
                    Epoch={epoch},
                    Loss={self.train_loss.avg:.5f},
                    LR={self.optimizer.param_groups[0]["lr"]:.5e}
            """
            self.logger.info(clean_message(message))

            if self.config.run_mode == 'hf':
                train_loader = self.update_train_loader(train_loader)

        if colwise_mode:
            self.remove_unuse_weights()
            best_score, best_cw_score, best_preds, _ = self.valid_evaluate(
                valid_loader, current_epoch=-1, eval_count=-1, eval_method="colwise"
            )

        self.save_oof_df(self.valid_ids, best_preds)
        return best_score, best_cw_score, best_epochs

    def valid_evaluate(
        self,
        valid_loader: DataLoader,
        current_epoch: int,
        eval_count: int,
        eval_method: Literal["single", "colwise"] = "single",
    ):
        if self.valid_ids is None:
            self.valid_ids = valid_loader.dataset.ids

        if eval_method == "single":
            load_best_weight = True if eval_count == -1 else False
            preds = self.inference_loop(
                valid_loader, mode="valid", load_best_weight=load_best_weight
            )
        elif eval_method == "colwise":
            preds = self.inference_loop_colwise(valid_loader, "valid", self.best_score_dict)

        labels = valid_loader.dataset.y
        if self.config.target_shape == "3dim":
            labels = self.convert_target_3dim_to_2dim(labels)
        preds = self.restore_pred(preds)
        labels = self.restore_pred(labels)

        if self.pp_run and self.valid_pp_df is None:
            self.load_postprocess_input("valid")
        if self.pp_run:
            preds = self.postprocess(preds, run_type="valid")
        if self.config.out_clip:
            preds = self.clipping_pred(preds)

        eval_idx = [
            i for i, col in enumerate(self.target_cols) if self.factor_dict[col] != 0
        ]  # factor_dictの値が0のものは自動でR2=1になるようにする
        score, indiv_scores = evaluate_metric(preds, labels, eval_idx=eval_idx)
        cw_score, update_num = self.update_best_score(indiv_scores, eval_count)

        message = f"""
            [Valid] :
                Epoch={current_epoch},
                Loss={self.valid_loss.avg:.5f},
                Score={score:.5f},
                Best Col-Wise Score={cw_score:.5f}
        """
        self.logger.info(clean_message(message))
        return score, cw_score, preds, update_num

    def test_predict(
        self, test_loader: DataLoader, eval_method: Literal["single", "colwise"] = "single"
    ):
        if self.test_ids is None:
            self.test_ids = test_loader.dataset.ids

        if eval_method == "single":
            preds = self.inference_loop(test_loader, mode="test", load_best_weight=True)
        elif eval_method == "colwise":
            self.best_score_dict = pickle.load(
                open(self.config.output_path / f"best_score_dict{self.save_suffix}.pkl", "rb")
            )
            preds = self.inference_loop_colwise(test_loader, "test", self.best_score_dict)

        preds = self.restore_pred(preds)
        if self.pp_run and self.test_pp_df is None:
            self.load_postprocess_input("test")
        if self.pp_run:
            preds = self.postprocess(preds, run_type="test")
        if self.config.out_clip:
            preds = self.clipping_pred(preds)

        pred_df = pl.DataFrame(preds, schema=self.target_cols)
        pred_df = pred_df.with_columns(sample_id=pl.Series(self.test_ids))
        return pred_df

    def forward_step(self, data: torch.Tensor, calc_loss: bool = True):
        if calc_loss:
            x, y = data
            x, y = x.to(self.config.device), y.to(self.config.device)
            out = self.model(x)
            loss = self.loss_fn(out, y)
        else:
            x = data[0]
            x = x.to(self.config.device)
            out = self.model(x)
            loss = None

        if self.config.multi_task:
            out = out[:, :, :self.config.out_dim]

        if self.config.target_shape == "3dim":
            out = self.convert_target_3dim_to_2dim(out)
        return out, loss

    def inference_loop(
        self,
        eval_loader: DataLoader,
        mode: Literal["valid", "test"],
        load_best_weight: bool = False,
    ):
        self.model.eval()
        if mode == "valid":
            self.valid_loss.reset()

        if load_best_weight:
            self.model.load_state_dict(
                torch.load(self.config.output_path / f"model{self.save_suffix}_best.pth")
            )

        preds = []
        with torch.no_grad():
            iterations = (
                tqdm(eval_loader, total=len(eval_loader)) if self.detail_pbar else eval_loader
            )
            for data in iterations:
                if mode == "valid":
                    out, loss = self.forward_step(data, calc_loss=True)
                    self.valid_loss.update(loss.item(), n=data[0].size(0))
                elif mode == "test":
                    out, _ = self.forward_step(data, calc_loss=False)
                preds.append(out.detach().cpu().numpy())
        preds = np.concatenate(preds, axis=0)
        return preds

    def inference_loop_colwise(
        self,
        test_loader: DataLoader,
        mode: Literal["valid", "test"],
        best_score_dict: dict[str, tuple[int, float]],
    ):
        self.model.eval()
        if mode == "valid":
            self.valid_loss.reset()

        selected_counts = list(set([eval_count for eval_count, _ in best_score_dict.values()]))
        all_preds = np.zeros((len(test_loader.dataset), len(self.target_cols)))
        for eval_count in tqdm(selected_counts):
            self.model.load_state_dict(
                torch.load(
                    self.config.output_path / f"model{self.save_suffix}_eval{eval_count}.pth"
                )
            )
            preds = []
            with torch.no_grad():
                iterations = (
                    tqdm(test_loader, total=len(test_loader)) if self.detail_pbar else test_loader
                )
                for data in iterations:
                    if mode == "valid":
                        out, loss = self.forward_step(data, calc_loss=True)
                        self.valid_loss.update(loss.item(), n=data[0].size(0))
                    elif mode == "test":
                        out, _ = self.forward_step(data, calc_loss=False)
                    preds.append(out.detach().cpu().numpy())
            preds = np.concatenate(preds, axis=0)

            target_cols = [
                col for col, (count, _) in best_score_dict.items() if count == eval_count
            ]
            for col in target_cols:
                idx = self.target_cols.index(col)
                all_preds[:, idx] = preds[:, idx]
        return all_preds

    def update_train_loader(self, train_loader: DataLoader):
        train_dataset = train_loader.dataset
        train_dataset.update()
        train_loader = DataLoader(
            train_dataset,
            batch_size=self.config.train_batch,
            shuffle=True,
            pin_memory=True,
            drop_last=True
        )
        return train_loader

    def update_best_score(self, indiv_scores: list[float], eval_count: int):
        update_num = 0
        for col, score in zip(self.target_cols, indiv_scores):
            if score > self.best_score_dict[col][1] and eval_count != -1:
                self.best_score_dict[col] = (eval_count, score)
                update_num += 1

        best_cw_score = (
            np.sum([score for _, score in self.best_score_dict.values()])
            + (368 - len(self.target_cols))
        ) / 368
        if update_num > 0 and eval_count != -1:
            pickle.dump(
                dict(self.best_score_dict),
                open(self.config.output_path / f"best_score_dict{self.save_suffix}.pkl", "wb"),
            )
        return best_cw_score, update_num

    def remove_unuse_weights(self):
        selected_counts = set([v[0] for v in self.best_score_dict.values()])
        weight_paths = list(self.config.output_path.glob(f"model{self.save_suffix}_eval*.pth"))
        for path in weight_paths:
            eval_count = int(path.stem.split("_")[-1].replace("eval", ""))
            if eval_count not in selected_counts:
                path.unlink()

    def convert_target_3dim_to_2dim(
        self, y: np.ndarray | torch.Tensor
    ) -> np.ndarray | torch.Tensor:
        y_v = y[:, :, : len(VERTICAL_TARGET_COLS)]
        y_s = y[:, :, len(VERTICAL_TARGET_COLS) :]
        if isinstance(y, np.ndarray):
            y_v = np.transpose(y_v, (0, 2, 1)).reshape(y.shape[0], -1)
            y_s = y_s.mean(axis=1)
            y = np.concatenate([y_v, y_s], axis=-1)
        elif isinstance(y, torch.Tensor):
            y_v = y_v.permute(0, 2, 1).reshape(y.size(0), -1)
            y_s = y_s.mean(dim=1)
            y = torch.cat([y_v, y_s], dim=-1)
        y = self.alignment_target_idx(y)
        return y

    def alignment_target_idx(self, y: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
        align_order = [self.model_target_cols.index(col) for col in self.target_cols]
        assert len(y.shape) == 2
        y = y[:, align_order]
        return y

    def get_model_target_cols(self):
        model_target_cols = []
        for col in VERTICAL_TARGET_COLS:
            model_target_cols.extend([f"{col}_{i}" for i in range(60)])
        for col in SCALER_TARGET_COLS:
            model_target_cols.append(col)
        return model_target_cols

    def restore_pred(self, preds: np.ndarray):
        return preds * self.y_denominators + self.y_numerators

    def clipping_pred(self, preds: np.ndarray):
        for i in range(preds.shape[1]):
            preds[:, i] = np.clip(preds[:, i], self.target_min_max[i][0], self.target_min_max[i][1])
        return preds

    def save_oof_df(self, sample_ids: np.ndarray, preds: np.ndarray):
        oof_df = pl.DataFrame(preds, schema=self.target_cols)
        oof_df = oof_df.with_columns(sample_id=pl.Series(sample_ids))
        oof_df.write_parquet(self.config.oof_path / f"oof{self.save_suffix}.parquet")

    def postprocess(self, preds: np.ndarray, run_type: Literal["valid", "test"]):
        pp_x = self.valid_pp_df if run_type == "valid" else self.test_pp_df
        for x_col, y_col in zip(self.pp_x_cols, self.pp_y_cols):
            if y_col in self.target_cols:
                idx = self.target_cols.index(y_col)
                old_factor = self.old_factor_dict[y_col] if self.config.mul_old_factor else 1
                preds[:, idx] = (-1 * pp_x[x_col].to_numpy() / 1200) * old_factor
        return preds

    def load_postprocess_input(self, data_type: Literal["valid", "test"]):
        if data_type == "valid":
            valid_path = (
                self.config.input_path / "18_shrinked.parquet"
                if self.config.shared_valid
                else self.config.input_path / "train_shrinked.parquet"
            )
            self.valid_pp_df = (
                pl.scan_parquet(valid_path)
                .select(["sample_id"] + self.pp_x_cols)
                .filter(pl.col("sample_id").is_in(self.valid_ids))
                .collect()
            )
            id_df = pl.DataFrame({"sample_id": self.valid_ids})
            self.valid_pp_df = id_df.join(self.valid_pp_df, on="sample_id", how="left")

        elif data_type == "test":
            self.test_pp_df = pl.read_parquet(
                self.config.input_path / "test_shrinked.parquet",
                columns=["sample_id"] + self.pp_x_cols,
            )
            id_df = pl.DataFrame({"sample_id": self.test_ids})
            self.test_pp_df = id_df.join(self.test_pp_df, on="sample_id", how="left")


In [53]:
trainer = Trainer(config, logger)

In [54]:
# config.eval_step['dev'] = 50

In [55]:
oof_df = trainer.train(
    train_loader,
    valid_loader,
    colwise_mode=True,
    # eval_only=True
)

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/1582 [00:00<?, ?it/s]

In [None]:
pred_df = trainer.test_predict(test_loader, eval_method="colwise")

  torch.load(self.config.output_path / f"model{self.save_suffix}_best.pth")


  0%|          | 0/153 [00:00<?, ?it/s]

# PostProcess

In [29]:
from omegaconf import DictConfig
import loguru
from typing import Literal
from sklearn.metrics import r2_score

from src.utils.competition_utils import get_sub_factor, get_io_columns

class PostProcess:
    def __init__(self, config: DictConfig, logger: loguru._Logger, additional: bool = True):
        self.config = config
        self.logger = logger
        self.additional = additional

        _, self.target_cols = get_io_columns(config)
        self.old_factor_dict = get_sub_factor(config.input_path, old=True)
        self.sub_cols = pl.read_parquet(config.input_path / 'sample_submission.parquet', n_rows=1).columns

        self.pp_x_cols = [f'state_q0002_{i}' for i in range(12, 27)]
        self.pp_y_cols = [f'ptend_q0002_{i}' for i in range(12, 27)]
        self.valid_pp_df = pl.read_parquet(
            config.input_path / '18_shrinked.parquet',
            columns=['sample_id'] + self.pp_x_cols
        )
        self.test_pp_df = pl.read_parquet(
            config.input_path / 'test_shrinked.parquet',
            columns=['sample_id'] + self.pp_x_cols
        )

        add_pp_y_cols = (
            [f'ptend_q0002_{i}' for i in range(60)] +
            [f'ptend_q0003_{i}' for i in range(60)]
        )
        self.add_pp_y_cols = [col for col in add_pp_y_cols if col in self.target_cols]
        self.add_pp_x_cols = [col.replace('ptend', 'state') for col in self.add_pp_y_cols]
        self.add_valid_pp_df = pl.read_parquet(
            config.input_path / '18_shrinked.parquet',
            columns=self.sub_cols + self.add_pp_x_cols
        )
        self.add_test_pp_df = pl.read_parquet(
            config.input_path / 'test_shrinked.parquet',
            columns=['sample_id'] + self.add_pp_x_cols
        )
        self.th_dict = None

    def postprocess(self, oof_df: pl.DataFrame, sub_df: pl.DataFrame):
        oof_df = self.complement_columns(oof_df)
        oof_df = self.reverse_sub_factor(oof_df)
        oof_df = self.replace_postprocess(oof_df, 'oof')

        sub_df = self.complement_columns(sub_df)
        sub_df = self.reverse_sub_factor(sub_df)
        sub_df = self.replace_postprocess(sub_df, 'sub')

        if self.additional:
            oof_df = self.additional_postprocess(oof_df, 'oof')
            sub_df = self.additional_postprocess(sub_df, 'sub')

        oof_df = self.create_oof_df(oof_df)
        sub_df = self.create_sub_df(sub_df)
        return oof_df, sub_df

    def complement_columns(self, pred_df: pl.DataFrame):
        lack_cols = list(set(self.sub_cols) - set(pred_df.columns))
        for col in lack_cols:
            pred_df = pred_df.with_columns([pl.lit(0).alias(col)])
        return pred_df

    def reverse_sub_factor(self, pred_df: pl.DataFrame):
        if self.config.mul_old_factor:
            exprs = []
            for col in self.target_cols:
                if self.old_factor_dict[col] != 0:
                    exprs.append((pl.col(col) / self.old_factor_dict[col]).alias(col))

            pred_df = pred_df.with_columns(exprs)
        return pred_df

    def replace_postprocess(self, pred_df: pl.DataFrame, pred_type: Literal['oof', 'sub']):
        pp_df = self.valid_pp_df if pred_type == 'oof' else self.test_pp_df
        pred_df = pred_df.join(pp_df, on=['sample_id'], how='left')

        exprs = []
        for x_col, y_col in zip(self.pp_x_cols, self.pp_y_cols):
            exprs.append((-1 * pl.col(x_col) / 1200).alias(y_col))
        pred_df = pred_df.with_columns(exprs)
        pred_df = pred_df.drop(self.pp_x_cols)
        return pred_df

    def additional_postprocess(self, pred_df: pl.DataFrame, pred_type: Literal['oof', 'sub']):
        pp_df = self.add_valid_pp_df if pred_type == 'oof' else self.add_test_pp_df
        pred_df = pred_df.join(pp_df, on=['sample_id'], how='left', suffix='_gt')
        exprs = []
        for x_col, y_col in zip(self.add_pp_x_cols, self.add_pp_y_cols):
            exprs.append((pl.col(x_col) + pl.col(y_col) * 1200).alias(f'{x_col}_next'))
        pred_df = pred_df.with_columns(exprs)

        if pred_type == 'oof':
            self.tuning_threshold(pred_df)

        assert self.th_dict is not None # oofから実行する必要がある
        exprs = []
        for y_col, (best_th, _) in self.th_dict.items():
            x_col = y_col.replace('ptend', 'state')
            exprs.append(
                pl.when(pl.col(f'{x_col}_next') < best_th)
                .then(-1 * pl.col(x_col) / 1200)
                .otherwise(pl.col(y_col))
                .alias(y_col)
            )
        pred_df = pred_df.with_columns(exprs)

        if pred_type == 'oof':
            scores = []
            for col in self.target_cols:
                score = r2_score(pred_df[f'{col}_gt'].to_numpy(), pred_df[col].to_numpy())
                scores.append(score)
            total_score = (np.sum(scores) + (368 - len(scores))) / 368
            self.logger.info(f'After Additional Postprocess: {total_score:.5f}')

        drop_cols = (
            self.add_pp_x_cols +
            [f'{col}_next' for col in self.add_pp_x_cols] +
            [col for col in pred_df.columns if '_gt' in col]
        )
        pred_df = pred_df.drop(drop_cols)
        return pred_df

    def tuning_threshold(self, pred_df: pl.DataFrame):
        iterations = tqdm(zip(self.add_pp_x_cols, self.add_pp_y_cols), total=len(self.add_pp_x_cols))
        for x_col, y_col in iterations:
            best_score = r2_score(pred_df[f'{y_col}_gt'].to_numpy(), pred_df[y_col].to_numpy())
            best_th = None
            for th_base in [0, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5]:
                for corr in range(1, 10):
                    if th_base == 0 and corr >= 2:
                        break

                    th = th_base * corr
                    preds = pred_df.select(
                        pl.when(pl.col(f'{x_col}_next') < th)
                        .then(-1 * pl.col(x_col) / 1200)
                        .otherwise(pl.col(y_col))
                    ).to_numpy()

                    truths = pred_df[f'{y_col}_label'].to_numpy()
                    score = r2_score(truths, preds)
                    if score > best_score:
                        best_score = score
                        best_th = th

            if best_th is not None:
                self.th_dict[y_col] = (best_th, best_score)


    def create_oof_df(self, oof_df: pl.DataFrame):
        oof_df = oof_df.select(self.sub_cols)
        oof_df.write_parquet(self.config.oof_path / 'oof_pp.parquet')
        return oof_df

    def create_sub_df(self, sub_df: pl.DataFrame):
        sub_df = sub_df.with_columns(sample_id = pl.concat_str([pl.lit('test_'), pl.col('sample_id')]))
        sub_df = sub_df.select(self.sub_cols)
        sub_df.write_csv(self.config.output_path / 'submission_pp.csv')
        return sub_df