# Colab setup

In [1]:
# import sys
# # if "google.colab" in str(get_ipython()):
# ! {sys.executable} -m pip install pytorch-lifestream
# ! {sys.executable} -m pip install catboost
# ! {sys.executable} -m pip install torchmetrics

# Supervised task

## Prepare your data

- Use `Pyspark` in local or cluster mode for big dataset and `Pandas` for small.
- Split data into required parts (train, valid, test, ...).
- Use `ptls.preprocessing` for simple data preparation.
- Transform features to compatible format using `Pyspark` or `Pandas` functions.
You can also use `ptls.data_load.preprocessing` for common data transformation patterns.
- Split sequences to `ptls-data` format with `ptls.data_load.split_tools`. Save prepared data into `Parquet` format or
keep it in memory (`Pickle` also works).
- Use one of the available `ptls.data_load.datasets` to define input for the models.

In [2]:
import torch

import numpy as np
import pandas as pd
import torchmetrics
import pytorch_lightning as pl

from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.model_selection import train_test_split
from functools import partial
from ptls.frames import PtlsDataModule
from ptls.nn import TrxEncoder, RnnSeqEncoder, Head
from ptls.data_load.datasets import MemoryMapDataset
from ptls.preprocessing import PandasDataPreprocessor
from ptls.frames.supervised import SeqToTargetDataset, SequenceToTarget
from ptls.data_load.utils import collate_feature_dict
from ptls.frames.inference_module import InferenceModule
from ptls.frames.coles import CoLESModule
from ptls.frames.coles.multimodal_dataset import MultiModalIterableDataset
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping


In [3]:
from functools import partial
from datetime import timedelta
from time import time

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
import catboost

import torch
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.loggers import TensorBoardLogger

from sklearn.model_selection import train_test_split

from ptls.nn import TrxEncoder
from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.frames import PtlsDataModule
from ptls.frames.coles import CoLESModule
from ptls.frames.coles.split_strategy import SampleSlices
from ptls.frames.coles.multimodal_dataset import MultiModalDataset
from ptls.frames.coles.multimodal_dataset import MultiModalIterableDataset
from ptls.frames.coles.multimodal_dataset import MultiModalSortTimeSeqEncoderContainer
from ptls.frames.coles.multimodal_inference_dataset import MultiModalInferenceDataset
from ptls.frames.coles.multimodal_inference_dataset import MultiModalInferenceIterableDataset
from ptls.frames.inference_module import InferenceModuleMultimodal
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.data_load import IterableProcessingDataset
from ptls.data_load.utils import collate_feature_dict
from ptls.data_load.datasets import MemoryMapDataset
from ptls.preprocessing import PandasDataPreprocessor

In [4]:
df_target = pd.read_csv(
    "https://huggingface.co/datasets/dllllb/age-group-prediction/resolve/main/train_target.csv?download=true"
)
df_target

Unnamed: 0,client_id,bins
0,24662,2
1,1046,0
2,34089,2
3,34848,1
4,47076,3
...,...,...
29995,14303,1
29996,22301,2
29997,25731,0
29998,16820,3


In [5]:
# df_target_train, df_target_test = train_test_split(
#     df_target, test_size=7000, stratify=df_target["bins"], random_state=142)
# df_target_train, df_target_valid = train_test_split(
#     df_target_train, test_size=3000, stratify=df_target_train["bins"], random_state=142)
# print("Split {} records to train: {}, valid: {}, test: {}".format(
#     *[
#       len(df)
#       for df in [df_target, df_target_train, df_target_valid, df_target_test]
#     ]
# ))

In [6]:
df_trx = pd.read_csv(
    "https://huggingface.co/datasets/dllllb/age-group-prediction/resolve/main/transactions_train.csv.gz?download=true",
    compression="gzip"
)
df_trx

