# Colab setup

In [1]:
import sys
# if "google.colab" in str(get_ipython()):
! {sys.executable} -m pip install pytorch-lifestream
! {sys.executable} -m pip install umap-learn
! {sys.executable} -m pip install catboost



# CoLES-demo-multimodal

**In this demo, we will try to show how Multimodal CoLES handles event data of different modalities.**

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
import os

if not os.path.exists("lightning_logs/CoLES-demo-multimodal"):
    !mkdir -p lightning_logs/CoLES-demo-multimodal

if not os.path.exists("CatBoostClassifier"):
    !mkdir -p CatBoostClassifier

if not os.path.exists("model"):
    !mkdir -p model

# Libraries

In [4]:
! pip install pytorch-lifestream



In [5]:
from functools import partial
from datetime import timedelta

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns
import catboost
from time import time

import torch
import pytorch_lightning as pl
from torch.utils.data.dataloader import DataLoader
from pytorch_lightning.loggers import TensorBoardLogger

from sklearn.model_selection import train_test_split

from ptls.nn import TrxEncoder
from ptls.nn.seq_encoder.rnn_encoder import RnnEncoder
from ptls.frames import PtlsDataModule
from ptls.frames.coles import CoLESModule
from ptls.frames.coles.split_strategy import SampleSlices
from ptls.frames.coles.multimodal_dataset import MultiModalDataset
from ptls.frames.coles.multimodal_dataset import MultiModalIterableDataset
from ptls.frames.coles.multimodal_dataset import MultiModalSortTimeSeqEncoderContainer
from ptls.frames.coles.multimodal_inference_dataset import MultiModalInferenceDataset
from ptls.frames.coles.multimodal_inference_dataset import MultiModalInferenceIterableDataset
from ptls.frames.inference_module import InferenceModuleMultimodal
from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.data_load import IterableProcessingDataset
from ptls.data_load.utils import collate_feature_dict
from ptls.data_load.datasets import MemoryMapDataset
from ptls.preprocessing import PandasDataPreprocessor

# Working with data

## Data load

In [6]:
transactions = pd.read_csv("https://huggingface.co/datasets/dllllb/transactions-gender/resolve/main/transactions.csv.gz?download=true", compression="gzip")
targets = pd.read_csv("https://huggingface.co/datasets/dllllb/transactions-gender/resolve/main/gender_train.csv?download=true")

In [7]:
transactions = transactions.dropna().reset_index(drop=True)
transactions

Unnamed: 0,customer_id,tr_datetime,mcc_code,tr_type,amount,term_id
0,39026145,208 14:18:12,5499,1010,-7254.31,453799
1,39026145,208 14:24:23,5499,1010,-1392.47,453306
2,39026145,209 15:09:16,5499,1010,-2852.31,453306
3,39026145,209 15:13:51,5499,1010,-853.45,058435
4,39026145,209 15:13:17,5499,1010,-1410.44,058435
...,...,...,...,...,...,...
4084146,61870738,453 16:03:02,5499,1010,-5176.84,10217113
4084147,61870738,454 10:54:60,5411,1010,-1652.77,022915
4084148,61870738,454 14:23:59,5499,1010,-4687.23,10217113
4084149,61870738,454 16:11:53,5541,1110,-4491.83,RU570124


Here
1. `customer_id` is the id of some user
2. `tr_datetime` is the time of the transaction
3. `mcc_code` is, in fact, the mcc code of the transaction
4. `tr_type` is the type of transaction (what was paid for)
5. `amount` is the amount of the transaction
6. `term-id` is the id of the terminal where the transaction was carried out

We will predict the gender of the user based on his transactions.

In [8]:
targets

Unnamed: 0,customer_id,gender
0,10928546,1
1,69348468,1
2,61009479,0
3,74045822,0
4,27979606,1
...,...,...
8395,90417572,0
8396,66837341,0
8397,10758984,1
8398,11376556,0


In [9]:
n_cutomers = len(pd.unique(transactions["customer_id"]))
n_labeling_cutomers = len(pd.unique(targets["customer_id"]))

print("n_cutomers:", n_cutomers)
print("n_labeling_cutomers:", n_labeling_cutomers)

n_cutomers: 14973
n_labeling_cutomers: 8400


In [10]:
list(transactions.columns)

['customer_id', 'tr_datetime', 'mcc_code', 'tr_type', 'amount', 'term_id']

In [11]:
sourceA = transactions[["customer_id", "tr_datetime", "mcc_code", "term_id"]]
sourceB = transactions[["customer_id", "tr_datetime", "tr_type", "amount"]]

In [12]:
sourceA

Unnamed: 0,customer_id,tr_datetime,mcc_code,term_id
0,39026145,208 14:18:12,5499,453799
1,39026145,208 14:24:23,5499,453306
2,39026145,209 15:09:16,5499,453306
3,39026145,209 15:13:51,5499,058435
4,39026145,209 15:13:17,5499,058435
...,...,...,...,...
4084146,61870738,453 16:03:02,5499,10217113
4084147,61870738,454 10:54:60,5411,022915
4084148,61870738,454 14:23:59,5499,10217113
4084149,61870738,454 16:11:53,5541,RU570124


In [13]:
sourceB

Unnamed: 0,customer_id,tr_datetime,tr_type,amount
0,39026145,208 14:18:12,1010,-7254.31
1,39026145,208 14:24:23,1010,-1392.47
2,39026145,209 15:09:16,1010,-2852.31
3,39026145,209 15:13:51,1010,-853.45
4,39026145,209 15:13:17,1010,-1410.44
...,...,...,...,...
4084146,61870738,453 16:03:02,1010,-5176.84
4084147,61870738,454 10:54:60,1010,-1652.77
4084148,61870738,454 14:23:59,1010,-4687.23
4084149,61870738,454 16:11:53,1110,-4491.83


In [14]:
sourceA_drop_indices = np.random.choice(sourceA.index, 130000, replace=False)
sourceB_drop_indices = np.random.choice(sourceB.index, 420000, replace=False)

sourceA = sourceA.drop(sourceA_drop_indices).reset_index(drop=True)
sourceB = sourceB.drop(sourceB_drop_indices).reset_index(drop=True)

In [15]:
sourceA

Unnamed: 0,customer_id,tr_datetime,mcc_code,term_id
0,39026145,208 14:18:12,5499,453799
1,39026145,208 14:24:23,5499,453306
2,39026145,209 15:09:16,5499,453306
3,39026145,209 15:13:51,5499,058435
4,39026145,209 15:13:17,5499,058435
...,...,...,...,...
3954146,61870738,453 16:03:02,5499,10217113
3954147,61870738,454 10:54:60,5411,022915
3954148,61870738,454 14:23:59,5499,10217113
3954149,61870738,454 16:11:53,5541,RU570124


In [16]:
sourceB

Unnamed: 0,customer_id,tr_datetime,tr_type,amount
0,39026145,208 14:18:12,1010,-7254.31
1,39026145,209 15:13:51,1010,-853.45
2,39026145,209 15:13:17,1010,-1410.44
3,39026145,210 00:00:00,1110,-5277.90
4,39026145,210 06:55:10,1030,-2245.92
...,...,...,...,...
3664146,61870738,452 19:33:04,1110,-4491.83
3664147,61870738,453 16:03:02,1010,-5176.84
3664148,61870738,454 10:54:60,1010,-1652.77
3664149,61870738,454 14:23:59,1010,-4687.23


## Preprocessing

In [17]:
mcc_code_in = len(np.unique((sourceA["mcc_code"])))
term_id_in = len(np.unique((sourceA["term_id"])))
tr_type_in = len(np.unique((sourceB["tr_type"])))

print("mcc_code_in:", mcc_code_in)
print("term_id_in:", term_id_in)
print("tr_type_in", tr_type_in)

mcc_code_in: 184
term_id_in: 435085
tr_type_in 73


In [18]:
def tr_datetime_preprocess(tr_datetime):
    days, hms = tr_datetime.split()
    hh, mm, ss = hms.split(":")

    seconds = timedelta(hours=int(hh), minutes=int(mm), seconds=int(ss))
    seconds = seconds.total_seconds()
    seconds += int(days) * 24 * 3600

    return int(seconds)

