In [1]:
import polars as pl
import gc
import pickle
from pathlib import Path, PosixPath
from tqdm.auto import tqdm

import sys
sys.path.append('..')

from src.utils import seed_everything, get_logger, get_config, TimeUtil
from src.utils.competition_utils import clipping_input
from src.data import DataProvider, FeatureEngineering, Preprocessor, HFPreprocessor
from src.train import get_dataloader

In [2]:
# コマンドライン引数
exp = '146'

In [3]:
config = get_config(exp, config_dir=Path('../config'))
logger = get_logger(config.output_path)
logger.info(f'exp: {exp} | run_mode={config.run_mode}, multi_task={config.multi_task}, loss_type={config.loss_type}')

seed_everything(config.seed)

[ [32m2024-10-10 14:50:56[0m | [1mINFO ] exp: 146 | run_mode=hf, multi_task=False, loss_type=mae[0m


In [4]:
config.run_mode = 'dev'
config.multi_task = False

In [5]:
with TimeUtil.timer('Data Loading...'):
    dpr = DataProvider(config)
    train_df, test_df = dpr.load_data()



In [6]:
with TimeUtil.timer('Feature Engineering...'):
    fer = FeatureEngineering(config)
    train_df = fer.feature_engineering(train_df)
    test_df = fer.feature_engineering(test_df)

[Feature Engineering...] done [69.4GB(30.2%)(+2.975GB)] 27.8554 s


In [7]:
with TimeUtil.timer('Scaling and Clipping Features...'):
    ppr = Preprocessor(config)
    train_df, test_df = ppr.scaling(train_df, test_df)
    input_cols, target_cols = ppr.input_cols, ppr.target_cols
    if config.task_type == 'grid_pred':
        train_df = train_df.drop(target_cols)

    valid_df = train_df.filter(pl.col('fold') == 0)
    train_df = train_df.filter(pl.col('fold') != 0)
    valid_df, input_clip_dict = clipping_input(train_df, valid_df, input_cols)
    test_df, _ = clipping_input(None, test_df, input_cols, input_clip_dict)
    pickle.dump(input_clip_dict, open(config.output_path / 'input_clip_dict.pkl', 'wb'))

[Scaling and Clipping Features...] done [34.4GB(22.5%)(-35.001GB)] 36.5791 s


In [8]:
with TimeUtil.timer('Converting to arrays for NN...'):
    array_data = ppr.convert_numpy_array(train_df, valid_df, test_df)
    del train_df, valid_df, test_df
    gc.collect()

[Converting to arrays for NN...] done [53.5GB(40.7%)(+19.076GB)] 179.8489 s


In [9]:
# Prepare HF Data
if config.run_mode == 'hf':
    with TimeUtil.timer('HF Data Preprocessing...'):
        hf_ppr = HFPreprocessor(config)
        # hf_pcr.preprocess_data()
        # hf_pcr.convert_numpy_array(near_target=False)
        # del train_loader; gc.collect()
        # train_loader = None

In [10]:
with TimeUtil.timer('Creating Torch DataLoader...'):
    train_loader = get_dataloader(
        config,
        array_data['train_ids'],
        array_data['X_train'],
        array_data['y_train'],
        is_train=True
    )
    valid_loader = get_dataloader(
        config,
        array_data['valid_ids'],
        array_data['X_valid'],
        array_data['y_valid'],
        is_train=False
    )
    test_loader = get_dataloader(
        config,
        array_data['test_ids'],
        array_data['X_test'],
        is_train=False
    )
    del array_data
    gc.collect()

[Creating Torch DataLoader...] done [53.5GB(40.6%)(+0.000GB)] 0.0876 s


# Trainer

In [31]:
import pickle
from collections import defaultdict
from typing import Dict, List, Literal, Tuple

import loguru
import numpy as np
import polars as pl
import torch
from omegaconf import DictConfig, OmegaConf
from torch import nn
from torch.utils.data import DataLoader
from tqdm.auto import tqdm

from src.train import ComponentFactory
from src.train.train_utils import AverageMeter
from src.utils import clean_message
from src.utils.competition_utils import evaluate_metric, get_io_columns, get_sub_factor
from src.utils.constant import (
    PP_TARGET_COLS,
    SCALER_TARGET_COLS,
    TARGET_MIN_MAX,
    VERTICAL_TARGET_COLS,
)


class Trainer:
    def __init__(self, config: DictConfig, logger: loguru._Logger, save_suffix: str = ""):
        self.config = config
        self.eval_step = config.eval_step[config.run_mode]
        self.logger = logger
        self.save_suffix = save_suffix
        self.detail_pbar = True

        self.model = ComponentFactory.get_model(config)
        self.model = self.model.to(config.device)
        n_device = torch.cuda.device_count()
        if n_device > 1:
            self.model = nn.DataParallel(self.model)
        self.loss_fn = ComponentFactory.get_loss(config)
        self.train_loss = AverageMeter()
        self.valid_loss = AverageMeter()

        _, self.target_cols = get_io_columns(config)
        self.model_target_cols = self.get_model_target_cols()
        self.factor_dict = get_sub_factor(config.input_path, old=False)
        self.old_factor_dict = get_sub_factor(config.input_path, old=True)

        self.y_numerators = np.load(
            config.output_path / f"y_numerators_{config.target_scale_method}.npy"
        )
        self.y_denominators = np.load(
            config.output_path / f"y_denominators_{config.target_scale_method}.npy"
        )
        self.target_min_max = [TARGET_MIN_MAX[col] for col in self.target_cols]

        self.valid_ids = None
        self.test_ids = None
        self.valid_pp_df = None
        self.test_pp_df = None
        self.pp_run = True
        self.pp_y_cols = PP_TARGET_COLS
        self.pp_x_cols = [col.replace("ptend", "state") for col in self.pp_y_cols]

        self.best_score_dict = defaultdict(lambda: (-1, -np.inf))

    def train(
        self,
        train_loader: DataLoader,
        valid_loader: DataLoader,
        colwise_mode: bool = True,
        retrain: bool = False,
        retrain_weight_name: str = "",
        retrain_best_score: float = -np.inf,
        eval_only: bool = False,
    ):
        if eval_only:
            self.best_score_dict = pickle.load(
                open(self.config.output_path / f"best_score_dict{self.save_suffix}.pkl", "rb")
            )
            eval_method = "colwise" if colwise_mode else "single"
            score, cw_score, preds, _ = self.valid_evaluate(
                valid_loader, current_epoch=-1, eval_count=-1, eval_method=eval_method
            )
            self.save_oof_df(self.valid_ids, preds)
            return score, cw_score, -1

        self.optimizer = ComponentFactory.get_optimizer(self.config, self.model)
        self.scheduler = ComponentFactory.get_scheduler(
            self.config, self.optimizer, steps_per_epoch=len(train_loader)
        )
        global_step = 0
        eval_count = 0
        best_score = -np.inf

        if retrain:
            self.best_score_dict = pickle.load(
                open(self.config.output_path / f"best_score_dict{self.save_suffix}.pkl", "rb")
            )
            self.model.load_state_dict(
                torch.load(self.config.output_path / f"{retrain_weight_name}.pth")
            )
            weight_numbers = [
                int(file.stem.split("_")[-1].replace("eval", ""))
                for file in list(self.config.output_path.glob(f"model{self.save_suffix}_eval*.pth"))
            ]
            eval_count = sorted(weight_numbers)[-1] + 1
            best_score = retrain_best_score

        # 学習ループの開始
        for epoch in tqdm(range(self.config.epochs)):
            self.model.train()
            self.train_loss.reset()

            iterations = (
                tqdm(train_loader, total=len(train_loader)) if self.detail_pbar else train_loader
            )
            for data in iterations:
                _, loss = self.forward_step(data, calc_loss=True)
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                self.scheduler.step()
                self.train_loss.update(loss.item(), n=data[0].size(0))
                global_step += 1

                if global_step % self.eval_step == 0:
                    score, _, preds, update_num = self.valid_evaluate(
                        valid_loader,
                        current_epoch=epoch,
                        eval_count=eval_count,
                        eval_method="single",
                    )
                    if colwise_mode and update_num > 0:
                        torch.save(
                            self.model.state_dict(),
                            self.config.output_path
                            / f"model{self.save_suffix}_eval{eval_count}.pth",
                        )

                    if score > best_score:
                        best_score = score
                        best_preds = preds
                        best_epochs = epoch
                        torch.save(
                            self.model.state_dict(),
                            self.config.output_path / f"model{self.save_suffix}_best.pth",
                        )

                    eval_count += 1
                    self.model.train()

            message = f"""
                [Train] :
                    Epoch={epoch},
                    Loss={self.train_loss.avg:.5f},
                    LR={self.optimizer.param_groups[0]["lr"]:.5e}
            """
            self.logger.info(clean_message(message))

        if colwise_mode:
            self.remove_unuse_weights()
            best_score, best_cw_score, best_preds, _ = self.valid_evaluate(
                valid_loader, current_epoch=-1, eval_count=-1, eval_method="colwise"
            )

        self.save_oof_df(self.valid_ids, best_preds)
        return best_score, best_cw_score, best_epochs

    def valid_evaluate(
        self,
        valid_loader: DataLoader,
        current_epoch: int,
        eval_count: int,
        eval_method: Literal["single", "colwise"] = "single",
    ):
        if self.valid_ids is None:
            self.valid_ids = valid_loader.dataset.ids

        if eval_method == "single":
            load_best_weight = True if eval_count == -1 else False
            preds = self.inference_loop(
                valid_loader, mode="valid", load_best_weight=load_best_weight
            )
        elif eval_method == "colwise":
            preds = self.inference_loop_colwise(valid_loader, "valid", self.best_score_dict)

        labels = valid_loader.dataset.y
        if self.config.target_shape == "3dim":
            labels = self.convert_target_3dim_to_2dim(labels)
        preds = self.restore_pred(preds)
        labels = self.restore_pred(labels)

        if self.pp_run and self.valid_pp_df is None:
            self.load_postprocess_input("valid")
        if self.pp_run:
            preds = self.postprocess(preds, run_type="valid")
        if self.config.out_clip:
            preds = self.clipping_pred(preds)

        eval_idx = [
            i for i, col in enumerate(self.target_cols) if self.factor_dict[col] != 0
        ]  # factor_dictの値が0のものは自動でR2=1になるようにする
        score, indiv_scores = evaluate_metric(preds, labels, eval_idx=eval_idx)
        cw_score, update_num = self.update_best_score(indiv_scores, eval_count)

        message = f"""
            [Valid] :
                Epoch={current_epoch},
                Loss={self.valid_loss.avg:.5f},
                Score={score:.5f},
                Best Col-Wise Score={cw_score:.5f}
        """
        self.logger.info(clean_message(message))
        return score, cw_score, preds, update_num

    def test_predict(
        self, test_loader: DataLoader, eval_method: Literal["single", "colwise"] = "single"
    ):
        if self.test_ids is None:
            self.test_ids = test_loader.dataset.ids

        if eval_method == "single":
            preds = self.inference_loop(test_loader, mode="test", load_best_weight=True)
        elif eval_method == "colwise":
            self.best_score_dict = pickle.load(
                open(self.config.output_path / f"best_score_dict{self.save_suffix}.pkl", "rb")
            )
            preds = self.inference_loop_colwise(test_loader, "test", self.best_score_dict)

        preds = self.restore_pred(preds)
        if self.pp_run and self.test_pp_df is None:
            self.load_postprocess_input("test")
        if self.pp_run:
            preds = self.postprocess(preds, run_type="test")
        if self.config.out_clip:
            preds = self.clipping_pred(preds)

        pred_df = pl.DataFrame(preds, schema=self.target_cols)
        pred_df = pred_df.with_columns(sample_id=pl.Series(self.test_ids))
        return pred_df

    def inference_loop(
        self,
        eval_loader: DataLoader,
        mode: Literal["valid", "test"],
        load_best_weight: bool = False,
    ):
        self.model.eval()
        if mode == "valid":
            self.valid_loss.reset()

        # テストデータを推論するときはbest_weightを読み込む
        if load_best_weight:
            self.model.load_state_dict(
                torch.load(self.config.output_path / f"model{self.save_suffix}_best.pth")
            )

        preds = []
        with torch.no_grad():
            iterations = (
                tqdm(eval_loader, total=len(eval_loader)) if self.detail_pbar else eval_loader
            )
            for data in iterations:
                if mode == "valid":
                    out, loss = self.forward_step(data, calc_loss=True)
                    self.valid_loss.update(loss.item(), n=data[0].size(0))
                elif mode == "test":
                    out, _ = self.forward_step(data, calc_loss=False)
                preds.append(out.detach().cpu().numpy())
        preds = np.concatenate(preds, axis=0)
        return preds

    def inference_loop_colwise(
        self,
        test_loader: DataLoader,
        mode: Literal["valid", "test"],
        best_score_dict: dict[str, tuple[int, float]],
    ):
        self.model.eval()
        if mode == "valid":
            self.valid_loss.reset()

        selected_counts = list(set([eval_count for eval_count, _ in best_score_dict.values()]))
        all_preds = np.zeros((len(test_loader.dataset), len(self.target_cols)))
        for eval_count in tqdm(selected_counts):
            self.model.load_state_dict(
                torch.load(
                    self.config.output_path / f"model{self.save_suffix}_eval{eval_count}.pth"
                )
            )
            preds = []
            with torch.no_grad():
                iterations = (
                    tqdm(test_loader, total=len(test_loader)) if self.detail_pbar else test_loader
                )
                for data in iterations:
                    if mode == "valid":
                        out, loss = self.forward_step(data, calc_loss=True)
                        self.valid_loss.update(loss.item(), n=data[0].size(0))
                    elif mode == "test":
                        out, _ = self.forward_step(data, calc_loss=False)
                    preds.append(out.detach().cpu().numpy())
            preds = np.concatenate(preds, axis=0)

            target_cols = [
                col for col, (count, _) in best_score_dict.items() if count == eval_count
            ]
            for col in target_cols:
                idx = self.target_cols.index(col)
                all_preds[:, idx] = preds[:, idx]
        return all_preds

    def update_best_score(self, indiv_scores: list[float], eval_count: int):
        update_num = 0
        for col, score in zip(self.target_cols, indiv_scores):
            if score > self.best_score_dict[col][1] and eval_count != -1:
                self.best_score_dict[col] = (eval_count, score)
                update_num += 1

        best_cw_score = (
            np.sum([score for _, score in self.best_score_dict.values()])
            + (368 - len(self.target_cols))
        ) / 368
        if update_num > 0 and eval_count != -1:
            pickle.dump(
                dict(self.best_score_dict),
                open(self.config.output_path / f"best_score_dict{self.save_suffix}.pkl", "wb"),
            )
        return best_cw_score, update_num

    def remove_unuse_weights(self):
        selected_counts = set([v[0] for v in self.best_score_dict.values()])
        weight_paths = list(self.config.output_path.glob(f"model{self.save_suffix}_eval*.pth"))
        for path in weight_paths:
            eval_count = int(path.stem.split("_")[-1].replace("eval", ""))
            if eval_count not in selected_counts:
                path.unlink()

    def forward_step(self, data: torch.Tensor, calc_loss: bool = True):
        if calc_loss:
            x, y = data
            x, y = x.to(self.config.device), y.to(self.config.device)
            out = self.model(x)
            loss = self.loss_fn(out, y)
        else:
            x = data[0]
            x = x.to(self.config.device)
            out = self.model(x)
            loss = None

        if self.config.target_shape == "3dim":
            out = self.convert_target_3dim_to_2dim(out)
        return out, loss

    def convert_target_3dim_to_2dim(
        self, y: np.ndarray | torch.Tensor
    ) -> np.ndarray | torch.Tensor:
        y_v = y[:, :, : len(VERTICAL_TARGET_COLS)]
        y_s = y[:, :, len(VERTICAL_TARGET_COLS) :]
        if isinstance(y, np.ndarray):
            y_v = np.transpose(y_v, (0, 2, 1)).reshape(y.shape[0], -1)
            y_s = y_s.mean(axis=1)
            y = np.concatenate([y_v, y_s], axis=-1)
        elif isinstance(y, torch.Tensor):
            y_v = y_v.permute(0, 2, 1).reshape(y.size(0), -1)
            y_s = y_s.mean(dim=1)
            y = torch.cat([y_v, y_s], dim=-1)
        y = self.alignment_target_idx(y)
        return y

    def alignment_target_idx(self, y: np.ndarray | torch.Tensor) -> np.ndarray | torch.Tensor:
        """
        target_colsとモデルの出力の順番を合わせる
        """
        align_order = [self.model_target_cols.index(col) for col in self.target_cols]
        assert len(y.shape) == 2
        y = y[:, align_order]
        return y

    def get_model_target_cols(self):
        model_target_cols = []
        for col in VERTICAL_TARGET_COLS:
            model_target_cols.extend([f"{col}_{i}" for i in range(60)])
        for col in SCALER_TARGET_COLS:
            model_target_cols.append(col)
        return model_target_cols

    def restore_pred(self, preds: np.ndarray):
        return preds * self.y_denominators + self.y_numerators

    def clipping_pred(self, preds: np.ndarray):
        for i in range(preds.shape[1]):
            preds[:, i] = np.clip(preds[:, i], self.target_min_max[i][0], self.target_min_max[i][1])
        return preds

    def save_oof_df(self, sample_ids: np.ndarray, preds: np.ndarray):
        oof_df = pl.DataFrame(preds, schema=self.target_cols)
        oof_df = oof_df.with_columns(sample_id=pl.Series(sample_ids))
        oof_df.write_parquet(self.config.oof_path / f"oof{self.save_suffix}.parquet")

    def postprocess(self, preds: np.ndarray, run_type: Literal["valid", "test"]):
        pp_x = self.valid_pp_df if run_type == "valid" else self.test_pp_df
        for x_col, y_col in zip(self.pp_x_cols, self.pp_y_cols):
            if y_col in self.target_cols:
                idx = self.target_cols.index(y_col)
                old_factor = self.old_factor_dict[y_col] if self.config.mul_old_factor else 1
                preds[:, idx] = (-1 * pp_x[x_col].to_numpy() / 1200) * old_factor
        return preds

    def load_postprocess_input(self, data_type: Literal["valid", "test"]):
        if data_type == "valid":
            valid_path = (
                self.config.input_path / "18_shrinked.parquet"
                if self.config.shared_valid
                else self.config.input_path / "train_shrinked.parquet"
            )
            self.valid_pp_df = (
                pl.scan_parquet(valid_path)
                .select(["sample_id"] + self.pp_x_cols)
                .filter(pl.col("sample_id").is_in(self.valid_ids))
                .collect()
            )
            id_df = pl.DataFrame({"sample_id": self.valid_ids})
            self.valid_pp_df = id_df.join(self.valid_pp_df, on="sample_id", how="left")

        elif data_type == "test":
            self.test_pp_df = pl.read_parquet(
                self.config.input_path / "test_shrinked.parquet",
                columns=["sample_id"] + self.pp_x_cols,
            )
            id_df = pl.DataFrame({"sample_id": self.test_ids})
            self.test_pp_df = id_df.join(self.test_pp_df, on="sample_id", how="left")


In [98]:
trainer = Trainer(config, logger)

In [99]:
oof_df = trainer.train(
    train_loader,
    valid_loader,
    colwise_mode=True,
)

27 28 0.6665454148778593


  torch.load(self.config.output_path / f"{retrain_weight_name}.pth")


  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/3375 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 14:24:55[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.16988, Score=0.66968, Best Col-Wise Score=0.67182[0m


KeyboardInterrupt: 

In [91]:
oof_df[0]

0.6665454148778593

In [72]:
pred_df = trainer.test_predict(test_loader, eval_method="single")

  torch.load(self.config.output_path / f"model{self.save_suffix}_best.pth")


  0%|          | 0/153 [00:00<?, ?it/s]

In [71]:
pred_df

ptend_t_0,ptend_t_1,ptend_t_2,ptend_t_3,ptend_t_4,ptend_t_5,ptend_t_6,ptend_t_7,ptend_t_8,ptend_t_9,ptend_t_10,ptend_t_11,ptend_t_12,ptend_t_13,ptend_t_14,ptend_t_15,ptend_t_16,ptend_t_17,ptend_t_18,ptend_t_19,ptend_t_20,ptend_t_21,ptend_t_22,ptend_t_23,ptend_t_24,ptend_t_25,ptend_t_26,ptend_t_27,ptend_t_28,ptend_t_29,ptend_t_30,ptend_t_31,ptend_t_32,ptend_t_33,ptend_t_34,ptend_t_35,ptend_t_36,…,ptend_v_32,ptend_v_33,ptend_v_34,ptend_v_35,ptend_v_36,ptend_v_37,ptend_v_38,ptend_v_39,ptend_v_40,ptend_v_41,ptend_v_42,ptend_v_43,ptend_v_44,ptend_v_45,ptend_v_46,ptend_v_47,ptend_v_48,ptend_v_49,ptend_v_50,ptend_v_51,ptend_v_52,ptend_v_53,ptend_v_54,ptend_v_55,ptend_v_56,ptend_v_57,ptend_v_58,ptend_v_59,cam_out_NETSW,cam_out_FLWDS,cam_out_PRECSC,cam_out_PRECC,cam_out_SOLS,cam_out_SOLL,cam_out_SOLSD,cam_out_SOLLD,sample_id
f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,…,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,f64,i32
0.072387,-1.447465,-1.59126,-1.238313,-1.039916,-0.886447,-0.759536,-0.794192,-0.957017,-1.032793,-0.944681,-0.791783,-0.54014,-0.306998,-0.069964,0.115484,0.786033,0.542847,0.249372,-0.374216,0.418973,-0.165187,-0.561241,-0.873283,-0.567651,-0.470924,-0.43505,-0.433015,-0.505128,-0.621606,-0.677111,-0.545636,-0.367299,-0.272195,-0.2401,-0.230634,-0.248424,…,0.256084,-0.121401,-0.344587,-0.102638,-0.091313,-0.140771,-0.460923,-0.049854,-0.018083,0.053758,-0.040552,0.309686,0.049827,0.035883,-0.286636,-0.126785,-1.287177,-2.573998,-2.5162,-1.333147,1.358848,1.754033,1.687201,0.85524,0.49837,0.349224,0.274072,0.017806,0.0,5.419843,0.0,0.06175,0.0,0.0,0.0,0.0,0
-0.485621,-1.135914,-0.649161,-0.742407,-0.981201,-1.199448,-1.358868,-1.201348,-0.899092,-0.662138,-0.59916,-0.602477,-0.670352,-0.698934,-0.682521,-0.733334,-0.659958,-0.355007,-0.285042,-0.208013,-0.244101,-0.429523,-0.397009,-0.343185,-0.364166,-0.398045,-0.450494,-0.475154,-0.492672,-0.480598,-0.430091,-0.338034,-0.181345,-0.009316,0.11043,0.261088,0.230349,…,-0.090072,0.020498,0.086971,0.050052,0.027946,0.063024,0.045506,0.026537,0.026066,-0.021002,-0.080584,-0.01897,0.002647,0.106812,-0.013483,-0.139997,-0.227484,-0.258315,-0.204068,-0.163616,-0.085005,-0.025923,-0.004916,0.093526,0.099184,0.026542,0.081104,0.141567,0.0,4.673466,0.0,0.039099,0.0,0.0,0.004282,0.007173,10
-0.234881,-1.656182,-0.601238,-0.433519,-0.789218,-0.849147,-0.75333,-0.671963,-0.595635,-0.537732,-0.55565,-0.608882,-0.711146,-0.758988,-0.750639,-0.721898,-0.592721,-0.432625,-0.277662,-0.208864,-0.280137,-0.387638,-0.296331,-0.302601,-0.214972,-0.178604,-0.198711,-0.220267,-0.217826,-0.22134,-0.288729,-0.324898,-0.344786,-0.334182,-0.341993,-0.344061,-0.341947,…,-0.214113,-0.028538,0.044159,0.040136,0.044838,0.076347,0.07342,0.041257,0.025995,0.00335,-0.010037,-0.010856,0.09887,0.040718,-0.088835,-0.134853,-0.455447,-0.750787,-1.185783,-1.05277,0.492555,0.40803,0.560029,0.34848,0.13647,0.193376,0.150923,0.374985,0.0,4.636291,0.035202,0.051106,0.0,0.0,0.0,0.0,100
0.841697,0.207503,0.660262,0.41642,-0.099689,-0.4974,-0.445526,-0.274708,-0.160325,-0.089188,-0.028821,-0.033144,-0.002816,0.130474,0.288755,0.572247,0.807359,0.651383,0.094121,0.031484,-0.032943,0.033807,-0.465906,-0.893544,-0.641397,-0.434837,-0.411204,-0.475317,-0.526809,-0.462754,-0.374046,-0.329558,-0.277853,-0.227232,-0.224293,-0.261834,-0.335193,…,-0.008861,0.15785,-0.316546,-0.006402,0.263694,-0.090587,0.003734,-0.151393,0.007677,-0.442158,1.166751,0.571663,0.326836,0.528868,0.470603,0.35881,0.213614,0.124506,-0.078239,-0.124431,-0.086841,-0.133892,-0.240608,-0.402967,-0.661201,-1.068136,-1.686045,0.885702,0.081611,4.937044,0.011568,0.000146,0.045463,0.068802,0.0976,0.022714,1000
1.500539,0.631652,0.84002,1.230621,1.381638,1.600752,1.813521,1.874617,1.767358,1.601102,1.472667,1.553341,1.733075,1.589896,1.570783,1.608648,1.623847,1.326061,0.824457,0.409224,0.334102,0.062525,-0.189031,-0.269723,-0.355983,-0.37082,-0.199703,-0.09418,-0.072377,-0.062984,-0.037737,-0.056851,-0.068275,-0.043301,0.070918,0.046332,0.015303,…,0.007482,-0.00586,-0.005332,0.001344,0.000645,-0.007996,0.004655,-0.015023,0.003851,-0.002278,-0.006703,-0.003428,-0.018436,-0.000147,-0.002932,-0.002815,0.002792,-0.002735,0.004233,0.015278,-0.014504,-0.015983,-0.016892,-0.01117,0.007782,0.004151,0.012318,-0.052589,3.550347,5.46045,0.006965,0.0,3.898687,3.928175,1.272887,0.199447,10000
…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…,…
1.555742,-0.231307,0.31507,1.357152,1.375476,1.425383,1.345898,1.348876,1.34179,1.276865,1.287688,1.414489,1.469185,1.401805,1.287487,1.31015,1.376972,0.790754,-0.067635,0.538727,0.436302,0.008843,-0.175665,-0.016893,-0.012854,-0.127514,-0.235109,-0.415306,-0.446855,-0.346145,-0.234438,-0.26318,-0.193878,-0.137943,-0.12264,-0.142851,-0.127268,…,-1.367339,-1.414375,-2.515401,-0.806765,0.050769,-0.602515,-0.922203,-1.136588,-1.362953,-1.078827,-0.543785,-1.221819,-0.858087,-1.07465,-0.963729,0.010406,0.177658,0.676587,2.175902,2.617899,1.8335,1.274838,0.734899,0.221132,0.056085,-0.055042,-0.020237,0.096456,1.792494,5.401479,0.0,0.206009,1.71325,1.731584,1.342879,0.859979,99994
0.125676,-1.379655,-2.742562,-1.084558,-0.698364,-0.503584,-0.441944,-0.475176,-0.569462,-0.535571,-0.495798,-0.482753,-0.552679,-0.686463,-0.874307,-1.022084,-1.059304,-0.9773,-0.583171,-0.433945,-0.282584,-0.250732,-0.193453,-0.190563,-0.16169,-0.125672,-0.121418,-0.116236,-0.111314,-0.09722,-0.095947,-0.095008,-0.091815,-0.10781,-0.100974,-0.112554,-0.124686,…,-0.001767,-0.005422,-0.004283,-0.001575,-0.002043,-0.005807,-0.005647,-0.006422,-0.006536,-0.006953,-0.006118,0.00739,-0.006179,0.00661,-0.002298,0.000249,0.000919,-0.006099,-0.008529,0.006642,-0.011856,-0.012575,-0.016818,-0.023979,0.007209,-0.023306,0.020196,-0.071072,0.007836,2.042937,0.443394,0.02423,0.006911,0.02251,0.00699,0.0,99995
0.955648,0.108612,0.245928,0.292953,0.357661,0.491642,0.48425,0.453527,0.421007,0.352576,0.29612,0.254613,0.269711,0.341166,0.443086,0.490532,0.475491,0.151646,0.142014,0.029529,0.00432,-0.013435,-0.032033,-0.051299,-0.010686,-0.01092,-0.012744,-0.003986,-0.015199,-0.012578,-0.022269,-0.026731,-0.034253,-0.038397,-0.044256,-0.050118,-0.058607,…,-0.003791,-0.005287,-0.005374,-0.003278,-0.002977,-0.00763,-0.006406,-0.00563,-0.006016,-0.007582,-0.008841,0.001931,-0.008538,0.000548,-0.003473,-0.002969,-0.001571,-0.005772,0.000074,0.000333,-0.010182,-0.009303,-0.002546,-0.00463,-0.007142,-0.009745,-0.00918,0.004694,1.047021,2.473471,0.542881,0.059441,1.069734,1.22308,1.082318,0.704674,99996
-0.589185,-0.345286,-0.698004,-0.778368,-0.880136,-0.989507,-0.938784,-0.869011,-0.888787,-0.958739,-0.952675,-0.919699,-0.800661,-0.702907,-0.280697,0.195618,0.560036,0.503415,0.283352,0.027114,-0.317059,-0.35842,-0.547165,-0.717264,-0.894454,-0.940556,-1.024346,-0.862114,-0.67173,-0.507013,-0.453553,-0.405618,-0.385944,-0.383616,-0.443181,-0.412004,-0.422172,…,-0.091295,-0.099374,-0.059406,-0.045226,-0.037378,0.006381,0.307105,-0.884168,-0.738261,-0.27823,-0.080536,0.073662,0.23916,0.286868,0.379763,0.30674,0.257128,0.140809,0.16433,0.104309,0.046667,0.010246,-0.004849,-0.019308,-0.050999,-0.07442,-0.094153,-0.275987,0.0,5.24048,0.00854,0.0,0.0,0.0,0.0,0.0,99997


In [60]:
best_score, best_cw_score, best_epochs = trainer.train(
    train_loader,
    valid_loader,
    colwise_mode=True,
    eval_only=False
)

  0%|          | 0/240 [00:00<?, ?it/s]

  0%|          | 0/3375 [00:00<?, ?it/s]

  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:07:30[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.28119, Score=0.31449, Best Col-Wise Score=0.31449[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:08:49[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.26269, Score=0.36343, Best Col-Wise Score=0.36363[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:10:08[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.25043, Score=0.41416, Best Col-Wise Score=0.41457[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:11:27[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.22722, Score=0.48158, Best Col-Wise Score=0.48226[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:12:46[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.21943, Score=0.50966, Best Col-Wise Score=0.51021[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:14:04[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.20788, Score=0.55110, Best Col-Wise Score=0.55116[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:15:23[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.20354, Score=0.56777, Best Col-Wise Score=0.56811[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:16:42[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.20009, Score=0.58099, Best Col-Wise Score=0.58203[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:18:02[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.19554, Score=0.59370, Best Col-Wise Score=0.59386[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:19:21[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.19324, Score=0.59970, Best Col-Wise Score=0.60087[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:20:40[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.19273, Score=0.60122, Best Col-Wise Score=0.60463[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:21:59[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.18903, Score=0.61432, Best Col-Wise Score=0.61505[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:23:18[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.18824, Score=0.61602, Best Col-Wise Score=0.61906[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:24:38[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.18484, Score=0.62110, Best Col-Wise Score=0.62384[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:25:57[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.18413, Score=0.63048, Best Col-Wise Score=0.63128[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:27:16[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.18317, Score=0.63415, Best Col-Wise Score=0.63609[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:28:36[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.17991, Score=0.63892, Best Col-Wise Score=0.64082[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:29:55[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.17891, Score=0.63920, Best Col-Wise Score=0.64284[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:31:14[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.17844, Score=0.64732, Best Col-Wise Score=0.64767[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:32:33[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.17695, Score=0.64971, Best Col-Wise Score=0.65146[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:33:53[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.17680, Score=0.65087, Best Col-Wise Score=0.65314[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:35:12[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.17572, Score=0.65483, Best Col-Wise Score=0.65643[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:36:31[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.17561, Score=0.65641, Best Col-Wise Score=0.65850[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:37:51[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.17379, Score=0.65722, Best Col-Wise Score=0.66052[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:39:10[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.17354, Score=0.66167, Best Col-Wise Score=0.66341[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:40:29[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.17422, Score=0.65987, Best Col-Wise Score=0.66460[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:41:48[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.17202, Score=0.66506, Best Col-Wise Score=0.66720[0m


  0%|          | 0/153 [00:00<?, ?it/s]

[ [32m2024-10-10 10:43:19[0m | [1mINFO ] [Valid] : Epoch=0, Loss=0.17192, Score=0.66655, Best Col-Wise Score=0.66900[0m


  0%|          | 0/153 [00:00<?, ?it/s]

KeyboardInterrupt: 

# PostProcess

In [29]:
from omegaconf import DictConfig
import loguru
from typing import Literal
from sklearn.metrics import r2_score

from src.utils.competition_utils import get_sub_factor, get_io_columns

class PostProcess:
    def __init__(self, config: DictConfig, logger: loguru._Logger, additional: bool = True):
        self.config = config
        self.logger = logger
        self.additional = additional

        _, self.target_cols = get_io_columns(config)
        self.old_factor_dict = get_sub_factor(config.input_path, old=True)
        self.sub_cols = pl.read_parquet(config.input_path / 'sample_submission.parquet', n_rows=1).columns

        self.pp_x_cols = [f'state_q0002_{i}' for i in range(12, 27)]
        self.pp_y_cols = [f'ptend_q0002_{i}' for i in range(12, 27)]
        self.valid_pp_df = pl.read_parquet(
            config.input_path / '18_shrinked.parquet',
            columns=['sample_id'] + self.pp_x_cols
        )
        self.test_pp_df = pl.read_parquet(
            config.input_path / 'test_shrinked.parquet',
            columns=['sample_id'] + self.pp_x_cols
        )

        add_pp_y_cols = (
            [f'ptend_q0002_{i}' for i in range(60)] +
            [f'ptend_q0003_{i}' for i in range(60)]
        )
        self.add_pp_y_cols = [col for col in add_pp_y_cols if col in self.target_cols]
        self.add_pp_x_cols = [col.replace('ptend', 'state') for col in self.add_pp_y_cols]
        self.add_valid_pp_df = pl.read_parquet(
            config.input_path / '18_shrinked.parquet',
            columns=self.sub_cols + self.add_pp_x_cols
        )
        self.add_test_pp_df = pl.read_parquet(
            config.input_path / 'test_shrinked.parquet',
            columns=['sample_id'] + self.add_pp_x_cols
        )
        self.th_dict = None

    def postprocess(self, oof_df: pl.DataFrame, sub_df: pl.DataFrame):
        oof_df = self.complement_columns(oof_df)
        oof_df = self.reverse_sub_factor(oof_df)
        oof_df = self.replace_postprocess(oof_df, 'oof')

        sub_df = self.complement_columns(sub_df)
        sub_df = self.reverse_sub_factor(sub_df)
        sub_df = self.replace_postprocess(sub_df, 'sub')

        if self.additional:
            oof_df = self.additional_postprocess(oof_df, 'oof')
            sub_df = self.additional_postprocess(sub_df, 'sub')

        oof_df = self.create_oof_df(oof_df)
        sub_df = self.create_sub_df(sub_df)
        return oof_df, sub_df

    def complement_columns(self, pred_df: pl.DataFrame):
        lack_cols = list(set(self.sub_cols) - set(pred_df.columns))
        for col in lack_cols:
            pred_df = pred_df.with_columns([pl.lit(0).alias(col)])
        return pred_df

    def reverse_sub_factor(self, pred_df: pl.DataFrame):
        if self.config.mul_old_factor:
            exprs = []
            for col in self.target_cols:
                if self.old_factor_dict[col] != 0:
                    exprs.append((pl.col(col) / self.old_factor_dict[col]).alias(col))

            pred_df = pred_df.with_columns(exprs)
        return pred_df

    def replace_postprocess(self, pred_df: pl.DataFrame, pred_type: Literal['oof', 'sub']):
        pp_df = self.valid_pp_df if pred_type == 'oof' else self.test_pp_df
        pred_df = pred_df.join(pp_df, on=['sample_id'], how='left')

        exprs = []
        for x_col, y_col in zip(self.pp_x_cols, self.pp_y_cols):
            exprs.append((-1 * pl.col(x_col) / 1200).alias(y_col))
        pred_df = pred_df.with_columns(exprs)
        pred_df = pred_df.drop(self.pp_x_cols)
        return pred_df

    def additional_postprocess(self, pred_df: pl.DataFrame, pred_type: Literal['oof', 'sub']):
        pp_df = self.add_valid_pp_df if pred_type == 'oof' else self.add_test_pp_df
        pred_df = pred_df.join(pp_df, on=['sample_id'], how='left', suffix='_gt')
        exprs = []
        for x_col, y_col in zip(self.add_pp_x_cols, self.add_pp_y_cols):
            exprs.append((pl.col(x_col) + pl.col(y_col) * 1200).alias(f'{x_col}_next'))
        pred_df = pred_df.with_columns(exprs)

        if pred_type == 'oof':
            self.tuning_threshold(pred_df)

        assert self.th_dict is not None # oofから実行する必要がある
        exprs = []
        for y_col, (best_th, _) in self.th_dict.items():
            x_col = y_col.replace('ptend', 'state')
            exprs.append(
                pl.when(pl.col(f'{x_col}_next') < best_th)
                .then(-1 * pl.col(x_col) / 1200)
                .otherwise(pl.col(y_col))
                .alias(y_col)
            )
        pred_df = pred_df.with_columns(exprs)

        if pred_type == 'oof':
            scores = []
            for col in self.target_cols:
                score = r2_score(pred_df[f'{col}_gt'].to_numpy(), pred_df[col].to_numpy())
                scores.append(score)
            total_score = (np.sum(scores) + (368 - len(scores))) / 368
            self.logger.info(f'After Additional Postprocess: {total_score:.5f}')

        drop_cols = (
            self.add_pp_x_cols +
            [f'{col}_next' for col in self.add_pp_x_cols] +
            [col for col in pred_df.columns if '_gt' in col]
        )
        pred_df = pred_df.drop(drop_cols)
        return pred_df

    def tuning_threshold(self, pred_df: pl.DataFrame):
        iterations = tqdm(zip(self.add_pp_x_cols, self.add_pp_y_cols), total=len(self.add_pp_x_cols))
        for x_col, y_col in iterations:
            best_score = r2_score(pred_df[f'{y_col}_gt'].to_numpy(), pred_df[y_col].to_numpy())
            best_th = None
            for th_base in [0, 1e-10, 1e-9, 1e-8, 1e-7, 1e-6, 1e-5]:
                for corr in range(1, 10):
                    if th_base == 0 and corr >= 2:
                        break

                    th = th_base * corr
                    preds = pred_df.select(
                        pl.when(pl.col(f'{x_col}_next') < th)
                        .then(-1 * pl.col(x_col) / 1200)
                        .otherwise(pl.col(y_col))
                    ).to_numpy()

                    truths = pred_df[f'{y_col}_label'].to_numpy()
                    score = r2_score(truths, preds)
                    if score > best_score:
                        best_score = score
                        best_th = th

            if best_th is not None:
                self.th_dict[y_col] = (best_th, best_score)


    def create_oof_df(self, oof_df: pl.DataFrame):
        oof_df = oof_df.select(self.sub_cols)
        oof_df.write_parquet(self.config.oof_path / 'oof_pp.parquet')
        return oof_df

    def create_sub_df(self, sub_df: pl.DataFrame):
        sub_df = sub_df.with_columns(sample_id = pl.concat_str([pl.lit('test_'), pl.col('sample_id')]))
        sub_df = sub_df.select(self.sub_cols)
        sub_df.write_csv(self.config.output_path / 'submission_pp.csv')
        return sub_df