Unnamed: 0,client_id,trans_date,small_group,amount_rur
0,33172,6,4,71.463
1,33172,6,35,45.017
2,33172,8,11,13.887
3,33172,9,11,15.983
4,33172,10,11,21.341
...,...,...,...,...
26450572,43300,727,25,7.602
26450573,43300,727,15,3.709
26450574,43300,727,1,6.448
26450575,43300,727,11,24.669


In [7]:
len(df_target)

30000

In [8]:
len(df_trx)

26450577

In [9]:
sourceA = df_trx[["client_id", "trans_date", "small_group"]]
sourceB = df_trx[["client_id", "trans_date", "amount_rur"]]

In [10]:
sourceA_drop_indices = np.random.choice(sourceA.index, int(1500000), replace=False)
sourceB_drop_indices = np.random.choice(sourceB.index, int(4500000), replace=False)

sourceA = sourceA.drop(sourceA_drop_indices).reset_index(drop=True)
sourceB = sourceB.drop(sourceB_drop_indices).reset_index(drop=True)

In [11]:
len(sourceA), len(sourceB)

(24950577, 21950577)

In [12]:
sourceA["trans_date"] = sourceA["trans_date"].apply(lambda x: x * 3600)
sourceB["trans_date"] = sourceB["trans_date"].apply(lambda x: x * 3600)

In [13]:
sourceA_preprocessor = PandasDataPreprocessor(
    col_id="client_id",
    col_event_time="trans_date",
    event_time_transformation="none",
    cols_category=["small_group"],
    return_records=False,
)

sourceB_preprocessor = PandasDataPreprocessor(
    col_id="client_id",
    col_event_time="trans_date",
    event_time_transformation="none",
    cols_numerical=["amount_rur"],
    return_records=False,
)

In [14]:
processed_sourceA = sourceA_preprocessor.fit_transform(sourceA)
processed_sourceB = sourceB_preprocessor.fit_transform(sourceB)

In [15]:
processed_sourceA.columns = [
    "sourceA_" + str(col) if str(col) != "client_id" else str(col)
    for col in processed_sourceA.columns
]

In [16]:
processed_sourceB.columns = [
    "sourceB_" + str(col) if str(col) != "client_id" else str(col)
    for col in processed_sourceB.columns
]

In [17]:
joined_data = processed_sourceA.merge(processed_sourceB, how="outer", on="client_id")

In [18]:
joined_data