In [19]:
sourceA["tr_datetime"] = sourceA["tr_datetime"].apply(tr_datetime_preprocess)
sourceB["tr_datetime"] = sourceB["tr_datetime"].apply(tr_datetime_preprocess)

In [20]:
sourceA_preprocessor = PandasDataPreprocessor(
    col_id="customer_id",
    col_event_time="tr_datetime",
    event_time_transformation="none",
    cols_category=["mcc_code", "term_id"],
    return_records=False,
)

sourceB_preprocessor = PandasDataPreprocessor(
    col_id="customer_id",
    col_event_time="tr_datetime",
    event_time_transformation="none",
    cols_numerical=["tr_type", "amount"],
    return_records=False,
)

In [21]:
processed_sourceA = sourceA_preprocessor.fit_transform(sourceA)
processed_sourceB = sourceB_preprocessor.fit_transform(sourceB)

In [22]:
processed_sourceA.columns = [
    "sourceA_" + str(col) if str(col) != "customer_id" else str(col)
    for col in processed_sourceA.columns
]

In [23]:
processed_sourceB.columns = [
    "sourceB_" + str(col) if str(col) != "customer_id" else str(col)
    for col in processed_sourceB.columns
]

In [24]:
joined_data = processed_sourceA.merge(processed_sourceB, how="outer", on="customer_id")

In [25]:
joined_data = joined_data.applymap(lambda x: torch.tensor([]) if pd.isna(x) else x)

In [26]:
train_df, test_df = train_test_split(joined_data,
                                     test_size=0.1,
                                     random_state=42)
train_df, valid_df = train_test_split(train_df,
                                      test_size=0.1,
                                      random_state=42)

In [27]:
train_df = train_df.reset_index(drop=True)
valid_df = valid_df.reset_index(drop=True)
test_df = test_df.reset_index(drop=True)

In [28]:
train_dict = train_df.to_dict("records")
valid_dict = valid_df.to_dict("records")
test_dict = test_df.to_dict("records")

In [29]:
source_features = {
    "sourceA": ["event_time", "mcc_code", "term_id"],
    "sourceB": ["event_time", "tr_type", "amount"]
}

In [30]:
splitter = SampleSlices(split_count=5, cnt_min=25, cnt_max=50)

In [31]:
train_multimodal_data = MultiModalIterableDataset(
    data = train_dict,
    splitter = splitter,
    source_features = source_features,
    col_id = "customer_id",
    col_time = "event_time",
    source_names = ("sourceA", "sourceB")
)

valid_multimodal_data = MultiModalIterableDataset(
    data = valid_dict,
    splitter = splitter,
    source_features = source_features,
    col_id = "customer_id",
    col_time = "event_time",
    source_names = ("sourceA", "sourceB")
)

In [32]:
train_loader = PtlsDataModule(
    train_data = train_multimodal_data,
    train_num_workers = 16,
    train_batch_size = 64,

    valid_data = valid_multimodal_data
)

In [33]:
sourceA_encoder_params = dict(
    embeddings_noise = 0.003,
    linear_projection_size = 64,
    embeddings = {
        "mcc_code": {"in": mcc_code_in, "out": 32},
        "term_id": {"in": term_id_in, "out": 32}
    },
)

sourceB_encoder_params = dict(
    embeddings_noise = 0.003,
    linear_projection_size = 64,
    embeddings = {
        "tr_type": {"in": tr_type_in, "out": 32},
    },
    numeric_values = {"amount": "identity"},
)

In [34]:
sourceA_encoder = TrxEncoder(**sourceA_encoder_params)
sourceB_encoder = TrxEncoder(**sourceB_encoder_params)

In [35]:
seq_encoder = MultiModalSortTimeSeqEncoderContainer(
    trx_encoders = {
        "sourceA": sourceA_encoder,
        "sourceB": sourceB_encoder,
    },

    input_size = 64,
    hidden_size = 256,
    seq_encoder_cls = RnnEncoder,
    type = "gru"
)

In [36]:
model = CoLESModule(
    seq_encoder = seq_encoder,
    optimizer_partial = partial(torch.optim.Adam, lr=0.004),
    lr_scheduler_partial = partial(torch.optim.lr_scheduler.StepLR, step_size=30, gamma=0.5)
)

In [37]:
# model.load_state_dict(torch.load("model/multimodal_coles.pt"))

In [38]:
# logger = TensorBoardLogger("lightning_logs", name="CoLES-demo-multimodal")

pl_trainer = pl.Trainer(
    # logger = logger,
    max_epochs = 1,
    accelerator = "gpu",
    devices = 1,
    enable_progress_bar = True
)

Using default `ModelCheckpoint`. Consider installing `litmodels` package to enable `LitModelCheckpoint` for automatic upload to the Lightning model registry.
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [39]:
pl_trainer.fit(model, train_loader)

You are using a CUDA device ('NVIDIA A100 80GB PCIe') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name               | Type                                  | Params | Mode 
-------------------------------------------------------------------------------------
0 | _loss              | ContrastiveLoss                       | 0      | train
1 | _seq_encoder       | MultiModalSortTimeSeqEncoderContainer | 14.2 M | train
2 | _validation_metric | BatchRecallTopK                       | 0      | train
3 | _head              | Head                                  | 0      | train
-------------------------------------------------------------------------------------
14.2 M    Trainable p

Sanity Checking: |                                                                                       | 0/?…

Training: |                                                                                              | 0/?…

Validation: |                                                                                            | 0/?…

`Trainer.fit` stopped: `max_epochs=1` reached.


`Trainer.fit` stopped: `max_epochs=1` reached.


In [40]:
inf_test_data = MultiModalInferenceIterableDataset(
    data = test_dict,
    source_features = source_features,
    col_id = "customer_id",
    col_time = "event_time",
    source_names = ("sourceA", "sourceB")
)

In [41]:
inf_test_loader = DataLoader(
    dataset = inf_test_data,
    collate_fn = partial(inf_test_data.collate_fn, col_id="customer_id"),
    shuffle = False,
    num_workers = 0,
    batch_size = 8
)

In [42]:
inf_train_data = MultiModalInferenceIterableDataset(
    data = train_dict,
    source_features = source_features,
    col_id = "customer_id",
    col_time = "event_time",
    source_names = ("sourceA", "sourceB")
)

In [43]:
inf_train_loader = DataLoader(
    dataset = inf_train_data,
    collate_fn = partial(inf_train_data.collate_fn, col_id="customer_id"),
    shuffle = False,
    num_workers = 0,
    batch_size = 8
)

In [44]:
# inference_module = InferenceModuleMultimodal(
#     model = model,
#     pandas_output = True,
#     drop_seq_features = True,
#     model_out_name = "emb",
#     col_id = "customer_id"
# )

# inference_module.model.is_reduce_sequence = True

In [45]:
# inf_test_embeddings = pd.concat(
#     pl_trainer.predict(inference_module, inf_test_loader),
#     axis = 0
# )

In [46]:
# inf_test_embeddings

## Tunning hyperparams using RankMe metric

In [47]:
# !git clone https://github.com/google-research/google-research.git

In [48]:
# !ls google-research/graph_embedding/metrics

In [49]:
import sys
sys.path.append("google-research/graph_embedding/metrics")

In [50]:
from metrics import (rankme,
        coherence,
        pseudo_condition_number,
        alpha_req,
        stable_rank,
        ne_sum,
        self_clustering)

In [51]:
from itertools import product

In [52]:
# !pip install git+https://github.com/simonzhang00/ripser-plusplus.git

In [53]:
!ls 

CatBoostClassifier	    csv_results		   logs
age_notebook.ipynb	    gender		   model
catboost_info		    gender_notebook.ipynb  sample_frac_results.ipynb
churn_tr_hidden_size.ipynb  google-research	   simple_exps.ipynb
compare_components.ipynb    lightning_logs	   vizualize_results.ipynb


