In [149]:
import json
import os
import random
from types import SimpleNamespace

import numpy as np
import pandas as pd
import timm
import torch
import torch.nn.functional as F
import torchaudio.transforms as T
from sklearn.metrics import f1_score
from sklearn.model_selection import StratifiedGroupKFold
from torch import nn
from torch.optim import AdamW
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from transformers import get_cosine_schedule_with_warmup

In [150]:
!pwd

Now using node v22.16.0 (npm v11.4.1)
/home/karunru/Home/Kaggle/kaggle_monorepo/projects/CMI_Detect_Behavior_with_Sensor_Data/notebooks


In [151]:
cfg = SimpleNamespace(**{})

# cfg.train_path = '/kaggle/input/cmi-detect-behavior-with-sensor-data/train.csv'
cfg.train_path = "../data/train.csv"

cfg.imu_cols = [
    "acc_x",
    "acc_y",
    "acc_z",
    "rot_w",
    "rot_x",
    "rot_y",
    "rot_z",
]
cfg.static_cols = [
    "sequence_id",
    "sequence_type",
    "gesture",
    "orientation",
    "subject",
    "adult_child",
    "age",
    "sex",
    "handedness",
    "height_cm",
    "shoulder_to_wrist_cm",
    "elbow_to_wrist_cm",
]

cfg.main_target = "gesture"
cfg.main_num_classes = 18
cfg.group = "subject"
cfg.seq_len = 100
cfg.n_splits = 5
cfg.curr_fold = 1
cfg.seed = 42

cfg.model_dir = "weights"
cfg.oof_dir = "oofs"

# cfg.encoder_name = 'timm/resnet50.a1_in1k' # just for test :)
cfg.encoder_name = "timm/timm/caformer_s18.sail_in22k"
cfg.img_sz = 224
cfg.sepc_model_dropout = 0.3
cfg.im_pretrained = False

cfg.bs = 256
cfg.n_epochs = 50
cfg.patience = 5
cfg.lr = 1e-4
cfg.weight_decay = 1e-2
cfg.num_warmup_steps_ratio = 0.03
cfg.max_norm = 2.0

In [152]:
class SpecNormalize(nn.Module):
    def __init__(self, eps: float = 1e-8):
        super().__init__()
        self.eps = eps

    def forward(self, x):
        # batch, channel
        # x: (batch, channel, freq, time)
        min_ = x.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]
        max_ = x.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0]

        return (x - min_) / (max_ - min_ + self.eps)