Unnamed: 0,client_id,sourceA_trans_date,sourceA_event_time,sourceA_small_group,sourceB_trans_date,sourceB_event_time,sourceB_amount_rur
0,4,"[tensor(0), tensor(7200), tensor(10800), tenso...","[tensor(0), tensor(7200), tensor(10800), tenso...","[tensor(1), tensor(3), tensor(1), tensor(1), t...","[tensor(0), tensor(7200), tensor(10800), tenso...","[tensor(0), tensor(7200), tensor(10800), tenso...","[tensor(10.2090, dtype=torch.float64), tensor(..."
1,6,"[tensor(0), tensor(18000), tensor(36000), tens...","[tensor(0), tensor(18000), tensor(36000), tens...","[tensor(4), tensor(3), tensor(1), tensor(3), t...","[tensor(18000), tensor(36000), tensor(39600), ...","[tensor(18000), tensor(36000), tensor(39600), ...","[tensor(13.7380, dtype=torch.float64), tensor(..."
2,7,"[tensor(3600), tensor(7200), tensor(43200), te...","[tensor(3600), tensor(7200), tensor(43200), te...","[tensor(3), tensor(52), tensor(1), tensor(2), ...","[tensor(3600), tensor(7200), tensor(43200), te...","[tensor(3600), tensor(7200), tensor(43200), te...","[tensor(18.3190, dtype=torch.float64), tensor(..."
3,10,"[tensor(14400), tensor(14400), tensor(14400), ...","[tensor(14400), tensor(14400), tensor(14400), ...","[tensor(16), tensor(1), tensor(52), tensor(10)...","[tensor(14400), tensor(14400), tensor(14400), ...","[tensor(14400), tensor(14400), tensor(14400), ...","[tensor(9.3420, dtype=torch.float64), tensor(5..."
4,11,"[tensor(0), tensor(7200), tensor(21600), tenso...","[tensor(0), tensor(7200), tensor(21600), tenso...","[tensor(3), tensor(9), tensor(1), tensor(1), t...","[tensor(7200), tensor(21600), tensor(28800), t...","[tensor(7200), tensor(21600), tensor(28800), t...","[tensor(22.8980, dtype=torch.float64), tensor(..."
...,...,...,...,...,...,...,...
29995,49993,"[tensor(3600), tensor(14400), tensor(14400), t...","[tensor(3600), tensor(14400), tensor(14400), t...","[tensor(10), tensor(49), tensor(37), tensor(21...","[tensor(3600), tensor(14400), tensor(14400), t...","[tensor(3600), tensor(14400), tensor(14400), t...","[tensor(78.8800, dtype=torch.float64), tensor(..."
29996,49995,"[tensor(0), tensor(3600), tensor(3600), tensor...","[tensor(0), tensor(3600), tensor(3600), tensor...","[tensor(3), tensor(9), tensor(2), tensor(9), t...","[tensor(3600), tensor(3600), tensor(7200), ten...","[tensor(3600), tensor(3600), tensor(7200), ten...","[tensor(2.6520, dtype=torch.float64), tensor(9..."
29997,49996,"[tensor(3600), tensor(3600), tensor(7200), ten...","[tensor(3600), tensor(3600), tensor(7200), ten...","[tensor(13), tensor(1), tensor(5), tensor(2), ...","[tensor(3600), tensor(3600), tensor(7200), ten...","[tensor(3600), tensor(3600), tensor(7200), ten...","[tensor(215.6500, dtype=torch.float64), tensor..."
29998,49997,"[tensor(3600), tensor(7200), tensor(10800), te...","[tensor(3600), tensor(7200), tensor(10800), te...","[tensor(1), tensor(1), tensor(1), tensor(1), t...","[tensor(3600), tensor(7200), tensor(10800), te...","[tensor(3600), tensor(7200), tensor(10800), te...","[tensor(32.1940, dtype=torch.float64), tensor(..."


In [19]:
joined_data = joined_data.applymap(lambda x: torch.tensor([]) if pd.isna(x) else x)

  joined_data = joined_data.applymap(lambda x: torch.tensor([]) if pd.isna(x) else x)


In [20]:
train_df, test_df = train_test_split(joined_data,
                                     test_size=0.4,
                                     random_state=42)
train_df, valid_df = train_test_split(train_df,
                                      test_size=0.1,
                                      random_state=42)

In [21]:
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

In [22]:
train_dict = train_df.to_dict("records")
valid_dict = valid_df.to_dict("records")
test_dict = test_df.to_dict("records")

In [23]:
source_features = {
    "sourceA": {
        "categorical": ["small_group"],
        "numeric": [],
    },
    "sourceB": {
        "categorical": [],
        "numeric": ["amount_rur"],
    },
}

In [24]:
inf_test_data = MultiModalInferenceIterableDataset(
    data = test_dict,
    source_features = source_features,
    col_id = "client_id",
    col_time = "trans_date",
    source_names = ("sourceA", "sourceB")
)

In [25]:
inf_test_loader = DataLoader(
    dataset = inf_test_data,
    collate_fn = partial(inf_test_data.collate_fn, col_id="client_id"),
    shuffle = False,
    num_workers = 0,
    batch_size = 8
)

## MY CODe

In [26]:
!git clone https://github.com/google-research/google-research.git

fatal: destination path 'google-research' already exists and is not an empty directory.


In [27]:
import sys
sys.path.append("google-research/graph_embedding/metrics")