In [54]:
import ripserplusplus as rpp
def ripser_metric(embeddings, u=None, s=None):
    
    diagrams = rpp.run("--format point-cloud", embeddings)
    persistence = {}
    persistence["ripser_sum"] = 0

    for k in range(len(diagrams)):
        persistence_sum = sum([death - birth for birth, death in diagrams[k] if death > birth])
        persistence[f"ripser_sum_H{k}"] = persistence_sum
        persistence["ripser_sum"]+= persistence_sum

    return persistence

In [55]:
# batch_sizes = [64, 128]
# learning_rates = [0.001, 0.004]
# split_counts = [3, 5]
# cnt_min_values = [10, 25]
# cnt_max_values = [50, 100]

# # Генерация сетки гиперпараметров
# hyperparameter_grid = [
#     {
#         "batch_size": batch_size,
#         "learning_rate": lr,
#         "split_count": split_count,
#         "cnt_min": cnt_min,
#         "cnt_max": cnt_max,
#     }
#     for batch_size, lr, split_count, cnt_min, cnt_max in product(
#         batch_sizes, learning_rates, split_counts, cnt_min_values, cnt_max_values
#     )
# ]

In [56]:
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping

In [57]:
import os

checkpoints_path = "gender/checkpoints"
os.makedirs(checkpoints_path, exist_ok=True)

In [58]:
def create_datasets(train_dict, valid_dict, params, source_features):
    splitter = SampleSlices(
        split_count=params["split_count"],
        cnt_min=params["cnt_min"],
        cnt_max=params["cnt_max"],
    )

    train_data = MultiModalIterableDataset(
        data=train_dict,
        splitter=splitter,
        source_features=source_features,
        col_id="customer_id",
        col_time="event_time",
        source_names=("sourceA", "sourceB"),
    )

    valid_data = MultiModalIterableDataset(
        data=valid_dict,
        splitter=splitter,
        source_features=source_features,
        col_id="customer_id",
        col_time="event_time",
        source_names=("sourceA", "sourceB"),
    )

    data_loader = PtlsDataModule(
        train_data=train_data,
        train_batch_size=params["batch_size"],
        train_num_workers=0,
        valid_data=valid_data,
    )

    return data_loader

In [59]:
from time import time
import glob

In [60]:
import logging

os.makedirs('logs/gender', exist_ok=True)

logger = logging.getLogger("my_logger")
logger.setLevel(logging.INFO)

file_handler = logging.FileHandler("logs/gender/fraction_experiment.log")
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s')
file_handler.setFormatter(formatter)

# Удалим другие обработчики
if logger.hasHandlers():
    logger.handlers.clear()

logger.addHandler(file_handler)
logger.info("🔧 Логгер настроен вручную")

In [61]:
def compute_metrics(model, pl_trainer, inf_test_loader, selected_metrics=None, n_samples=10, sample_fraction=1/20):
    import gc
    from sklearn.utils import resample
    from time import time

    logger.info(f"{sample_fraction=}")
    model.eval()
    inference_module = InferenceModuleMultimodal(
        model=model,
        pandas_output=True,
        drop_seq_features=True,
        model_out_name="emb",
        col_id="customer_id",
    )
    inference_module.model.is_reduce_sequence = True

    # Получение эмбеддингов
    inf_test_embeddings = pd.concat(
        pl_trainer.predict(inference_module, inf_test_loader),
        axis=0,
    )
    embeddings_np = inf_test_embeddings.drop(columns=["customer_id"]).to_numpy(dtype=np.float32)
    sample_size = max(1, int(sample_fraction * embeddings_np.shape[0]))

    # Метрики
    available_metrics = {
        "rankme": rankme,
        "coherence": coherence,
        "pseudo_condition_number": pseudo_condition_number,
        "alpha_req": alpha_req,
        "stable_rank": stable_rank,
        "ne_sum": ne_sum,
        "self_clustering": self_clustering,
        "ripser": ripser_metric
    }
    if selected_metrics is None:
        selected_metrics = list(available_metrics.keys())

    metrics = {name: [] for name in selected_metrics}
    times = {name: [] for name in selected_metrics}

    for i in range(n_samples):
        sample = resample(embeddings_np, n_samples=sample_size, replace=False, random_state=42 + i)
        u, s, _ = np.linalg.svd(sample, compute_uv=True, full_matrices=False)

        for metric_name in selected_metrics:
            if metric_name not in available_metrics:
                continue

            try:
                t0 = time()
                result = available_metrics[metric_name](sample, u=u, s=s)
                t = time() - t0

                if isinstance(result, dict):
                    for subname, val in result.items():
                        if subname not in metrics:
                            metrics[subname] = []
                            times[subname] = []
                        metrics[subname].append(val)
                        times[subname].append(t)
                else:
                    if metric_name not in metrics:
                        metrics[metric_name] = []
                        times[metric_name] = []
                    metrics[metric_name].append(result)
                    times[metric_name].append(t)
            except Exception as e:
                print(f"⚠️ Failed to compute {metric_name} on sample {i}: {e}")

        gc.collect()

    averaged_metrics = {k: np.mean(v) for k, v in metrics.items()}
    std_metrics = {k: np.std(v) for k, v in metrics.items()}
    
    averaged_times = {k: np.mean(v) for k, v in times.items()}
    std_times = {k: np.std(v) for k, v in times.items()}

    print("\n📊 Средние значения метрик и время вычисления:")
    for metric_name in averaged_metrics:
        metric_value = averaged_metrics[metric_name]
        metric_time = averaged_times.get(metric_name, None)
        print(f"🧠 {metric_name:30s} = {metric_value:.4f} | ⏱ {metric_time:.4f} сек")

    return averaged_metrics, averaged_times, inf_test_embeddings


In [62]:
def evaluate_model(model, pl_trainer, inf_loader, selected_metrics=None, sample_fraction=1/20):
    model.eval()
    metrics, times, inf_test_embeddings = compute_metrics(model, pl_trainer, inf_loader, selected_metrics, sample_fraction=sample_fraction)
    targets_df = targets.set_index("customer_id")
    inf_test_df = inf_test_embeddings.merge(targets_df, how="inner", on="customer_id").set_index("customer_id")
    
    X = inf_test_df.drop(columns=["gender"])
    y = inf_test_df["gender"]
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
    
    classifier = catboost.CatBoostClassifier(
        iterations=150,
        random_seed=42,
        verbose=0,
    )
    classifier.fit(X_train, y_train)
    
    accuracy = classifier.score(X_test, y_test)
    
    return metrics, times, accuracy

In [63]:
fixed_params = {
    "batch_size": 64,
    "learning_rate": 0.001,
    "split_count": 3,
    "cnt_min": 10,
    "cnt_max": 50,
    "embedding_dim": 16,  # Размерность эмбеддингов
    "category_embedding_dim": 8,  # Размерность категорий эмбеддингов
    "hidden_size": 128,  # Размер скрытого слоя по умолчанию
}

# Список гиперпараметров для перебора
variable_params = {
    "batch_size": [32, 64, 128], 
    "learning_rate": [0.001, 0.005, 0.01, 0.05],
    "split_count": [3, 5, 7],
    "cnt_min": [10, 15, 20],
    "cnt_max": [80, 100, 150],
    "embedding_dim": [8, 16, 24, 32],
    "category_embedding_dim": [8, 16, 24],
    "hidden_size": [64, 128, 256],
}

# Создание списка всех гиперпараметров, которые нужно перебрать
all_hyperparameter_grids = []
for variable_param_name, variable_param_values in variable_params.items():
    for value in variable_param_values:
        hyperparameter_grid = {**fixed_params, variable_param_name: value}
        all_hyperparameter_grids.append((variable_param_name, hyperparameter_grid))


In [64]:
# fixed_params = {
#     "batch_size": 64,
#     "learning_rate": 0.001,
#     "split_count": 3,
#     "cnt_min": 10,
#     "cnt_max": 50,
#     "embedding_dim": 16,  
#     "category_embedding_dim": 8,  
# }

# # Перебираемые значения hidden_size
# hidden_sizes = [64, 256, 512, 1024, 1468, 2048]  # От маленьких значений до очень больших