In [153]:
class SpecFeatureExtractor(nn.Module):
    def __init__(
        self,
        in_channels: int,
        height: int,
        hop_length: int,
        win_length: int | None = None,
        out_size: int | None = None,
    ):
        super().__init__()
        self.height = height
        self.out_chans = in_channels
        n_fft = height * 2 - 1
        self.feature_extractor = nn.Sequential(
            T.Spectrogram(n_fft=n_fft, hop_length=hop_length, win_length=win_length),
            T.AmplitudeToDB(top_db=80),
            SpecNormalize(),
        )
        self.out_size = out_size

        if self.out_size is not None:
            self.pool = nn.AdaptiveAvgPool2d((None, self.out_size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        img = self.feature_extractor(x)
        if self.out_size is not None:
            img = self.pool(img)

        return img

In [154]:
class SpecCNN2d(nn.Module):
    def __init__(self, in_channels, n_classes, dropout=0.2):
        super().__init__()

        self.conv1 = nn.Conv2d(in_channels, 32, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(32)

        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(64)

        self.conv3 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(128)

        self.conv4 = nn.Conv2d(128, 256, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(256)

        self.pool = nn.MaxPool2d(2)
        self.dropout = nn.Dropout(dropout)

        self.gap = nn.AdaptiveAvgPool2d(1)

        self.classifier = nn.Sequential(
            nn.Linear(256, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(256, n_classes),
        )

    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x = self.pool(x)
        x = F.relu(self.bn2(self.conv2(x)))
        x = self.pool(x)
        x = F.relu(self.bn3(self.conv3(x)))
        x = F.relu(self.bn4(self.conv4(x)))
        x = self.gap(x)
        x = x.flatten(1)

        return self.classifier(x)

In [155]:
class IMG_CMIModel(nn.Module):
    def __init__(
        self,
        imu_vars=len(cfg.imu_cols),
        n_classes=cfg.main_num_classes,
        dropout=cfg.sepc_model_dropout,
        img_sz=cfg.img_sz,
        spec_height=12,
        spec_hop_length=4,
        spec_win_length=None,
        pretrained=cfg.im_pretrained,
        encoder=cfg.encoder_name,
    ):
        super().__init__()

        self.spec_feature_extractor = SpecFeatureExtractor(
            in_channels=imu_vars,
            height=spec_height,
            hop_length=spec_hop_length,
            win_length=spec_win_length,
            # out_size=None # for my cnn2d
            out_size=img_sz,  # for timm model
        )

        self.tmodel = timm.create_model(
            encoder, in_chans=imu_vars, drop_rate=dropout, num_classes=n_classes, pretrained=pretrained
        )
        # self.tmodel = SpecCNN2d(in_channels=imu_vars, n_classes=n_classes, dropout=dropout)

    def forward(self, _x):
        x = _x.transpose(1, 2)
        x = self.spec_feature_extractor(x)
        out = self.tmodel(x)

        return out

In [156]:
class TS_CMIDataset(Dataset):
    def __init__(self, dataframe, seq_len=cfg.seq_len, main_target=cfg.main_target):
        self.df = dataframe.copy().reset_index(drop=True)
        self.seq_len = seq_len
        self.main_target = main_target

        self.imu_cols = cfg.imu_cols
        self.has_target = self.main_target in self.df.columns

    def _prepare_sensor_data_raw(self, row, sensor_cols):
        processed_series_list = []
        original_lengths = []

        for col_name in sensor_cols:
            series = row[col_name]
            series_array = np.asarray(series, dtype=np.float64)
            original_lengths.append(len(series_array))
            processed_series_list.append(series_array)

        data_stacked = np.stack(processed_series_list, axis=1)

        for i in range(data_stacked.shape[1]):
            column_data = data_stacked[:, i]
            if np.all(np.isnan(column_data)):
                data_stacked[:, i] = 0.0
            elif np.any(np.isnan(column_data)):
                s = pd.Series(column_data)
                s_filled = s.interpolate(method="linear", limit_direction="both").ffill().bfill().fillna(0.0)
                data_stacked[:, i] = s_filled.values

        return data_stacked

    def __len__(self):
        return len(self.df)

    def _pad_or_truncate_final(self, data, target_len):
        current_len = data.shape[0]

        if current_len > target_len:
            truncated_data = data[-target_len:]
            return truncated_data
        elif current_len < target_len:
            padding_rows = np.zeros((target_len - current_len, data.shape[1]), dtype=data.dtype)
            padded_data = np.concatenate([padding_rows, data], axis=0)
            return padded_data

        return data

    def _prepare_sensor_data(self, row, sensor_cols):
        data_stacked = self._prepare_sensor_data_raw(row, sensor_cols)
        data_stacked = self._pad_or_truncate_final(data_stacked, self.seq_len)

        return data_stacked

    def __getitem__(self, idx):
        row = self.df.iloc[idx]

        imu_data = self._prepare_sensor_data(row, self.imu_cols)

        features = {
            "imu": torch.tensor(imu_data, dtype=torch.float32),  # (seq_len, 7)
        }

        if self.has_target:
            features["main_target"] = torch.tensor(row[self.main_target], dtype=torch.long)

        return features

In [157]:
def fast_seq_agg(df):
    sc = cfg.static_cols
    seq_cols = [c for c in df.columns if c not in sc + ["sequence_counter", "row_id"]]
    static_cols = [c for c in sc if c in df.columns]

    df = df.sort_values(["sequence_id", "sequence_counter"]).reset_index(drop=True)

    seq_id_codes, _ = pd.factorize(df["sequence_id"])
    _, seq_start_idxs = np.unique(seq_id_codes, return_index=True)

    res = {"sequence_id": df["sequence_id"].values[seq_start_idxs]}

    for c in static_cols:
        res[c] = df[c].values[seq_start_idxs]

    for c in seq_cols:
        res[c] = np.split(df[c].values, seq_start_idxs[1:])

    res_df = pd.DataFrame(res)

    return res_df

In [158]:
def le(df):
    mapper_main = {
        "Above ear - pull hair": 0,
        "Cheek - pinch skin": 1,
        "Eyebrow - pull hair": 2,
        "Eyelash - pull hair": 3,
        "Forehead - pull hairline": 4,
        "Forehead - scratch": 5,
        "Neck - pinch skin": 6,
        "Neck - scratch": 7,
        "Drink from bottle/cup": 8,
        "Feel around in tray and pull out an object": 9,
        "Glasses on/off": 10,
        "Pinch knee/leg skin": 11,
        "Pull air toward your face": 12,
        "Scratch knee/leg skin": 13,
        "Text on phone": 14,
        "Wave hello": 15,
        "Write name in air": 16,
        "Write name on leg": 17,
    }

    df[cfg.main_target] = df[cfg.main_target].map(mapper_main)

    return df

In [159]:
def comp_metric(y_true, y_pred):
    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    bscore = f1_score(
        np.where(y_true <= 7, 1, 0),
        np.where(y_pred <= 7, 1, 0),
        zero_division=0.0,
    )

    mscore = f1_score(
        np.where(y_true <= 7, y_true, 99),
        np.where(y_pred <= 7, y_pred, 99),
        average="macro",
        zero_division=0.0,
    )

    return (bscore + mscore) / 2, bscore, mscore

In [160]:
def seed_everything(seed: int = cfg.seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True

In [161]:
def train_epoch(train_loader, model, optimizer, main_criterion, device, scheduler, current_step=0, fold=None):
    model.train()

    total_loss = 0
    total_samples = 0
    all_targets = []
    all_preds = []

    loop = tqdm(train_loader, desc="train", leave=False)

    for batch in loop:
        optimizer.zero_grad()

        for key in batch.keys():
            batch[key] = batch[key].to(device)

        outputs = model(batch["imu"])
        targets = batch["main_target"]
        loss = main_criterion(outputs, targets)

        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=cfg.max_norm)

        optimizer.step()
        scheduler.step()

        total_loss += loss.item() * targets.size(0)
        total_samples += targets.size(0)

        preds = torch.argmax(outputs, dim=1)
        all_targets.extend(targets.cpu().numpy())
        all_preds.extend(preds.cpu().numpy())

        loop.set_postfix(loss=loss.item())
        current_step += 1

    avg_loss = total_loss / total_samples
    avg_m, bm, mm = comp_metric(all_targets, all_preds)

    return avg_loss, avg_m, bm, mm, current_step

In [162]:
def valid_epoch(val_loader, model, main_criterion, device):
    model.eval()

    total_loss = 0
    total_samples = 0
    all_targets = []
    all_preds = []

    with torch.no_grad():
        loop = tqdm(val_loader, desc="val", leave=False)
        for batch in loop:
            for key in batch.keys():
                batch[key] = batch[key].to(device)

            outputs = model(batch["imu"])
            targets = batch["main_target"]
            loss = main_criterion(outputs, targets)

            total_loss += loss.item() * targets.size(0)
            total_samples += targets.size(0)

            preds = torch.argmax(outputs, dim=1)
            all_targets.extend(targets.cpu().numpy())
            all_preds.extend(preds.cpu().numpy())

            loop.set_postfix(loss=loss.item())

    avg_loss = total_loss / total_samples
    avg_m, bm, mm = comp_metric(all_targets, all_preds)

    return avg_loss, avg_m, bm, mm, all_targets, all_preds

In [163]:
def run_training_with_stratified_group_kfold():
    os.makedirs(cfg.model_dir, exist_ok=True)
    os.makedirs(cfg.oof_dir, exist_ok=True)

    sgkf = StratifiedGroupKFold(n_splits=cfg.n_splits, shuffle=True, random_state=cfg.seed)
    targets = train_seq[cfg.main_target].values
    groups = train_seq[cfg.group].values

    oof_preds = np.zeros((len(train_seq), cfg.main_num_classes))
    oof_targets = train_seq[cfg.main_target].values

    best_models = []
    best_f1_scores = []

    for fold, (train_idx, val_idx) in enumerate(sgkf.split(train_seq, targets, groups)):
        print(f"fold {fold + 1}/{cfg.n_splits}")
        # if fold != cfg.curr_fold:
        #     continue

        train_subset = train_seq.iloc[train_idx].reset_index(drop=True)
        val_subset = train_seq.iloc[val_idx].reset_index(drop=True)

        train_dataset = TS_CMIDataset(
            dataframe=train_subset,
            seq_len=cfg.seq_len,
            main_target=cfg.main_target,
        )
        val_dataset = TS_CMIDataset(
            dataframe=val_subset,
            seq_len=cfg.seq_len,
            main_target=cfg.main_target,
        )

        train_loader = DataLoader(
            train_dataset,
            batch_size=cfg.bs,
            shuffle=True,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=4,
            num_workers=4,
            generator=g,
            worker_init_fn=lambda worker_id: np.random.seed(cfg.seed + worker_id),
        )
        val_loader = DataLoader(
            val_dataset,
            batch_size=cfg.bs,
            shuffle=False,
            pin_memory=True,
            persistent_workers=True,
            prefetch_factor=4,
            num_workers=4,
            generator=g,
            worker_init_fn=lambda worker_id: np.random.seed(cfg.seed + worker_id),
        )

        model = IMG_CMIModel().to(device)

        fucking_kaggle_p100 = True
        if not fucking_kaggle_p100:
            model = torch.compile(model, mode="max-autotune")

        optimizer = AdamW(model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay)

        main_criterion = nn.CrossEntropyLoss()

        num_training_steps = cfg.n_epochs * len(train_loader)
        num_warmup_steps = int(cfg.num_warmup_steps_ratio * num_training_steps)
        current_step = 0

        scheduler_params = {
            "optimizer": optimizer,
            "num_warmup_steps": num_warmup_steps,
            "num_training_steps": num_training_steps,
        }
        scheduler = get_cosine_schedule_with_warmup(**scheduler_params)

        best_val_score = -np.inf
        patience_counter = 0
        fold_checkpoints = []

        for epoch in range(cfg.n_epochs):
            print(f"{epoch=}")

            train_loss, avg_m_train, bm_train, mm_train, current_step = train_epoch(
                train_loader, model, optimizer, main_criterion, device, scheduler, fold
            )
            val_loss, avg_m_val, bm_val, mm_val, _, _ = valid_epoch(val_loader, model, main_criterion, device)

            print(f"{train_loss=}, {avg_m_train=}, {bm_train=}, {mm_train=},")
            print(f"{val_loss=}, {avg_m_val=}, {bm_val=}, {mm_val=}")

            model_path = os.path.join(cfg.model_dir, f"model_fold{fold}_val_f1_{avg_m_val:.4f}_epoch{epoch:03d}.pt")
            torch.save(model.state_dict(), model_path)

            fold_checkpoints.append(
                {
                    "score": avg_m_val,
                    "epoch": epoch,
                    "model_path": model_path,
                }
            )

            fold_checkpoints.sort(key=lambda x: x["score"], reverse=True)

            if len(fold_checkpoints) > 5:
                to_remove = fold_checkpoints[5:]
                fold_checkpoints = fold_checkpoints[:5]

                for checkpoint in to_remove:
                    if os.path.exists(checkpoint["model_path"]):
                        os.remove(checkpoint["model_path"])

            if avg_m_val > best_val_score:
                best_val_score = avg_m_val
                patience_counter = 0
            else:
                patience_counter += 1

            if patience_counter >= cfg.patience:
                print("early stopping")
                break

        best_checkpoint = fold_checkpoints[0]
        if best_checkpoint:
            model.load_state_dict(torch.load(best_checkpoint["model_path"]))

        best_models.append(model)
        best_f1_scores.append(best_val_score)

        model.eval()
        all_preds = []
        with torch.no_grad():
            for batch in val_loader:
                for key in batch.keys():
                    batch[key] = batch[key].to(device)

                outputs = model(batch["imu"])
                all_preds.append(outputs.cpu().numpy())

        all_preds = np.concatenate(all_preds, axis=0)
        oof_preds[val_idx] = all_preds

        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    oof_pred_labels = np.argmax(oof_preds, axis=1)
    oof_m, oof_bm, oof_mm = comp_metric(oof_targets, oof_pred_labels)
    print(f"{oof_m=}, {oof_bm=}, {oof_mm=}, ")

    oof_preds_path = os.path.join(cfg.oof_dir, "oof_preds.npy")
    oof_targets_path = os.path.join(cfg.oof_dir, "oof_targets.npy")
    oof_pred_labels_path = os.path.join(cfg.oof_dir, "oof_pred_labels.npy")

    np.save(oof_preds_path, oof_preds)
    np.save(oof_targets_path, oof_targets)
    np.save(oof_pred_labels_path, oof_pred_labels)

    oof_info = {
        "oof_avg_f1": oof_m,
        "oof_binary_f1": oof_bm,
        "oof_macro_f1": oof_mm,
        "best_avg_f1_scores_per_fold": best_f1_scores,
        "mean_cv_avg_f1": np.mean(best_f1_scores),
        "std_cv_avg_f1": np.std(best_f1_scores),
    }

    oof_info_path = os.path.join(cfg.oof_dir, "oof_info.json")
    with open(oof_info_path, "w") as f:
        json.dump(oof_info, f, indent=2)

    return best_models, oof_preds

In [164]:
seed_everything()
g = torch.Generator(device="cpu").manual_seed(cfg.seed)

In [165]:
train = pd.read_csv(cfg.train_path)

In [166]:
train = le(train)
train_seq = fast_seq_agg(train)

In [167]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [168]:
best_models, oof_preds = run_training_with_stratified_group_kfold()

fold 1/5
epoch=0


                                                                 

train_loss=2.970078431064226, avg_m_train=0.3875060245284323, bm_train=0.6505708848715509, mm_train=0.12444116418531365,
val_loss=2.5684060338904096, avg_m_val=0.4458647497956081, bm_val=0.7758700696055685, mm_val=0.11585942998564772
epoch=1


                                                                 

train_loss=2.620968243875263, avg_m_train=0.4320659928550969, bm_train=0.7042074488369137, mm_train=0.15992453687328012,
val_loss=2.405199665049608, avg_m_val=0.4961886763344045, bm_val=0.8252013263855993, mm_val=0.1671760262832097
epoch=2


                                                                 

train_loss=2.3281754471826246, avg_m_train=0.49423230797614226, bm_train=0.7845208845208845, mm_train=0.20394373143140004,
val_loss=2.260236127214282, avg_m_val=0.436338684143031, bm_val=0.7091346153846154, mm_val=0.16354275290144657
epoch=3


                                                                 

train_loss=2.1436343612797613, avg_m_train=0.5242841147251216, bm_train=0.8203373258372036, mm_train=0.22823090361303955,
val_loss=2.061369274923315, avg_m_val=0.5417707868748466, bm_val=0.857293868921776, mm_val=0.22624770482791723
epoch=4


                                                                 

train_loss=1.9301973270646362, avg_m_train=0.5700334423309343, bm_train=0.8625429553264605, mm_train=0.27752392933540804,
val_loss=1.7408432473686977, avg_m_val=0.6089346515416187, bm_val=0.9110651499482937, mm_val=0.3068041531349437
epoch=5


                                                                 

train_loss=1.7673310538027989, avg_m_train=0.6056742423320608, bm_train=0.8916340508806262, mm_train=0.31971443378349546,
val_loss=1.7093125538052065, avg_m_val=0.6185239751534326, bm_val=0.9096446700507614, mm_val=0.3274032802561037
epoch=6


                                                                 

train_loss=1.6320012746288501, avg_m_train=0.628083650948097, bm_train=0.9165956410568611, mm_train=0.339571660839333,
val_loss=1.626168594310421, avg_m_val=0.6321466391176618, bm_val=0.9242183359830418, mm_val=0.3400749422522818
epoch=7


                                                                 

train_loss=1.515561062206445, avg_m_train=0.6520410385450655, bm_train=0.9318654434250765, mm_train=0.3722166336650543,
val_loss=1.44750104961595, avg_m_val=0.6667998934732787, bm_val=0.947261663286004, mm_val=0.38633812366055337
epoch=8


                                                                 

train_loss=1.380582689824969, avg_m_train=0.6775600507657217, bm_train=0.9462209302325582, mm_train=0.40889917129888526,
val_loss=1.492293016448695, avg_m_val=0.6739664338236533, bm_val=0.9458631256384066, mm_val=0.40206974200890006
epoch=9


                                                                 

train_loss=1.2693782211807978, avg_m_train=0.7020714856837609, bm_train=0.9585498489425982, mm_train=0.44559312242492344,
val_loss=1.50287403236509, avg_m_val=0.6595330030621944, bm_val=0.9340482573726542, mm_val=0.38501774875173456
epoch=10


                                                                 

train_loss=1.1665306929175485, avg_m_train=0.7218247442656035, bm_train=0.9631604459524964, mm_train=0.4804890425787106,
val_loss=1.4422335381283187, avg_m_val=0.697001391230369, bm_val=0.9470899470899471, mm_val=0.4469128353707908
epoch=11


                                                                  

train_loss=1.0189905037571076, avg_m_train=0.7513215680725285, bm_train=0.971656976744186, mm_train=0.530986159400871,
val_loss=1.4575936887900867, avg_m_val=0.6782146439399352, bm_val=0.9513070220399795, mm_val=0.4051222658398909
epoch=12


                                                                  

train_loss=0.9187415484105571, avg_m_train=0.7779255128638161, bm_train=0.9801692865779927, mm_train=0.5756817391496395,
val_loss=1.6123106616953904, avg_m_val=0.6826076378007988, bm_val=0.9398220826792255, mm_val=0.42539319292237215
epoch=13


                                                                  

train_loss=0.7390435466715239, avg_m_train=0.8199524381729777, bm_train=0.988056460369164, mm_train=0.6518484159767916,
val_loss=1.6172711580835712, avg_m_val=0.6991445729367536, bm_val=0.9528795811518325, mm_val=0.44540956472167476
epoch=14


                                                                  

train_loss=0.6182145302018136, avg_m_train=0.8524997430172421, bm_train=0.9893848009650181, mm_train=0.7156146850694661,
val_loss=1.6276367765446609, avg_m_val=0.6974430678372907, bm_val=0.9528301886792453, mm_val=0.44205594699533624
epoch=15


                                                                  

train_loss=0.43413412831943055, avg_m_train=0.8963857417978812, bm_train=0.9942154736080984, mm_train=0.7985560099876641,
val_loss=2.000461612696423, avg_m_val=0.6926901234611662, bm_val=0.9566115702479339, mm_val=0.4287686766743985
epoch=16


                                                                  

train_loss=0.32041807523553834, avg_m_train=0.9223637169425725, bm_train=0.994942196531792, mm_train=0.8497852373533529,
val_loss=2.3270598344153757, avg_m_val=0.6888082852738834, bm_val=0.9441517386722866, mm_val=0.4334648318754801
epoch=17


                                                                  

train_loss=0.2512119218097359, avg_m_train=0.941761982899415, bm_train=0.9960245753523672, mm_train=0.8874993904464628,
val_loss=2.3526468264494893, avg_m_val=0.6856885062959749, bm_val=0.9488517745302714, mm_val=0.4225252380616784
epoch=18


                                                                  

train_loss=0.18961622309157375, avg_m_train=0.9582063418761966, bm_train=0.9963846710050615, mm_train=0.9200280127473318,
val_loss=2.652000518369425, avg_m_val=0.6968406603134287, bm_val=0.9556701030927836, mm_val=0.4380112175340739
early stopping
fold 2/5
epoch=0


                                                                 

train_loss=3.0175282473139173, avg_m_train=0.3948230550936003, bm_train=0.6699354221478115, mm_train=0.11971068803938906,
val_loss=2.6149880699082915, avg_m_val=0.3756951379036487, bm_val=0.6490735541830432, mm_val=0.10231672162425423
epoch=1


                                                                 

train_loss=2.6426258880789613, avg_m_train=0.43496397328486647, bm_train=0.7131954702117184, mm_train=0.15673247635801463,
val_loss=2.3126407698089, avg_m_val=0.4151577251258204, bm_val=0.6431181485992692, mm_val=0.18719730165237164
epoch=2


                                                                 

train_loss=2.280160304647517, avg_m_train=0.5018696657781191, bm_train=0.7936707271128971, mm_train=0.21006860444334116,
val_loss=2.1061875072180056, avg_m_val=0.5128955787936141, bm_val=0.7843986998916577, mm_val=0.24139245769557066
epoch=3


                                                                 

train_loss=1.980233474786451, avg_m_train=0.5618897149142155, bm_train=0.8645950330712592, mm_train=0.2591843967571717,
val_loss=1.9578259388605754, avg_m_val=0.5890662212100355, bm_val=0.8574380165289256, mm_val=0.3206944258911455
epoch=4


                                                                 

train_loss=1.8579667685569146, avg_m_train=0.5879207623410894, bm_train=0.8811918063314711, mm_train=0.2946497183507075,
val_loss=1.8644541642245125, avg_m_val=0.5825213136121845, bm_val=0.8568507157464212, mm_val=0.3081919114779479
epoch=5


                                                                 

train_loss=1.7086644136278446, avg_m_train=0.6099520307678743, bm_train=0.9006242197253433, mm_train=0.3192798418104052,
val_loss=1.8427469520007862, avg_m_val=0.5760391982491094, bm_val=0.8456659619450317, mm_val=0.3064124345531872
epoch=6


                                                                 

train_loss=1.5521696478626739, avg_m_train=0.6468048532775751, bm_train=0.9276258814796486, mm_train=0.3659838250755015,
val_loss=1.9494335931890152, avg_m_val=0.5858141883562886, bm_val=0.8543689320388349, mm_val=0.31725944467374245
epoch=7


                                                                 

train_loss=1.4475643678317864, avg_m_train=0.6654563357415144, bm_train=0.9425713227121156, mm_train=0.3883413487709131,
val_loss=1.6400898227504654, avg_m_val=0.6430852407234322, bm_val=0.9016227180527383, mm_val=0.38454776339412605
epoch=8


                                                                 

train_loss=1.3009663790825559, avg_m_train=0.6960862846898528, bm_train=0.9551944854751354, mm_train=0.43697808390457005,
val_loss=1.725272353957681, avg_m_val=0.6344944217162056, bm_val=0.8836477987421384, mm_val=0.38534104469027286
epoch=9


                                                                  

train_loss=1.1472883577415558, avg_m_train=0.7262075708782365, bm_train=0.9677657480314961, mm_train=0.48464939372497706,
val_loss=1.667509957855823, avg_m_val=0.651380715982012, bm_val=0.9211700545364403, mm_val=0.38159137742758376
epoch=10


                                                                  

train_loss=1.0508998395804523, avg_m_train=0.7502973500653918, bm_train=0.9758074419747021, mm_train=0.5247872581560815,
val_loss=1.8626521147933661, avg_m_val=0.6253206722234095, bm_val=0.8989473684210526, mm_val=0.35169397602576646
epoch=11


                                                                  

train_loss=0.9256875243339094, avg_m_train=0.7762111231648594, bm_train=0.9801275760549558, mm_train=0.572294670274763,
val_loss=2.044174820769067, avg_m_val=0.6107434162192172, bm_val=0.882921589688507, mm_val=0.3385652427499275
epoch=12


                                                                  

train_loss=0.8252246792394337, avg_m_train=0.8023918882515539, bm_train=0.9850343473994112, mm_train=0.6197494291036967,
val_loss=1.8242149423150456, avg_m_val=0.6560986132739859, bm_val=0.9223205506391348, mm_val=0.3898766759088369
epoch=13


                                                                  

train_loss=0.6131560997301714, avg_m_train=0.8483842914577571, bm_train=0.9894736842105263, mm_train=0.7072948987049879,
val_loss=2.281434592078714, avg_m_val=0.6253186464783786, bm_val=0.9062181447502549, mm_val=0.34441914820650255
epoch=14


                                                                  

train_loss=0.534334380100212, avg_m_train=0.8683892988955921, bm_train=0.9874877330716388, mm_train=0.7492908647195454,
val_loss=2.228264995649749, avg_m_val=0.661980109162173, bm_val=0.9199803632793323, mm_val=0.4039798550450136
epoch=15


                                                                  

train_loss=0.3458686455670521, avg_m_train=0.9190500256874428, bm_train=0.9957207482577333, mm_train=0.8423793031171523,
val_loss=2.696786361582139, avg_m_val=0.6411184651424248, bm_val=0.9129770992366413, mm_val=0.36925983104820825
epoch=16


                                                                  

train_loss=0.29611478821597487, avg_m_train=0.9292308548671147, bm_train=0.9959613266430057, mm_train=0.8625003830912235,
val_loss=2.83736522992452, avg_m_val=0.6450503580967619, bm_val=0.9131538852209243, mm_val=0.3769468309725994
epoch=17


                                                                  

train_loss=0.21948470131378814, avg_m_train=0.9496050006225882, bm_train=0.9975538160469667, mm_train=0.9016561851982097,
val_loss=2.8957901515212714, avg_m_val=0.6452220860167155, bm_val=0.9226907630522089, mm_val=0.3677534089812221
epoch=18


                                                                   

train_loss=0.15822797259371565, avg_m_train=0.9642798992106372, bm_train=0.9962070231249235, mm_train=0.9323527752963509,
val_loss=3.0950391853556916, avg_m_val=0.648103973728004, bm_val=0.9164179104477612, mm_val=0.3797900370082467
epoch=19


                                                                   

train_loss=0.11959325397379432, avg_m_train=0.9765343322161054, bm_train=0.9965761799951088, mm_train=0.956492484437102,
val_loss=3.4124203523000083, avg_m_val=0.6406798103873252, bm_val=0.9157472417251755, mm_val=0.3656123790494749
early stopping
fold 3/5
epoch=0


                                                                 

train_loss=3.080847628453448, avg_m_train=0.3919468429935111, bm_train=0.6665084826195278, mm_train=0.11738520336749442,
val_loss=2.698048718232375, avg_m_val=0.4265932196717681, bm_val=0.7658662092624356, mm_val=0.08732023008110051
epoch=1


                                                                 

train_loss=2.6757369889023477, avg_m_train=0.4420646999793449, bm_train=0.7293296425562643, mm_train=0.1547997574024255,
val_loss=2.5606718595944917, avg_m_val=0.3816671622788058, bm_val=0.6376811594202898, mm_val=0.12565316513732183
epoch=2


                                                                 

train_loss=2.4660954707756826, avg_m_train=0.4715246815357111, bm_train=0.7618322547900013, mm_train=0.18121710828142093,
val_loss=2.35240916633606, avg_m_val=0.38462482874228854, bm_val=0.6083123425692695, mm_val=0.16093731491530766
epoch=3


                                                                 

train_loss=2.198534878930394, avg_m_train=0.5145566142759812, bm_train=0.8134019324883925, mm_train=0.21571129606356976,
val_loss=2.1968914466271032, avg_m_val=0.5286131728968497, bm_val=0.8280450358239508, mm_val=0.22918130996974856
epoch=4


                                                                 

train_loss=1.9677621922608428, avg_m_train=0.5617647822887631, bm_train=0.8590237237610234, mm_train=0.26450584081650286,
val_loss=1.9889327692618737, avg_m_val=0.5706532921193006, bm_val=0.8606183476938672, mm_val=0.28068823654473396
epoch=5


                                                                 

train_loss=1.7936789397991693, avg_m_train=0.5972867831281844, bm_train=0.8920863309352518, mm_train=0.302487235321117,
val_loss=1.8625640327013455, avg_m_val=0.5764697651348604, bm_val=0.8699271592091571, mm_val=0.28301237106056365
epoch=6


                                                                 

train_loss=1.6665985543341348, avg_m_train=0.6179452097391458, bm_train=0.9068506653523903, mm_train=0.32903975412590125,
val_loss=1.7873140371029193, avg_m_val=0.5964707076410254, bm_val=0.8771750255885363, mm_val=0.3157663896935144
epoch=7


                                                                 

train_loss=1.511274488542915, avg_m_train=0.6570759222402701, bm_train=0.9336797822321208, mm_train=0.3804720622484195,
val_loss=1.7062798019555898, avg_m_val=0.6187626594301221, bm_val=0.9182879377431906, mm_val=0.31923738111705346
epoch=8


                                                                 

train_loss=1.3962102105491117, avg_m_train=0.6754521049288924, bm_train=0.9414531501423443, mm_train=0.4094510597154405,
val_loss=1.851147768167349, avg_m_val=0.6228207814983567, bm_val=0.9067490984028851, mm_val=0.33889246459382844
epoch=9


                                                                 

train_loss=1.2881455561006285, avg_m_train=0.6930962175325686, bm_train=0.9542709232096636, mm_train=0.4319215118554735,
val_loss=1.7408592448601357, avg_m_val=0.6291095485487631, bm_val=0.9293608841902932, mm_val=0.32885821290723305
epoch=10


                                                                  

train_loss=1.1914650757108918, avg_m_train=0.7192413681988394, bm_train=0.9619411257544033, mm_train=0.4765416106432757,
val_loss=1.7099907430502084, avg_m_val=0.6455821836724224, bm_val=0.9228441754916793, mm_val=0.3683201918531654
epoch=11


                                                                  

train_loss=1.076175738067814, avg_m_train=0.7413835718267655, bm_train=0.9735099337748344, mm_train=0.5092572098786966,
val_loss=1.7460321214382466, avg_m_val=0.6350719002242382, bm_val=0.9265143992055611, mm_val=0.3436294012429152
epoch=12


                                                                  

train_loss=0.9902703933945252, avg_m_train=0.7668018693451603, bm_train=0.9798525798525799, mm_train=0.5537511588377406,
val_loss=1.7611452328608586, avg_m_val=0.6414984915454501, bm_val=0.9172168613509396, mm_val=0.36578012173996066
epoch=13


                                                                  

train_loss=0.8690564529938892, avg_m_train=0.7884148905029039, bm_train=0.9830674846625767, mm_train=0.5937622963432312,
val_loss=1.7693878092399011, avg_m_val=0.6502501661768281, bm_val=0.9300291545189504, mm_val=0.3704711778347056
epoch=14


                                                                  

train_loss=0.74078660613366, avg_m_train=0.8193048076264721, bm_train=0.9883763611892817, mm_train=0.6502332540636624,
val_loss=2.0310716361265917, avg_m_val=0.6492104629432701, bm_val=0.9162067235323633, mm_val=0.3822142023541769
epoch=15


                                                                  

train_loss=0.6083665415439143, avg_m_train=0.8499148087282398, bm_train=0.9898371495041018, mm_train=0.7099924679523777,
val_loss=2.191507838029128, avg_m_val=0.6552017577670004, bm_val=0.9326171875, mm_val=0.37778632803400086
epoch=16


                                                                  

train_loss=0.48171027030749003, avg_m_train=0.8800806171296593, bm_train=0.992545521202493, mm_train=0.7676157130568257,
val_loss=2.403765301484328, avg_m_val=0.6572646728330291, bm_val=0.9265426052889324, mm_val=0.38798674037712594
epoch=17


                                                                  

train_loss=0.36972297679197924, avg_m_train=0.9121784073302385, bm_train=0.9932754615478665, mm_train=0.8310813531126104,
val_loss=2.5059396054194525, avg_m_val=0.6422366377776718, bm_val=0.9233716475095786, mm_val=0.36110162804576496
epoch=18


                                                                  

train_loss=0.2760834902150478, avg_m_train=0.9374241388602758, bm_train=0.996826171875, mm_train=0.8780221058455514,
val_loss=2.7709021315941444, avg_m_val=0.6552072635933206, bm_val=0.9217221135029354, mm_val=0.3886924136837057
epoch=19


                                                                  

train_loss=0.19357309994834937, avg_m_train=0.9544320452077855, bm_train=0.997071027581157, mm_train=0.911793062834414,
val_loss=3.048928103373601, avg_m_val=0.6499387210781873, bm_val=0.9298500241896468, mm_val=0.3700274179667277
epoch=20


                                                                  

train_loss=0.1537520718622931, avg_m_train=0.9629571901661512, bm_train=0.9975574010747436, mm_train=0.9283569792575589,
val_loss=3.3764411067962645, avg_m_val=0.6465940112635482, bm_val=0.9220338983050848, mm_val=0.37115412422201155
epoch=21


                                                                   

train_loss=0.14448251097782455, avg_m_train=0.9671362868113763, bm_train=0.9981687217677939, mm_train=0.9361038518549587,
val_loss=3.233098315018874, avg_m_val=0.652593947643784, bm_val=0.912850812407681, mm_val=0.39233708287988706
early stopping
fold 4/5
epoch=0


                                                                 

train_loss=3.0122476512073906, avg_m_train=0.4001066115890235, bm_train=0.6700920062134066, mm_train=0.13012121696464046,
val_loss=2.636300521738389, avg_m_val=0.3805576953199732, bm_val=0.6578669482576558, mm_val=0.10324844238229063
epoch=1


                                                                 

train_loss=2.68557313688512, avg_m_train=0.43900746877685665, bm_train=0.7183098591549296, mm_train=0.15970507839878367,
val_loss=2.3654188034581205, avg_m_val=0.4330249340544803, bm_val=0.7116228070175439, mm_val=0.15442706109141668
epoch=2


                                                                 

train_loss=2.410866508858393, avg_m_train=0.4783202426937743, bm_train=0.7654416505480335, mm_train=0.19119883483951514,
val_loss=2.1377025866040995, avg_m_val=0.5079924417708015, bm_val=0.817773339990015, mm_val=0.19821154355158802
epoch=3


                                                                 

train_loss=2.2046972307476427, avg_m_train=0.5202202178865151, bm_train=0.8162854988690625, mm_train=0.2241549369039676,
val_loss=2.0287340715819715, avg_m_val=0.5550309769777071, bm_val=0.864787386526517, mm_val=0.24527456742889728
epoch=4


                                                                 

train_loss=1.952546352037423, avg_m_train=0.5654082227719153, bm_train=0.863830858293481, mm_train=0.26698558725034943,
val_loss=1.7731016570446538, avg_m_val=0.5674894453659685, bm_val=0.8870636550308009, mm_val=0.24791523570113597
epoch=5


                                                                 

train_loss=1.8047909130803608, avg_m_train=0.5984489720722147, bm_train=0.8905935613682092, mm_train=0.30630438277622013,
val_loss=1.772482862659529, avg_m_val=0.5989457650126475, bm_val=0.889651790093183, mm_val=0.3082397399321121
epoch=6


                                                                 

train_loss=1.6779756888342627, avg_m_train=0.6214491015440468, bm_train=0.9050404480398258, mm_train=0.33785775504826776,
val_loss=1.6320841172162224, avg_m_val=0.6030727457455466, bm_val=0.8794102159031069, mm_val=0.32673527558798626
epoch=7


                                                                 

train_loss=1.5605852790982615, avg_m_train=0.6415962888658043, bm_train=0.9269817454363591, mm_train=0.35621083229524947,
val_loss=1.6706588011161954, avg_m_val=0.6383044190912246, bm_val=0.9221153846153847, mm_val=0.3544934535670645
epoch=8


                                                                 

train_loss=1.4060067483518985, avg_m_train=0.675765303040871, bm_train=0.9461833477669183, mm_train=0.4053472583148239,
val_loss=1.5453206814971625, avg_m_val=0.6509601386225544, bm_val=0.9318961293483586, mm_val=0.37002414789675026
epoch=9


                                                                 

train_loss=1.3239207693103778, avg_m_train=0.6864977021347501, bm_train=0.9547862510779844, mm_train=0.4182091531915158,
val_loss=1.6285149607003904, avg_m_val=0.6212584950026212, bm_val=0.9162462159434914, mm_val=0.326270774061751
epoch=10


                                                                 

train_loss=1.1884063466741073, avg_m_train=0.716570193998721, bm_train=0.9661037840502896, mm_train=0.46703660394715246,
val_loss=1.5808601005404603, avg_m_val=0.6484034987302177, bm_val=0.9264197530864198, mm_val=0.37038724437401577
epoch=11


                                                                  

train_loss=1.0691999137392438, avg_m_train=0.7398558614558247, bm_train=0.9709717097170971, mm_train=0.5087400131945523,
val_loss=1.5245495333391077, avg_m_val=0.6847161923466714, bm_val=0.9428975932043416, mm_val=0.42653479148900103
epoch=12


                                                                  

train_loss=0.9382727939231499, avg_m_train=0.7783510490137406, bm_train=0.9825766871165644, mm_train=0.574125410910917,
val_loss=1.562195382866205, avg_m_val=0.665495842272188, bm_val=0.9274193548387096, mm_val=0.40357232970566653
epoch=13


                                                                  

train_loss=0.8492955359272987, avg_m_train=0.7981381231172604, bm_train=0.9807338323720702, mm_train=0.6155424138624506,
val_loss=1.5926557405322206, avg_m_val=0.6789916876220995, bm_val=0.9394550958627649, mm_val=0.4185282793814341
epoch=14


                                                                  

train_loss=0.7259125243960979, avg_m_train=0.8238695142162418, bm_train=0.9874600442586673, mm_train=0.6602789841738164,
val_loss=1.6961167840396656, avg_m_val=0.6749246353435195, bm_val=0.9425061425061425, mm_val=0.4073431281808966
epoch=15


                                                                  

train_loss=0.5566107261160763, avg_m_train=0.8624620658886095, bm_train=0.9924075434729366, mm_train=0.7325165883042826,
val_loss=1.9001491864522297, avg_m_val=0.6796436012782512, bm_val=0.9507557289127255, mm_val=0.4085314736437768
epoch=16


                                                                  

train_loss=0.4069748452002696, avg_m_train=0.9024227835856058, bm_train=0.9957144606342598, mm_train=0.8091311065369516,
val_loss=2.1949165568632236, avg_m_val=0.668768662941376, bm_val=0.9322709163346613, mm_val=0.4052664095480908
early stopping
fold 5/5
epoch=0


                                                                 

train_loss=3.020294966161929, avg_m_train=0.40571911559882423, bm_train=0.6850789935634874, mm_train=0.12635923763416101,
val_loss=2.7655998310278305, avg_m_val=0.4144396144671693, bm_val=0.7762962962962963, mm_val=0.05258293263804234
epoch=1


                                                                 

train_loss=2.7196389024124277, avg_m_train=0.4253963105769015, bm_train=0.6953233329131476, mm_train=0.15546928824065548,
val_loss=2.744625991603495, avg_m_val=0.34581131482955607, bm_val=0.5966562173458725, mm_val=0.09496641231323967
epoch=2


                                                                 

train_loss=2.5765581784524034, avg_m_train=0.44073543253356895, bm_train=0.7152886115444618, mm_train=0.16618225352267613,
val_loss=2.4237796253811528, avg_m_val=0.3423737095423701, bm_val=0.5407725321888412, mm_val=0.14397488689589893
epoch=3


                                                                 

train_loss=2.29988223688614, avg_m_train=0.4926225884272301, bm_train=0.7786617973131603, mm_train=0.2065833795412999,
val_loss=2.173434287206524, avg_m_val=0.532436675001626, bm_val=0.8415349887133183, mm_val=0.2233383612899335
epoch=4


                                                                 

train_loss=2.082062268026785, avg_m_train=0.5330814323680888, bm_train=0.8218782791185729, mm_train=0.24428458561760477,
val_loss=1.8653691503828371, avg_m_val=0.5664561082225934, bm_val=0.8613625535969509, mm_val=0.27154966284823584
epoch=5


                                                                 

train_loss=1.9087905890476822, avg_m_train=0.5769554061491049, bm_train=0.8757803541852466, mm_train=0.27813045811296333,
val_loss=1.8763472647540436, avg_m_val=0.5467887202110416, bm_val=0.8314720812182741, mm_val=0.2621053592038092
epoch=6


                                                                 

train_loss=1.7544757363691088, avg_m_train=0.6055384822010512, bm_train=0.8997598280874731, mm_train=0.31131713631462943,
val_loss=1.703838436386302, avg_m_val=0.6041603046878848, bm_val=0.9228675136116152, mm_val=0.28545309576415445
epoch=7


                                                                 

train_loss=1.6337474143930513, avg_m_train=0.6277806123748567, bm_train=0.9141267231566966, mm_train=0.34143450159301686,
val_loss=1.6529521877515412, avg_m_val=0.6087693932246654, bm_val=0.8841698841698842, mm_val=0.3333689022794465
epoch=8


                                                                 

train_loss=1.5249130608267514, avg_m_train=0.6494416414780799, bm_train=0.927931339139215, mm_train=0.37095194381694474,
val_loss=1.7012025546862592, avg_m_val=0.6157313523080636, bm_val=0.9141592920353983, mm_val=0.31730341258072897
epoch=9


                                                                 

train_loss=1.4259817345186578, avg_m_train=0.6710998315696224, bm_train=0.9411617426139208, mm_train=0.40103792052532383,
val_loss=1.633458542438802, avg_m_val=0.6310262953630446, bm_val=0.9182624941616068, mm_val=0.3437900965644824
epoch=10


                                                                 

train_loss=1.334219566940128, avg_m_train=0.692545421226633, bm_train=0.950050454086781, mm_train=0.435040388366485,
val_loss=1.466841250004928, avg_m_val=0.6641670412318386, bm_val=0.9446730681298583, mm_val=0.38366101433381883
epoch=11


                                                                 

train_loss=1.2351690111491693, avg_m_train=0.7109970953136611, bm_train=0.9595795269678388, mm_train=0.4624146636594834,
val_loss=1.531120426789562, avg_m_val=0.6557038042030257, bm_val=0.938475665748393, mm_val=0.3729319426576582
epoch=12


                                                                  

train_loss=1.1304419582759886, avg_m_train=0.7357018301881622, bm_train=0.9665083729067733, mm_train=0.5048952874695511,
val_loss=1.5055307603624866, avg_m_val=0.6593086576806931, bm_val=0.928273947246645, mm_val=0.39034336811474124
epoch=13


                                                                  

train_loss=1.019066323525478, avg_m_train=0.7569208835336236, bm_train=0.9745826065287815, mm_train=0.5392591605384658,
val_loss=1.6354531348095358, avg_m_val=0.6612987881094169, bm_val=0.9300797747536368, mm_val=0.392517801465197
epoch=14


                                                                  

train_loss=0.920502627213573, avg_m_train=0.7759845709497378, bm_train=0.9775336994508238, mm_train=0.5744354424486519,
val_loss=1.6921590370015427, avg_m_val=0.6607000721942882, bm_val=0.9286035024696901, mm_val=0.3927966419188862
epoch=15


                                                                  

train_loss=0.8022805951978069, avg_m_train=0.8088399123850465, bm_train=0.9856662096472641, mm_train=0.6320136151228287,
val_loss=1.6488284223219927, avg_m_val=0.6731725276398385, bm_val=0.9427937915742793, mm_val=0.4035512637053976
epoch=16


                                                                  

train_loss=0.675384637981407, avg_m_train=0.8331872924743041, bm_train=0.9896598978447739, mm_train=0.6767146871038343,
val_loss=1.8247153855250093, avg_m_val=0.6835827917642137, bm_val=0.9383499546690843, mm_val=0.4288156288593431
epoch=17


                                                                  

train_loss=0.5560441906669159, avg_m_train=0.8698234952282722, bm_train=0.9911636589919104, mm_train=0.7484833314646341,
val_loss=2.0106090050117644, avg_m_val=0.668326632719142, bm_val=0.9361892012494422, mm_val=0.400464064188842
epoch=18


                                                                  

train_loss=0.43177175223688036, avg_m_train=0.8954052683404808, bm_train=0.9935307290370738, mm_train=0.7972798076438878,
val_loss=2.2584270512493974, avg_m_val=0.6671055613100879, bm_val=0.9302325581395349, mm_val=0.4039785644806409
epoch=19


                                                                  

train_loss=0.33312320549000973, avg_m_train=0.9221811785271874, bm_train=0.9931668530252206, mm_train=0.8511955040291542,
val_loss=2.3696280372734027, avg_m_val=0.6681124130098388, bm_val=0.941016333938294, mm_val=0.39520849208138364
epoch=20


                                                                  

train_loss=0.24128627789406942, avg_m_train=0.9435510460824666, bm_train=0.9947787170561909, mm_train=0.8923233751087424,
val_loss=2.749658064858724, avg_m_val=0.665593645457957, bm_val=0.9315315315315316, mm_val=0.39965575938438225
epoch=21


                                                                  

train_loss=0.17760244204152015, avg_m_train=0.9611590204922356, bm_train=0.996396172486641, mm_train=0.92592186849783,
val_loss=2.885305673605843, avg_m_val=0.6753359850403744, bm_val=0.9387199273717658, mm_val=0.41195204270898295
early stopping
oof_m=0.6785521357905444, oof_bm=0.9360093077370564, oof_mm=0.42109496384403233, 


In [170]:
# original oof_m=0.6603333109618074, oof_bm=0.9376289166093704, oof_mm=0.38303770531424436,
# pretrained timm/resnet50.a1_in1k oof_m=0.47988258895431113, oof_bm=0.7790515893694633, oof_mm=0.18071358853915895,
# not pretrained timm/resnet50.a1_in1k oof_m=0.6922992183479115, oof_bm=0.9439811635436084, oof_mm=0.4406172731522147,
# pretrained oof_m=0.530565182728787, oof_bm=0.8284389489953632, oof_mm=0.23269141646221086, oof_m=0.530565182728787, oof_bm=0.8284389489953632, oof_mm=0.23269141646221086,
# not pretrained oof_m=0.6773831836576342, oof_bm=0.9327846364883402, oof_mm=0.4219817308269282, oof_m=0.6773831836576342, oof_bm=0.9327846364883402, oof_mm=0.4219817308269282,
# not pretrained resnest50d.in1k oof_m=0.6905919294402194, oof_bm=0.9451518119490695, oof_mm=0.4360320469313694,
# not pretrained oof_m=0.6785521357905444, oof_bm=0.9360093077370564, oof_mm=0.42109496384403233, oof_m=0.6785521357905444, oof_bm=0.9360093077370564, oof_mm=0.42109496384403233,