In [28]:
from metrics import (rankme,
        coherence,
        pseudo_condition_number,
        alpha_req,
        stable_rank,
        ne_sum,
        self_clustering)

In [29]:
# !pip install git+https://github.com/simonzhang00/ripser-plusplus.git

In [30]:
import ripserplusplus as rpp
def ripser_metric(embeddings, u=None, s=None):
    
    diagrams = rpp.run("--format point-cloud", embeddings)
    persistence = {}

    for k in range(len(diagrams)):
        persistence_sum = sum([death - birth for birth, death in diagrams[k] if death > birth])
        persistence[f"ripser_sum_H{k}"] = persistence_sum

    return persistence

In [31]:
import logging
import os

os.makedirs('logs/age', exist_ok=True)

logger = logging.getLogger("my_logger")
logger.setLevel(logging.INFO)

file_handler = logging.FileHandler("logs/age/hidden_size_experiment.log")
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
file_handler.setFormatter(formatter)

# Удалим другие обработчики
if logger.hasHandlers():
    logger.handlers.clear()

logger.addHandler(file_handler)
logger.info("🔧 Логгер настроен вручную")

In [32]:
# np.random.seed(42)
# random_embedding = np.random.rand(100, 8)

# # 🔹 Санити-чек
# score, elapsed = ripser_metric(random_embedding)
# print(f"Ripser metric: {score:.4f}, computed in {elapsed:.4f} seconds")

In [33]:
def create_datasets(train_dict, valid_dict, params, source_features):
    splitter = SampleSlices(
        split_count=params["split_count"],
        cnt_min=params["cnt_min"],
        cnt_max=params["cnt_max"],
    )

    train_data = MultiModalIterableDataset(
        data=train_dict,
        splitter=splitter,
        source_features=source_features,
        col_id="client_id",
        col_time="trans_date",
        source_names=("sourceA", "sourceB"),
    )

    valid_data = MultiModalIterableDataset(
        data=valid_dict,
        splitter=splitter,
        source_features=source_features,
        col_id="client_id",
        col_time="trans_date",
        source_names=("sourceA", "sourceB"),
    )

    data_loader = PtlsDataModule(
        train_data=train_data,
        train_batch_size=params["batch_size"],
        train_num_workers=0,
        valid_data=valid_data,
    )

    return data_loader

In [34]:
def compute_metrics(model, pl_trainer, inf_test_loader, selected_metrics=None, n_samples=10, sample_fraction=1/20):
    import gc
    from sklearn.utils import resample
    from time import time
    logger.info(f"{sample_fraction=}")

    model.eval()
    inference_module = InferenceModuleMultimodal(
        model=model,
        pandas_output=True,
        drop_seq_features=True,
        model_out_name="emb",
        col_id="client_id",
    )
    inference_module.model.is_reduce_sequence = True

    # Получение эмбеддингов
    inf_test_embeddings = pd.concat(
        pl_trainer.predict(inference_module, inf_test_loader),
        axis=0,
    )
    embeddings_np = inf_test_embeddings.drop(columns=["client_id"]).to_numpy(dtype=np.float32)
    sample_size = max(1, int(sample_fraction * embeddings_np.shape[0]))

    # Метрики
    available_metrics = {
        "rankme": rankme,
        "coherence": coherence,
        "pseudo_condition_number": pseudo_condition_number,
        "alpha_req": alpha_req,
        "stable_rank": stable_rank,
        "ne_sum": ne_sum,
        "self_clustering": self_clustering,
        "ripser": ripser_metric
    }
    if selected_metrics is None:
        selected_metrics = list(available_metrics.keys())

    metrics = {name: [] for name in selected_metrics}
    times = {name: [] for name in selected_metrics}

    for i in range(n_samples):
        sample = resample(embeddings_np, n_samples=sample_size, replace=False, random_state=42 + i)
        u, s, _ = np.linalg.svd(sample, compute_uv=True, full_matrices=False)

        for metric_name in selected_metrics:
            if metric_name not in available_metrics:
                continue

            try:
                t0 = time()
                result = available_metrics[metric_name](sample, u=u, s=s)
                t = time() - t0

                if isinstance(result, dict):
                    for subname, val in result.items():
                        if subname not in metrics:
                            metrics[subname] = []
                            times[subname] = []
                        metrics[subname].append(val)
                        times[subname].append(t)
                else:
                    if metric_name not in metrics:
                        metrics[metric_name] = []
                        times[metric_name] = []
                    metrics[metric_name].append(result)
                    times[metric_name].append(t)
            except Exception as e:
                print(f"⚠️ Failed to compute {metric_name} on sample {i}: {e}")

        gc.collect()

    averaged_metrics = {k: np.mean(v) for k, v in metrics.items()}
    std_metrics = {k: np.std(v) for k, v in metrics.items()}
    
    averaged_times = {k: np.mean(v) for k, v in times.items()}
    std_times = {k: np.std(v) for k, v in times.items()}

    print("\n📊 Средние значения метрик и время вычисления:")
    for metric_name in averaged_metrics:
        metric_value = averaged_metrics[metric_name]
        metric_time = averaged_times.get(metric_name, None)
        print(f"🧠 {metric_name:30s} = {metric_value:.4f} | ⏱ {metric_time:.4f} сек")

    return averaged_metrics, averaged_times, std_metrics, std_times, inf_test_embeddings