# # Формирование гиперпараметров для перебора
# all_hyperparameter_grids = [
#     {**fixed_params, "hidden_size": h_size} for h_size in hidden_sizes
# ]

In [65]:
num_epochs = 30
columns = [
    *fixed_params.keys(), "checkpoint", "epoch_num", "accuracy", "early_stop_epoch", "sample_fraction"
] + [
    "metric_" + key for key in [
        "rankme", "coherence", "pseudo_condition_number", 
        "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser_sum_H0", "ripser_sum_H1", "ripser_sum"
    ]
] + ["time_" + key for key in [
    "rankme", "coherence", "pseudo_condition_number", 
    "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser_sum_H0", "ripser_sum_H1", "ripser_sum"
]]

In [66]:
class CustomLogger(pl.Callback):
    def __init__(self):
        super().__init__()
        self.early_stopping_epoch = None
    
    def on_train_epoch_end(self, trainer, pl_module):
        train_loss = trainer.callback_metrics.get("train_loss", None)
        val_loss = trainer.callback_metrics.get("val_loss", None)
        
        if train_loss is not None and val_loss is not None:
            print(f"Epoch {trainer.current_epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
        
        # Если валидационный лосс увеличивается - фиксируем эпоху остановки
        if trainer.early_stopping_callback is not None and trainer.early_stopping_callback.wait_count == 0:
            self.early_stopping_epoch = trainer.current_epoch


custom_logger = CustomLogger()
early_stopping_callback = EarlyStopping(
    monitor="val_loss",
    patience=5,
    mode="min",
    verbose=True
)

In [67]:
import gc

In [68]:
! rm -rf gender/checkpoints

In [None]:
cur_time = time()

for sample_fraction in np.linspace(1/20, 1, 5):
    print(f"{sample_fraction:.2f}")
    logger.info(f"{sample_fraction=:.2f}")
    output_csv = f"csv_results/gender_sample_fraction_{sample_fraction:.3f}".rstrip('0').rstrip('.') + ".csv"


    for param in all_hyperparameter_grids:
        
        logger.info(f'All params are frozen except {param[0]}')
        params = param[1]
        
        logger.info(f"Testing parameters: {params}")

        train_loader = create_datasets(train_dict, valid_dict, params, source_features)

        sourceA_encoder_params = dict(
            embeddings_noise=0.003,
            linear_projection_size=64,
            embeddings={
                "mcc_code": {"in": mcc_code_in, "out": 32},
                "term_id": {"in": term_id_in, "out": 32},
            },
        )
        
        sourceB_encoder_params = dict(
            embeddings_noise=0.003,
            linear_projection_size=64,
            embeddings={
                "tr_type": {"in": tr_type_in, "out": 32},
            },
            numeric_values={"amount": "identity"},
        )
        
        sourceA_encoder = TrxEncoder(**sourceA_encoder_params)
        sourceB_encoder = TrxEncoder(**sourceB_encoder_params)
        
        seq_encoder = MultiModalSortTimeSeqEncoderContainer(
            trx_encoders={
                "sourceA": sourceA_encoder,
                "sourceB": sourceB_encoder,
            },
            input_size=64,
            hidden_size=params["hidden_size"],  # Используем только текущее значение hidden_size
            seq_encoder_cls=RnnEncoder,
            type="gru",
        )

        model = CoLESModule(
            seq_encoder=seq_encoder,
            optimizer_partial=partial(torch.optim.Adam, lr=params["learning_rate"]),
            lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.5),
        )

        early_stopping_callback = EarlyStopping(
            monitor="loss",
            patience=5,
            mode="min",
            verbose=True
        )

        checkpoint_callback = ModelCheckpoint(
            dirpath=checkpoints_path,
            filename=f"model_{params['batch_size']}_{params['learning_rate']}_{params['split_count']}_{params['cnt_min']}_{params['cnt_max']}_{params['hidden_size']}{{epoch:02d}}",
            save_top_k=-1,
            every_n_epochs=1,
        )

        # Обучение модели
        pl_trainer = pl.Trainer(
            callbacks=[checkpoint_callback, early_stopping_callback, custom_logger],
            default_root_dir=checkpoints_path,
            check_val_every_n_epoch=1,
            max_epochs=num_epochs,
            accelerator="gpu",
            devices=1,
            enable_progress_bar=True,
            precision=16
        )
        model.train()
        pl_trainer.fit(model, train_loader)

        early_stop_epoch = custom_logger.early_stopping_epoch
        if early_stop_epoch is None:
            early_stop_epoch = num_epochs

        # Обработка чекпоинтов
        checkpoint_files = glob.glob(f"{checkpoints_path}/model_{params['batch_size']}_{params['learning_rate']}_{params['split_count']}_{params['cnt_min']}_{params['cnt_max']}_{params['hidden_size']}*.ckpt")
        checkpoint_files.sort()
        logger.info(f"Elapsed time: {time() - cur_time:.2f} seconds")

        logger.info(f'Early stop is {early_stop_epoch}')

        for i, checkpoint in enumerate(checkpoint_files):
            logger.info(f"Processing checkpoint number {i}")
            model = CoLESModule.load_from_checkpoint(checkpoint, seq_encoder=seq_encoder)

            # Вычисление метрик и времени
            metrics, times, accuracy = evaluate_model(model, pl_trainer, inf_test_loader, sample_fraction=sample_fraction)
            metrics_flattened = {f"metric_{k}": round(v, 4) for k, v in metrics.items()}
            times_flattened = {f"time_{k}": round(v, 4) for k, v in times.items()}

            # Сохранение результатов
            new_result = {
                **params,
                "checkpoint": checkpoint,
                "epoch_num": int(i),
                "accuracy": accuracy,
                **metrics_flattened,
                **times_flattened,
                "early_stop_epoch": int(early_stop_epoch),
                "sample_fraction": sample_fraction
            }

            # Сохранение в CSV
            results = pd.DataFrame([new_result], columns=columns)

            if not os.path.exists(output_csv):  
                pd.DataFrame(columns=columns).to_csv(output_csv, mode="w", index=False, header=True)
            
            results.to_csv(output_csv, mode="a", header=False, index=False)

            del metrics, accuracy, new_result
            torch.cuda.empty_cache()
            gc.collect()

        for checkpoint in checkpoint_files:
            os.remove(checkpoint)

        del model
        del train_loader
        torch.cuda.empty_cache()
        gc.collect()
    logger.info(f"results for {sample_fraction=} complete")

print("Optimization complete!")

0.05


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name               | Type                                  | Params | Mode 
-------------------------------------------------------------------------------------
0 | _loss              | ContrastiveLoss                       | 0      | train
1 | _seq_encoder       | MultiModalSortTimeSeqEncoderContainer | 14.0 M | train
2 | _validation_metric | BatchRecallTopK                       | 0      | train
3 | _head              | Head                                  | 0      | train
-------------------------------------------------------------------------------------
14.0 M    Trainable params
0         Non-trainable params
14.0 M    Total params
56.048    Total estimated model params size (MB)
27        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                       | 0/?…

Training: |                                                                                              | 0/?…

Validation: |                                                                                            | 0/?…

Metric loss improved. New best score: 17.697


Validation: |                                                                                            | 0/?…

Metric loss improved by 0.888 >= min_delta = 0.0. New best score: 16.808


Validation: |                                                                                            | 0/?…

Metric loss improved by 2.969 >= min_delta = 0.0. New best score: 13.839


Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Monitored metric loss did not improve in the last 5 records. Best score: 13.839. Signaling Trainer to stop.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                            | 0/?…


