# Library

In [59]:
import polars as pl
import gc
import pickle
from pathlib import Path, PosixPath
from tqdm.auto import tqdm
from collections import defaultdict, Counter

import sys

sys.path.append("..")

from src.utils import seed_everything, get_logger, get_config, TimeUtil
from src.preprocess import DataProvider1st
from src.train import get_train_loaders

# Setup

In [60]:
# コマンドライン引数
exp = "087"

In [61]:
config = get_config(exp, config_dir=Path("../config"))
logger = get_logger(config.output_path)
logger.info(f"exp:{exp} start")

seed_everything(config.seed)

[ [32m2024-10-11 05:04:12[0m | [1mINFO ] exp:087 start[0m


In [4]:
config.debug = True
config.exter_dataset

[['nicholas', True], ['mpware', False], ['pjma', False]]

# Data

In [5]:
dpr = DataProvider1st(config, "train")
data = dpr.load_data()
len(data)

400

In [6]:
dataloaders = get_train_loaders(config, data)



# Model

In [None]:
from omegaconf import DictConfig
from torch import nn
from torch.optim.optimizer import Optimizer

from src.model.models.conv1d import LEAPConv1D
from src.model.models.lstm import LEAPLSTM
from src.model.models.transformer import LEAPTransformer
from src.train.loss import LEAPLoss
from src.train.optimizer import get_optimizer
from src.train.scheduler import get_scheduler


class ComponentFactory:
    # [TODO]要編集
    @staticmethod
    def get_model(config: DictConfig):
        if config.task_type == "detect":
            model = DetectModel(config)
        elif config.task_type == "classify":
            model = ClassifyModel()

        if config.reinit_layer_num > 0:
            model.reinit_layers(config.reinit_layer_num)
        if config.freeze_layer_num > 0:
            model.freeze_layers(config.freeze_layer_num)
        return model

    # [TODO]要編集
    @staticmethod
    def get_loss(config: DictConfig):
        if config.task_type == "detect":
            if config.smooth_type == "online":
                loss_fn = None
            else:
                loss_fn = SmoothingCELoss(config, weight=config.loss_class_weight)

        elif config.task_type == "classify":
            loss_fn = WeightedBCELoss()
        return loss_fn

    @staticmethod
    def get_optimizer(config: DictConfig, model):
        optimizer = get_optimizer(
            model,
            optimizer_type=config.optimizer_type,
            pretrained_lr=config.pretrained_lr,
            head_lr=config.head_lr,
            weight_decay=config.weight_decay,
            betas=config.betas,
        )
        return optimizer

    @staticmethod
    def get_scheduler(config: DictConfig, optimizer: Optimizer, steps_per_epoch: int):
        total_steps = config.epochs * steps_per_epoch
        if config.scheduler_type == "linear":
            scheduler_args = {
                "num_warmup_steps": config.num_warmup_steps,
                "num_training_steps": total_steps,
            }
        elif config.scheduler_type == "cosine":
            scheduler_args = {
                "num_warmup_steps": config.num_warmup_steps,
                "num_training_steps": total_steps,
                "num_cycles": config.num_cycles,
            }
        elif config.scheduler_type == "cosine_custom":
            first_cycle_steps = config.first_cycle_epochs * steps_per_epoch
            scheduler_args = {
                "first_cycle_steps": first_cycle_steps,
                "cycle_factor": config.cycle_factor,
                "num_warmup_steps": config.num_warmup_steps,
                "min_lr": config.min_lr,
                "gamma": config.gamma,
            }
        elif config.scheduler_type == "reduce_on_plateau":
            scheduler_args = {
                "mode": config.mode,
                "factor": config.factor,
                "patience": config.patience,
                "min_lr": config.min_lr,
            }
        else:
            raise ValueError(f"Invalid scheduler_type: {config.scheduler_type}")

        scheduler = get_scheduler(optimizer, scheduler_type=config.scheduler_type, scheduler_args=scheduler_args)
        return scheduler


In [30]:
train_loader = dataloaders[0][0]
model = get_model(config)



# Train

In [None]:
for fold, (train_loader, valid_loader) in enumerate(dataloaders):
    logger.info(f"\n FOLD{fold} : Training Start \n")

In [31]:
# train over folds
oof_dfs = []
best_steps_list, best_add_steps_list = [], []
for fold, (train_loader, valid_loader) in enumerate(dataloaders):
    model = get_model(config)
    optimizer = get_optimizer(config, model)
    oof_df, score, best_steps, best_add_steps = train_model(
        config,
        model,
        train_loader,
        valid_loader,
        optimizer,
        logger,
        fold,
        suffix=suffix,
    )
    oof_df.write_parquet(config.oof_path / f"oof_fold{fold}{suffix}.parquet")
    oof_dfs.append(oof_df)
    best_score, best_th = get_best_negative_threshold(config, oof_df)
    config.negative_th = best_th
    message = f"FOLD: {fold}, Steps: {best_steps} + {best_add_steps}, Best Score: {best_score}, Best Negative Threshold: {best_th}"
    logger.info(message)
    best_steps_list.append(best_steps)
    best_add_steps_list.append(best_add_steps)

    del train_loader, valid_loader, model
    gc.collect()
    torch.cuda.empty_cache()
del dataloaders
gc.collect()

# save oof
oof_df = pl.concat(oof_dfs)
oof_df.write_parquet(config.oof_path / f"oof_{config.exp}{suffix}.parquet")

# get best threshold
best_score, best_th = get_best_negative_threshold(config, oof_df)
message = f"Overall OOF Best Score: {best_score}, Best Negative Threshold: {best_th}"
logger.info(message)
config.negative_th = best_th

# full train
if config.full_train:
    full_train_steps, full_train_add_steps = np.max(best_steps_list), np.max(best_add_steps_list)
    logger.info(f"\n Full Train : Training Start, Num of Steps : {full_train_steps} + {full_train_add_steps}\n")
    train_loader = get_full_train_loader(config, train_data)
    model = get_model(config)
    optimizer = get_optimizer(config, model)
    full_train_model(config, model, train_loader, optimizer, full_train_steps, full_train_add_steps, logger, suffix)
    message = f"Full Train Completed"
    logger.info(message)

In [32]:
out = model(input_ids, attention_mask, positions)

In [33]:
out.shape

torch.Size([16, 128, 2])