In [35]:
import catboost


def evaluate_model(model, pl_trainer, checkpoint=None, selected_metrics=None, sample_fraction=1/20):
    model.eval()
    metrics, times, std_metrics, std_times, inf_test_embeddings = compute_metrics(model, pl_trainer, inf_test_loader, selected_metrics, sample_fraction=sample_fraction)
    targets_df = df_target.set_index("client_id")
    inf_test_df = inf_test_embeddings.merge(targets_df, how="inner", on="client_id").set_index("client_id")
    
    X = inf_test_df.drop(columns=["bins"])
    y = inf_test_df["bins"]
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    
    classifier = catboost.CatBoostClassifier(
        iterations=150,
        random_seed=42,
        verbose=0,
    )
    classifier.fit(X_train, y_train)
    
    accuracy = classifier.score(X_test, y_test)

    del classifier
    
    return metrics, times, std_metrics, std_times, accuracy

In [36]:
fixed_params = {
    "batch_size": 64,
    "learning_rate": 0.001,
    "split_count": 3,
    "cnt_min": 10,
    "cnt_max": 50,
    "embedding_dim": 16,  # Размерность эмбеддингов
    "category_embedding_dim": 8,  # Размерность категорий эмбеддингов
    "hidden_size": 128,  # Размер скрытого слоя по умолчанию
}

# Список гиперпараметров для перебора
variable_params = {
    # "batch_size": [32, 64, 128], 
    # "learning_rate": [0.0001, 0.001, 0.05],
    # "split_count": [2, 3, 5],
    # "cnt_min": [5, 10, 20],
    # "cnt_max": [50, 80, 100],
    # "embedding_dim": [8, 16, 32],
    # "category_embedding_dim": [8, 16, 24],
    "hidden_size": [64, 128, 256, 1024],
}

# Создание списка всех гиперпараметров, которые нужно перебрать
all_hyperparameter_grids = []
for variable_param_name, variable_param_values in variable_params.items():
    for value in variable_param_values:
        hyperparameter_grid = {**fixed_params, variable_param_name: value}
        all_hyperparameter_grids.append((variable_param_name, hyperparameter_grid))


In [37]:
metric_names = [
    "rankme", "coherence", "pseudo_condition_number",
    "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser"
]

In [38]:
category_embedding_dims = {
    "small_group": (150, fixed_params["category_embedding_dim"]),
}

In [39]:
import os

In [40]:
checkpoints_path = "age/checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)