📊 Средние значения метрик и время вычисления:
🧠 rankme                         = 31.3527 | ⏱ 0.0001 сек
🧠 coherence                      = 1.0000 | ⏱ 0.0000 сек
🧠 pseudo_condition_number        = 0.0039 | ⏱ 0.0000 сек
🧠 alpha_req                      = 3.5649 | ⏱ 0.0001 сек
🧠 stable_rank                    = 1.2144 | ⏱ 0.0000 сек
🧠 ne_sum                         = 7.5380 | ⏱ 0.0008 сек
🧠 self_clustering                = 0.6821 | ⏱ 0.0001 сек
🧠 ripser                         = nan | ⏱ nan сек
🧠 ripser_sum                     = 126.4459 | ⏱ 0.0122 сек
🧠 ripser_sum_H0                  = 124.1731 | ⏱ 0.0122 сек
🧠 ripser_sum_H1                  = 2.2728 | ⏱ 0.0122 сек


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                            | 0/?…


📊 Средние значения метрик и время вычисления:
🧠 rankme                         = 31.7480 | ⏱ 0.0001 сек
🧠 coherence                      = 1.0000 | ⏱ 0.0000 сек
🧠 pseudo_condition_number        = 0.0039 | ⏱ 0.0000 сек
🧠 alpha_req                      = 3.4903 | ⏱ 0.0001 сек
🧠 stable_rank                    = 1.1966 | ⏱ 0.0000 сек
🧠 ne_sum                         = 8.1196 | ⏱ 0.0008 сек
🧠 self_clustering                = 0.6993 | ⏱ 0.0001 сек
🧠 ripser                         = nan | ⏱ nan сек
🧠 ripser_sum                     = 153.3335 | ⏱ 0.0110 сек
🧠 ripser_sum_H0                  = 150.3634 | ⏱ 0.0110 сек
🧠 ripser_sum_H1                  = 2.9701 | ⏱ 0.0110 сек


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                            | 0/?…


📊 Средние значения метрик и время вычисления:
🧠 rankme                         = 32.9319 | ⏱ 0.0001 сек
🧠 coherence                      = 1.0000 | ⏱ 0.0000 сек
🧠 pseudo_condition_number        = 0.0041 | ⏱ 0.0000 сек
🧠 alpha_req                      = 3.3856 | ⏱ 0.0001 сек
🧠 stable_rank                    = 1.1977 | ⏱ 0.0000 сек
🧠 ne_sum                         = 9.0111 | ⏱ 0.0008 сек
🧠 self_clustering                = 0.6955 | ⏱ 0.0001 сек
🧠 ripser                         = nan | ⏱ nan сек
🧠 ripser_sum                     = 175.2336 | ⏱ 0.0105 сек
🧠 ripser_sum_H0                  = 171.5645 | ⏱ 0.0105 сек
🧠 ripser_sum_H1                  = 3.6690 | ⏱ 0.0105 сек


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                            | 0/?…


📊 Средние значения метрик и время вычисления:
🧠 rankme                         = 33.4202 | ⏱ 0.0001 сек
🧠 coherence                      = 1.0000 | ⏱ 0.0000 сек
🧠 pseudo_condition_number        = 0.0045 | ⏱ 0.0000 сек
🧠 alpha_req                      = 3.2757 | ⏱ 0.0001 сек
🧠 stable_rank                    = 1.1920 | ⏱ 0.0000 сек
🧠 ne_sum                         = 9.1144 | ⏱ 0.0008 сек
🧠 self_clustering                = 0.7014 | ⏱ 0.0001 сек
🧠 ripser                         = nan | ⏱ nan сек
🧠 ripser_sum                     = 183.6301 | ⏱ 0.0123 сек
🧠 ripser_sum_H0                  = 180.0147 | ⏱ 0.0123 сек
🧠 ripser_sum_H1                  = 3.6153 | ⏱ 0.0123 сек


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                            | 0/?…


📊 Средние значения метрик и время вычисления:
🧠 rankme                         = 33.9344 | ⏱ 0.0001 сек
🧠 coherence                      = 1.0000 | ⏱ 0.0000 сек
🧠 pseudo_condition_number        = 0.0044 | ⏱ 0.0000 сек
🧠 alpha_req                      = 3.1619 | ⏱ 0.0001 сек
🧠 stable_rank                    = 1.1849 | ⏱ 0.0000 сек
🧠 ne_sum                         = 9.3910 | ⏱ 0.0008 сек
🧠 self_clustering                = 0.7097 | ⏱ 0.0001 сек
🧠 ripser                         = nan | ⏱ nan сек
🧠 ripser_sum                     = 192.5663 | ⏱ 0.0109 сек
🧠 ripser_sum_H0                  = 188.4109 | ⏱ 0.0109 сек
🧠 ripser_sum_H1                  = 4.1554 | ⏱ 0.0109 сек


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                            | 0/?…


📊 Средние значения метрик и время вычисления:
🧠 rankme                         = 35.3055 | ⏱ 0.0001 сек
🧠 coherence                      = 1.0000 | ⏱ 0.0000 сек
🧠 pseudo_condition_number        = 0.0050 | ⏱ 0.0000 сек
🧠 alpha_req                      = 3.0453 | ⏱ 0.0001 сек
🧠 stable_rank                    = 1.1962 | ⏱ 0.0000 сек
🧠 ne_sum                         = 9.2820 | ⏱ 0.0008 сек
🧠 self_clustering                = 0.6959 | ⏱ 0.0001 сек
🧠 ripser                         = nan | ⏱ nan сек
🧠 ripser_sum                     = 200.5509 | ⏱ 0.0127 сек
🧠 ripser_sum_H0                  = 196.5061 | ⏱ 0.0127 сек
🧠 ripser_sum_H1                  = 4.0449 | ⏱ 0.0127 сек


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                            | 0/?…


📊 Средние значения метрик и время вычисления:
🧠 rankme                         = 34.5874 | ⏱ 0.0001 сек
🧠 coherence                      = 1.0000 | ⏱ 0.0000 сек
🧠 pseudo_condition_number        = 0.0048 | ⏱ 0.0000 сек
🧠 alpha_req                      = 3.0329 | ⏱ 0.0001 сек
🧠 stable_rank                    = 1.1816 | ⏱ 0.0000 сек
🧠 ne_sum                         = 9.0161 | ⏱ 0.0008 сек
🧠 self_clustering                = 0.7130 | ⏱ 0.0001 сек
🧠 ripser                         = nan | ⏱ nan сек
🧠 ripser_sum                     = 203.1643 | ⏱ 0.0114 сек
🧠 ripser_sum_H0                  = 199.1070 | ⏱ 0.0114 сек
🧠 ripser_sum_H1                  = 4.0572 | ⏱ 0.0114 сек


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]


Predicting: |                                                                                            | 0/?…


📊 Средние значения метрик и время вычисления:
🧠 rankme                         = 35.5268 | ⏱ 0.0001 сек
🧠 coherence                      = 1.0000 | ⏱ 0.0000 сек
🧠 pseudo_condition_number        = 0.0051 | ⏱ 0.0000 сек
🧠 alpha_req                      = 2.9750 | ⏱ 0.0001 сек
🧠 stable_rank                    = 1.1939 | ⏱ 0.0000 сек
🧠 ne_sum                         = 8.8178 | ⏱ 0.0008 сек
🧠 self_clustering                = 0.6983 | ⏱ 0.0001 сек
🧠 ripser                         = nan | ⏱ nan сек
🧠 ripser_sum                     = 211.2965 | ⏱ 0.0155 сек
🧠 ripser_sum_H0                  = 206.7317 | ⏱ 0.0155 сек
🧠 ripser_sum_H1                  = 4.5648 | ⏱ 0.0155 сек


Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1]

  | Name               | Type                                  | Params | Mode 
-------------------------------------------------------------------------------------
0 | _loss              | ContrastiveLoss                       | 0      | train
1 | _seq_encoder       | MultiModalSortTimeSeqEncoderContainer | 14.0 M | train
2 | _validation_metric | BatchRecallTopK                       | 0      | train
3 | _head              | Head                                  | 0      | train
-------------------------------------------------------------------------------------
14.0 M    Trainable params
0         Non-trainable params
14.0 M    Total params
56.048    Total estimated model params size (MB)
27        Modules in train mode
0         Modules in eval mode


Sanity Checking: |                                                                                       | 0/?…

Training: |                                                                                              | 0/?…

Validation: |                                                                                            | 0/?…

Metric loss improved. New best score: 19.845


Validation: |                                                                                            | 0/?…

Metric loss improved by 0.569 >= min_delta = 0.0. New best score: 19.276


Validation: |                                                                                            | 0/?…

Metric loss improved by 3.807 >= min_delta = 0.0. New best score: 15.469


Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Metric loss improved by 0.669 >= min_delta = 0.0. New best score: 14.800


Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Metric loss improved by 0.494 >= min_delta = 0.0. New best score: 14.306


Validation: |                                                                                            | 0/?…

Metric loss improved by 0.355 >= min_delta = 0.0. New best score: 13.950


Validation: |                                                                                            | 0/?…

Metric loss improved by 0.178 >= min_delta = 0.0. New best score: 13.773


Validation: |                                                                                            | 0/?…

Validation: |                                                                                            | 0/?…

Validation: |          | 0/? [00:00<?, ?it/s]

## Try Optuna and recall TopK

In [None]:
# import optuna
# import pandas as pd
# import torch
# import gc
# from tqdm import tqdm

# # Файл для сохранения результатов
# output_csv = "optuna_hyperparameter_by_topK.csv"

# # Определение диапазонов гиперпараметров
# def define_search_space(trial):
#     return {
#         "batch_size": trial.suggest_categorical("batch_size", [32, 64]),
#         "learning_rate": trial.suggest_loguniform("learning_rate", 1e-4, 5e-2),
#         "hidden_size": trial.suggest_categorical("hidden_size", [64, 128]),
#         "embedding_dim": trial.suggest_categorical("embedding_dim", [8, 16, 32]),
#         "category_embedding_dim": trial.suggest_categorical("category_embedding_dim", [4, 8, 16]),
#         "split_count": trial.suggest_categorical("split_count", [2, 3, 5]),
#         "cnt_min": trial.suggest_categorical("cnt_min", [5, 10, 20]),
#         "cnt_max": trial.suggest_categorical("cnt_max", [50, 80, 100]),
#     }

# # Метрики, по которым проводится подбор гиперпараметров
# metric_names = [
#     "rankme", "coherence", "pseudo_condition_number",
#     "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser"
# ]

# # Список для хранения всех результатов
# optuna_results = []

In [None]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"

In [None]:
# from ptls.data_load.padded_batch import PaddedBatch


# class CustomCoLESModule(CoLESModule):
#     def __init__(self, custom_metric_name, *args, **kwargs):
#         super().__init__(*args, **kwargs)
#         self.custom_metric_name = custom_metric_name
#         # model.to(device)


#     def validation_step(self, batch, batch_idx):
#         print("valedation step")
#         x, y = batch
#         y = y.to(self.device)
#         for key in x:
#             if isinstance(x[key], PaddedBatch):
#                 # 🔄 Создаем новый PaddedBatch с seq_lens и payload на нужном устройстве
#                 x[key] = PaddedBatch(
#                     payload={k: v.to(self.device) for k, v in x[key].payload.items()},
#                     length=x[key].length.to(self.device)  # ⚠️ Используем length вместо seq_lens
#                 )
#             else:
#                 print(f"⚠️ [WARNING] Expected PaddedBatch but got {type(x[key])} for {key}")

#         print(f"Model is on device: {next(self.parameters()).device}")
#         print(f"x is on device: {[x[k].device for k in x]}")
#         print(f"y is on device: {y.device}")

#         y_hat = self(x)

#         # Compute loss (assuming classification task)
#         loss = torch.nn.functional.cross_entropy(y_hat, y)

#         metric_value, _, _ = compute_metrics(model, pl_trainer, inf_test_loader, selected_metrics=[self.custom_metric_name])

#         print(f"[DEBUG] Logging metric: valid/{self.custom_metric_name} = {metric_value[self.custom_metric_name]}")

#         self.trainer.logger.log_metrics({f"valid/{self.custom_metric_name}": metric_value[self.custom_metric_name]}, step=self.current_epoch)

#         return {"loss": loss, self.custom_metric_name: metric_value[self.custom_metric_name]}

In [None]:
! rm -rf /kaggle/working/checkpoints

In [None]:
! rm /kaggle/working/optuna_best_trials_accuracy.csv

In [None]:
# import optuna
# import time
# import pandas as pd
# import torch
# import gc
# import os
# import glob
# from functools import partial
# from pytorch_lightning import Trainer
# from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
# from tqdm import tqdm
# from time import time

# # Файл для логирования результатов
# output_csv = "optuna_results.csv"

# # Метрики для проверки
# metric_names = [
#     "rankme", "coherence", "pseudo_condition_number",
#     "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser"
# ]

# optuna_columns = [
#     *fixed_params.keys(), "checkpoint", "epoch_num", "accuracy", "topk_accuracy", "early_stop_epoch", "hidden_size",
# ] + [
#     "metric_" + key for key in [
#         "rankme", "coherence", "pseudo_condition_number", 
#         "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser"
#     ]
# ] + ["time_" + key for key in [
#     "rankme", "coherence", "pseudo_condition_number", 
#     "alpha_req", "stable_rank", "ne_sum", "self_clustering", "ripser"
# ]]


# def objective(trial):
#     # print(f'dealing with metric {metric_name}')
#     torch.cuda.empty_cache()
#     gc.collect()

#     params = define_search_space(trial)

#     # === Dataset ===
#     data_module = create_datasets(train_dict, valid_dict, params, source_features)

#     # === Encoders ===
#     sourceA_encoder = TrxEncoder(
#         embeddings_noise=0.003,
#         linear_projection_size=64,
#         embeddings={
#             "mcc_code": {"in": mcc_code_in, "out": 32},
#             "term_id": {"in": term_id_in, "out": 32},
#         },
#     )

#     sourceB_encoder = TrxEncoder(
#         embeddings_noise=0.003,
#         linear_projection_size=64,
#         embeddings={
#             "tr_type": {"in": tr_type_in, "out": 32},
#         },
#         numeric_values={"amount": "identity"},
#     )

#     seq_encoder = MultiModalSortTimeSeqEncoderContainer(
#         trx_encoders={"sourceA": sourceA_encoder, "sourceB": sourceB_encoder},
#         input_size=64,
#         hidden_size=params["hidden_size"],
#         seq_encoder_cls=RnnEncoder,
#         type="gru",
#     )

#     model = CoLESModule(
#         seq_encoder=seq_encoder,
#         optimizer_partial=partial(torch.optim.Adam, lr=params["learning_rate"]),
#         lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.5),
#     )

#     # === Callbacks ===
#     early_stopping_callback = EarlyStopping(
#         monitor=f"valid/recall_top_k", patience=5, mode="max", verbose=True
#     )

#     checkpoint_callback = ModelCheckpoint(
#         dirpath=checkpoints_path,
#         filename=f"model_optuna_trial_{trial.number}_epoch={{epoch:02d}}",
#         save_top_k=-1,
#         monitor="valid/recall_top_k",
#         mode="max",
#     )

#     trainer = pl.Trainer(
#         callbacks=[checkpoint_callback, early_stopping_callback, custom_logger],
#         default_root_dir=checkpoints_path,
#         check_val_every_n_epoch=1,
#         max_epochs= 1, # num_epochs,
#         accelerator="gpu",
#         devices=1,
#         enable_progress_bar=True,
#         precision=16,
#     )

#     trainer.fit(model, datamodule=data_module)

#     early_stop_epoch = custom_logger.early_stopping_epoch or num_epochs

#     # === Evaluate checkpoints ===
#     checkpoint_files = sorted(
#         glob.glob(f"{checkpoints_path}/model_optuna_trial_{trial.number}_epoch=*.ckpt")
#     )

#     best_acc = float("-inf")

#     for i, checkpoint in enumerate(checkpoint_files):
#         model = CoLESModule.load_from_checkpoint(checkpoint, seq_encoder=seq_encoder)
#         metrics, times, acc = evaluate_model(model, trainer, inf_train_loader)

#         metrics_flattened = {f"metric_{k}": round(v, 4) for k, v in metrics.items()}
#         times_flattened = {f"time_{k}": round(v, 4) for k, v in times.items()}
#         trainer = pl.Trainer(accelerator="gpu", devices=1)