In [41]:
splitter = SampleSlices(split_count=5, cnt_min=25, cnt_max=50)

In [42]:
class CustomLogger(pl.Callback):
    def __init__(self):
        super().__init__()
        self.early_stopping_epoch = None  # Запомним, на какой эпохе произошла остановка
    
    def on_train_epoch_end(self, trainer, pl_module):
        train_loss = trainer.callback_metrics.get("train_loss", None)
        val_loss = trainer.callback_metrics.get("val_loss", None)
        
        if train_loss is not None and val_loss is not None:
            print(f"Epoch {trainer.current_epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        # Если валидационный лосс увеличивается - фиксируем эпоху остановки
        if trainer.early_stopping_callback is not None and trainer.early_stopping_callback.wait_count == 0:
            self.early_stopping_epoch = trainer.current_epoch


custom_logger = CustomLogger()
early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    patience=5,
    mode="min",
    verbose=True
)

In [43]:
! rm -rf checkpoints

In [44]:
# ! rm age_tr_params_tun_full.csv

In [45]:
num_epochs = 30
output_csv = "age_tr_hidden_size2.csv"


metric_keys = [
    "rankme", "coherence", "pseudo_condition_number", 
    "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser_sum_H0", "ripser_sum_H1"
]

columns = (
    list(fixed_params.keys()) +
    ["checkpoint", "epoch_num", "accuracy", "early_stop_epoch", "hidden_size", "sample_fraction"] +
    [f"metric_{k}" for k in metric_keys] +
    [f"std_metric_{k}" for k in metric_keys] +
    [f"time_{k}" for k in metric_keys] +
    [f"std_time_{k}" for k in metric_keys]
)

In [None]:
from time import time
import os
import gc
import torch
import pandas as pd
import glob
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from functools import partial

cur_time = time()