#         val_metrics = trainer.validate(model=model, datamodule=data_module)
#         recall_top_k = val_metrics[0].get("valid/recall_top_k", None)
#         result = {
#             **params,
#             "checkpoint": checkpoint,
#             "epoch_num": i,
#             "accuracy": acc,
#             "topk_accuracy": recall_top_k,
#             **metrics_flattened,
#             **times_flattened,
#             "early_stop_epoch": early_stop_epoch,
#         }

#         results_df = pd.DataFrame([result], columns=optuna_columns)

#         if not os.path.exists(output_csv):
#             pd.DataFrame(columns=optuna_columns).to_csv(output_csv, index=False, header=True)
#         results_df.to_csv(output_csv, mode="a", header=False, index=False)

#         # current_metric_value = metrics.get(metric_name, float("-inf"))
#         best_acc = max(best_acc, acc)

#         del model, result, metrics
#         torch.cuda.empty_cache()
#         gc.collect()

#     for ckpt in checkpoint_files:
#         os.remove(ckpt)

#     return best_acc


In [None]:
# import optuna
# import pandas as pd
# import os
# from time import time
# from functools import partial

# # Параметры для Optuna
# num_trials = 10
# cur_time = time()

# # Путь к файлу для сохранения лучших результатов
# best_trials_csv = "optuna_best_trials_accuracy.csv"

# # Если файла ещё нет, создаём его
# if not os.path.exists(best_trials_csv):
#     pd.DataFrame(columns=["value", *fixed_params.keys()]).to_csv(best_trials_csv, index=False)

# # Теперь оптимизируем только по accuracy
# study = optuna.create_study(direction="maximize")  # Оптимизируем именно accuracy!
# study.optimize(objective, n_trials=num_trials)     # objective теперь должна возвращать 1-accuracy или -accuracy

# # Достаем лучший результат
# best_trial = study.best_trial
# best_result = {
#     "value": best_trial.value,
#     **best_trial.params
# }

# # Сохраняем лучший результат
# df_best = pd.DataFrame([best_result])
# df_best.to_csv(best_trials_csv, mode="a", header=False, index=False)

# # Логи
# print(f"✅ Optimization completed (direction: maximize)")
# print(f"⏱️ Time passed: {time() - cur_time:.2f} sec")
# print(f"🥇 Best trial value: {best_trial.value}")
# print(f"📊 Params: {best_trial.params}")


## Eval model with best hyperparams

In [None]:
input_csv = "/kaggle/input/gender-tr-best-params/gender_tr_optuna_best_params.csv"
best_trials_df = pd.read_csv(input_csv)

In [None]:
best_trials_df.reset_index(inplace=True)
best_trials_df.rename(columns={"index": "metric", "metric":"value",
                              "value":"batch_size", "batch_size":"learning_rate", "learning_rate":"hidden_size"}, inplace=True)

In [None]:
best_trials_df.rename(columns={"cnt_min": "embedding_dim", "embedding_dim":"cnt_min",
                              "category_embedding_dim":"cnt_max", "cnt_max":"category_embedding_dim"}, inplace=True)

In [None]:
best_trials_df

In [None]:
! rm -rf /kaggle/working/checkpoints

In [None]:
! rm /kaggle/working/optuna_best_metrics_eval.csv

In [None]:
# import pandas as pd
# import torch
# import gc
# import os
# import glob
# from functools import partial
# from time import time
# from pytorch_lightning import Trainer
# from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
# from itertools import islice



# checkpoints_path = "checkpoints"
# os.makedirs(checkpoints_path, exist_ok=True)


# # Для хранения результатов
# columns = list(best_trials_df.columns) + [
#     "checkpoint", "epoch_num", "accuracy", "early_stop_epoch"
# ] + [f"metric_{m}" for m in metric_names] + [f"time_{m}" for m in metric_names]

# for idx, row in islice(best_trials_df.iterrows(), 2, None):
#     metric_name = row["metric"]
#     print(f"\n=== Processing best params for metric: {metric_name} ===")
#     output_csv = f"optuna_best_metrics_eval_{metric_name}.csv"

#     print(row)
    
#     # Собираем параметры
#     params = {
#         "batch_size": int(row["batch_size"]),
#         "learning_rate": float(row["learning_rate"]),
#         "split_count": int(row["split_count"]),
#         "cnt_min": int(row["cnt_min"]),
#         "cnt_max": int(row["cnt_max"]),
#         "embedding_dim": int(row["embedding_dim"]),
#         "category_embedding_dim": int(row["category_embedding_dim"]),
#         "hidden_size": int(row["hidden_size"]),  # добавь сюда если будет в CSV
#     }

#     # Загружаем датасет
#     train_loader = create_datasets(train_dict, valid_dict, params, source_features)

#     # Создаём энкодеры
#     sourceA_encoder = TrxEncoder(
#         embeddings={"mcc_code": {"in": mcc_code_in, "out": 32}, "term_id": {"in": term_id_in, "out": 32}},
#         embeddings_noise=0.003,
#         linear_projection_size=64,
#     )
#     sourceB_encoder = TrxEncoder(
#         embeddings={"tr_type": {"in": tr_type_in, "out": 32}},
#         numeric_values={"amount": "identity"},
#         embeddings_noise=0.003,
#         linear_projection_size=64,
#     )

#     seq_encoder = MultiModalSortTimeSeqEncoderContainer(
#         trx_encoders={"sourceA": sourceA_encoder, "sourceB": sourceB_encoder},
#         input_size=64,
#         hidden_size=params["hidden_size"],
#         seq_encoder_cls=RnnEncoder,
#         type="gru",
#     )

#     model = CoLESModule(
#         seq_encoder=seq_encoder,
#         optimizer_partial=partial(torch.optim.Adam, lr=params["learning_rate"]),
#         lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.5),
#     )

#     early_stopping_callback = EarlyStopping(
#         monitor="loss", patience=5, mode="min", verbose=True
#     )
#     checkpoint_callback = ModelCheckpoint(
#         dirpath=checkpoints_path,
#         filename=f"best_{metric_name}_trial_{idx}_epoch={{epoch:02d}}",
#         save_top_k=-1,
#         every_n_epochs=1,
#     )

#     trainer = Trainer(
#         callbacks=[checkpoint_callback, early_stopping_callback],
#         default_root_dir=checkpoints_path,
#         check_val_every_n_epoch=1,
#         max_epochs=30,
#         accelerator="gpu",
#         devices=1,
#         enable_progress_bar=True,
#         precision=16
#     )

#     # Обучение
#     trainer.fit(model, train_loader)
#     early_stop_epoch = getattr(trainer.logger, "early_stopping_epoch", None) or num_epochs

#     # Вычисление метрик
#     checkpoint_files = sorted(
#         glob.glob(f"{checkpoints_path}/best_{metric_name}_trial_{idx}_epoch=*.ckpt")
#     )
#     model.cpu()
#     del model
#     torch.cuda.empty_cache()

#     for i, checkpoint in enumerate(checkpoint_files):
#         print(f"Evaluating checkpoint #{i}")
#         model = CoLESModule.load_from_checkpoint(checkpoint, seq_encoder=seq_encoder)
#         metrics, times, accuracy = evaluate_model(model, trainer)

#         row_result = {
#             **params,
#             "metric": metric_name,
#             "checkpoint": checkpoint,
#             "epoch_num": i,
#             "accuracy": accuracy,
#             "early_stop_epoch": early_stop_epoch,
#             **{f"metric_{k}": round(v, 4) for k, v in metrics.items()},
#             **{f"time_{k}": round(v, 4) for k, v in times.items()}
#         }

#         # Сохраняем результат
#         result_df = pd.DataFrame([row_result], columns=columns)
#         if not os.path.exists(output_csv):
#             pd.DataFrame(columns=columns).to_csv(output_csv, index=False)
#         result_df.to_csv(output_csv, mode="a", index=False, header=False)

#         del model, result_df
#         torch.cuda.empty_cache()
#         gc.collect()