for param in all_hyperparameter_grids:
    
    logger.info(f'All params are frozen except {param[0]}')
    params = param[1]
    logger.info(f"Testing parameters: {params}")

    train_loader = create_datasets(train_dict, valid_dict, params, source_features)

    sourceA_encoder_params = dict(
        embeddings_noise=0.003,
        linear_projection_size=64,
        embeddings={
            "small_group": {"in": len(np.unique(sourceA['small_group'])), "out": 32}
        },
    )
    
    sourceB_encoder_params = dict(
        embeddings_noise=0.003,
        linear_projection_size=64,
        numeric_values={"amount_rur": "identity"},
    )
    
    sourceA_encoder = TrxEncoder(**sourceA_encoder_params)
    sourceB_encoder = TrxEncoder(**sourceB_encoder_params)
    
    seq_encoder = MultiModalSortTimeSeqEncoderContainer(
        trx_encoders={
            "sourceA": sourceA_encoder,
            "sourceB": sourceB_encoder,
        },
        input_size=64,
        hidden_size=params["hidden_size"],  # Используем только текущее значение hidden_size
        seq_encoder_cls=RnnEncoder,
        type="gru",
    )

    model = CoLESModule(
        seq_encoder=seq_encoder,
        optimizer_partial=partial(torch.optim.Adam, lr=params["learning_rate"]),
        lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.5),
    )

    early_stopping_callback = EarlyStopping(
        monitor="loss",
        patience=5,
        mode="min",
        verbose=True
    )

    checkpoint_callback = ModelCheckpoint(
        dirpath=checkpoints_path,
        filename=f"model_{params['batch_size']}_{params['learning_rate']}_{params['split_count']}_{params['cnt_min']}_{params['cnt_max']}_{params['hidden_size']}{{epoch:02d}}",
        save_top_k=-1,
        every_n_epochs=1,
    )

    # Обучение модели
    pl_trainer = pl.Trainer(
        callbacks=[checkpoint_callback, early_stopping_callback, custom_logger],
        default_root_dir=checkpoints_path,
        check_val_every_n_epoch=1,
        max_epochs= num_epochs,
        accelerator="gpu",
        devices=1,
        enable_progress_bar=True,
        precision=16
    )
    model.train()
    pl_trainer.fit(model, train_loader)

    early_stop_epoch = getattr(custom_logger, "early_stopping_epoch", None) or num_epochs

    # Обработка чекпоинтов
    checkpoint_files = glob.glob(f"{checkpoints_path}/model_{params['batch_size']}_{params['learning_rate']}_{params['split_count']}_{params['cnt_min']}_{params['cnt_max']}_{params['hidden_size']}*.ckpt")
    checkpoint_files.sort()
    logger.info(f"Elapsed time: {time() - cur_time:.2f} seconds")

    logger.info(f'Early stop is {early_stop_epoch}')

    for i, checkpoint in enumerate(checkpoint_files):
        logger.info(f"Processing checkpoint number {i}")
        model = CoLESModule.load_from_checkpoint(checkpoint, seq_encoder=seq_encoder)
    
        # Вычисление метрик, времени, дисперсий и accuracy
        metrics, times, std_metrics, std_times, accuracy = evaluate_model(model, pl_trainer, checkpoint, sample_fraction=sample_fraction)
    
        # Преобразование результатов в плоские словари
        metrics_flattened = {f"metric_{k}": round(v, 4) for k, v in metrics.items()}
        std_metrics_flattened = {f"std_metric_{k}": round(v, 4) for k, v in std_metrics.items()}
        times_flattened = {f"time_{k}": round(v, 4) for k, v in times.items()}
        std_times_flattened = {f"std_time_{k}": round(v, 4) for k, v in std_times.items()}
    
        # Сбор всех результатов
        new_result = {
            **params,
            "checkpoint": checkpoint,
            "epoch_num": int(i),
            "accuracy": accuracy,
            "early_stop_epoch": int(early_stop_epoch),
            "sample_fraction": sample_fraction,
            **metrics_flattened,
            **std_metrics_flattened,
            **times_flattened,
            **std_times_flattened,
        }
    
        # Сохранение в CSV
        results = pd.DataFrame([new_result], columns=columns)

        if not os.path.exists(output_csv):  
            pd.DataFrame(columns=columns).to_csv(output_csv, mode="w", index=False, header=True)
            logger.info('csv file was created!')
        
        results.to_csv(output_csv, mode="a", header=False, index=False)
        

        del metrics, accuracy, new_result
        torch.cuda.empty_cache()
        gc.collect()

    logger.info(f"Removing checkpoints for parameters: {params}")
    for checkpoint in checkpoint_files:
        os.remove(checkpoint)

    del model
    del train_loader
    torch.cuda.empty_cache()
    gc.collect()

logger.info("Optimization complete!")

/home/dpetrovitch/venv/lib/python3.12/site-packages/lightning_fabric/connector.py:571: `precision=16` is supported for historical reasons but its usage is discouraged. Please set your precision to 16-mixed instead!
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/dpetrovitch/venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:76: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly 

Sanity Checking: |                                                                                       | 0/?…

/home/dpetrovitch/venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
/home/dpetrovitch/venv/lib/python3.12/site-packages/pytorch_lightning/utilities/data.py:123: Your `IterableDataset` has `__len__` defined. In combination with multi-process data loading (when num_workers > 1), `__len__` could be inaccurate if each worker is not configured independently to avoid having duplicate data.
/home/dpetrovitch/venv/lib/python3.12/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Training: |                                                                                              | 0/?…

Validation: |                                                                                            | 0/?…

Metric loss improved. New best score: 6.007


Validation: |                                                                                            | 0/?…

Metric loss improved by 0.741 >= min_delta = 0.0. New best score: 5.267


Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Metric loss improved by 0.019 >= min_delta = 0.0. New best score: 5.248


Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Metric loss improved by 0.503 >= min_delta = 0.0. New best score: 4.744


Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…