#     # Удаление чекпоинтов
#     for ckpt in checkpoint_files:
#         os.remove(ckpt)
#     del trainer, train_loader, seq_encoder
#     torch.cuda.empty_cache()
#     gc.collect()

# print("✅ Evaluation of best params complete.")

In [None]:
# import os

# # Загружаем сохраненные результаты, если файл уже существует
# if os.path.exists(output_csv):
#     print('exists')
#     processed_df = pd.read_csv(output_csv)
    
#     processed_params = [
#     tuple(row[["batch_size", "learning_rate", "split_count", "cnt_min", "cnt_max", "embedding_dim", "category_embedding_dim", "hidden_size"]])
#     for _, row in processed_df.iterrows()
# ]

# else:
#     processed_params = set()
#     print("don't exists")

In [None]:
# remaining_hyperparameter_grids = []

# for variable_param_name, hyperparameter_grid in all_hyperparameter_grids:
#     # Преобразуем hyperparameter_grid в кортеж значений (только значимые параметры)
#     param_values_tuple = tuple(hyperparameter_grid.values())  # Без сортировки!
#     if param_values_tuple not in processed_params:
#         print(param_values_tuple)
#         remaining_hyperparameter_grids.append((variable_param_name, hyperparameter_grid))
#     else:
#         print(1)

# print(f"Remaining hyperparameter sets to process: {len(remaining_hyperparameter_grids)}")


In [None]:
# # Запускаем обучение только для оставшихся гиперпараметров
# for variable_param_name, params in remaining_hyperparameter_grids:
#     print(f'variable param is {variable_param_name}')
#     print(f"Processing parameters: {params}")

#     train_loader = create_datasets(train_dict, valid_dict, params, source_features)

#     sourceA_encoder_params = dict(
#         embeddings_noise=0.003,
#         linear_projection_size=64,
#         embeddings={
#             "mcc_code": {"in": mcc_code_in, "out": 32},
#             "term_id": {"in": term_id_in, "out": 32},
#         },
#     )

#     sourceB_encoder_params = dict(
#         embeddings_noise=0.003,
#         linear_projection_size=64,
#         embeddings={
#             "tr_type": {"in": tr_type_in, "out": 32},
#         },
#         numeric_values={"amount": "identity"},
#     )

#     sourceA_encoder = TrxEncoder(**sourceA_encoder_params)
#     sourceB_encoder = TrxEncoder(**sourceB_encoder_params)

#     seq_encoder = MultiModalSortTimeSeqEncoderContainer(
#         trx_encoders={
#             "sourceA": sourceA_encoder,
#             "sourceB": sourceB_encoder,
#         },
#         input_size=64,
#         hidden_size=params["hidden_size"],
#         seq_encoder_cls=RnnEncoder,
#         type="gru",
#     )

#     model = CoLESModule(
#         seq_encoder=seq_encoder,
#         optimizer_partial=partial(torch.optim.Adam, lr=params["learning_rate"]),
#         lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=10, gamma=0.5),
#     )

#     early_stopping_callback = EarlyStopping(
#         monitor="loss",  # Следим за валидационным лоссом
#         patience=5,  # Количество эпох без улучшения перед остановкой
#         mode="min",  # Нужно минимизировать loss
#         verbose=True
#     )

#     checkpoint_callback = ModelCheckpoint(
#         dirpath=checkpoints_path,
#         filename=f"model_{params['batch_size']}_{params['learning_rate']}_{params['split_count']}_{params['cnt_min']}_{params['cnt_max']}_{params['hidden_size']}{{epoch:02d}}",
#         save_top_k=-1,
#         every_n_epochs=1,
#     )

#     # Обучение модели
#     pl_trainer = pl.Trainer(
#         callbacks=[checkpoint_callback, early_stopping_callback, custom_logger],
#         default_root_dir=checkpoints_path,
#         check_val_every_n_epoch=1,
#         max_epochs=num_epochs,
#         accelerator="gpu",
#         devices=1,
#         enable_progress_bar=True,
#         precision=16
#     )
#     model.train()
#     pl_trainer.fit(model, train_loader)

#     early_stop_epoch = custom_logger.early_stopping_epoch
#     if early_stop_epoch is None:
#         early_stop_epoch = num_epochs

#     # Обработка чекпоинтов
#     checkpoint_files = glob.glob(f"{checkpoints_path}/model_{params['batch_size']}_{params['learning_rate']}_{params['split_count']}_{params['cnt_min']}_{params['cnt_max']}_{params['hidden_size']}*.ckpt")
#     checkpoint_files.sort()
#     print(f"Elapsed time: {time() - cur_time:.2f} seconds")

#     print(f'Early stop is {early_stop_epoch}')

#     for i, checkpoint in enumerate(checkpoint_files):
#         print(f"Processing checkpoint number {i}")
#         model = CoLESModule.load_from_checkpoint(checkpoint, seq_encoder=seq_encoder)

#         metrics, times, accuracy = evaluate_model(model)
#         metrics_flattened = {f"metric_{k}": round(v, 4) for k, v in metrics.items()}
#         times_flattened = {f"time_{k}": round(v, 4) for k, v in times.items()}

#         new_result = {
#             **params,
#             "checkpoint": checkpoint,
#             "epoch_num": int(i),
#             "accuracy": accuracy,
#             **metrics_flattened,
#             **times_flattened,
#             "early_stop_epoch": int(early_stop_epoch)
#         }

#         new_result["epoch_num"] = int(new_result["epoch_num"])

#         results = pd.DataFrame([new_result], columns=columns)
#         print('----------')
#         print(results)

#         if not os.path.exists(output_csv):  # Проверяем, существует ли уже файл
#             pd.DataFrame(columns=columns).to_csv(output_csv, mode="w", index=False, header=True)
        
#         results.to_csv(output_csv, mode="a", header=False, index=False)

#         del metrics, accuracy, new_result
#         torch.cuda.empty_cache()
#         gc.collect()

#     print(f"Removing checkpoints for parameters: {params}")
#     for checkpoint in checkpoint_files:
#         os.remove(checkpoint)

#     del model
#     del train_loader
#     torch.cuda.empty_cache()
#     gc.collect()

# print("Optimization complete!")

In [None]:
# df_results = pd.DataFrame(results)
# df_results.to_csv("CoLES_hyperparameter_results.csv", index=False)
# print("Results saved to hyperparameter_results.csv")

### Calc correlation

In [None]:
hyperparameters = ["batch_size", "learning_rate", "split_count", "cnt_min", "cnt_max"]

# Корреляция между RankMe и accuracy
rankme_accuracy_corr, rankme_accuracy_pval = stats.pearsonr(df_results["rankme"], df_results["accuracy"])
print(f"Correlation between RankMe and Accuracy: {rankme_accuracy_corr:.4f}")
print(f"P-value: {rankme_accuracy_pval:.4e}\n")

# Корреляция между RankMe и каждым гиперпараметром
for param in hyperparameters:
    corr, pval = stats.pearsonr(df_results["rankme"], df_results[param])
    print(f"Correlation between RankMe and {param}: {corr:.4f}")
    print(f"P-value: {pval:.4e}\n")

### Plot smt

In [None]:
hyperparameters = ["batch_size", "learning_rate", "split_count", "cnt_min", "cnt_max"]

# Построение графиков для accuracy
for param in hyperparameters:
    plt.figure(figsize=(8, 5))
    plt.scatter(df_results[param], df_results["accuracy"], alpha=0.7, label="Accuracy")
    plt.xlabel(param)
    plt.ylabel("Accuracy")
    plt.title(f"Accuracy vs {param}")
    plt.grid(True)
    plt.legend()
    plt.show()

In [None]:
for param in hyperparameters:
    plt.figure(figsize=(8, 5))
    plt.scatter(df_results[param], df_results["rankme"], alpha=0.7, label="RankMe", color="orange")
    plt.xlabel(param)
    plt.ylabel("RankMe")
    plt.title(f"RankMe vs {param}")
    plt.grid(True)
    plt.legend()
    plt.show()