# Cluster Loss

In this notebook we will cover usage of clustering module for unsupervised embeddings pretraining

In [1]:
import os
import numpy as np
import pandas as pd
import torch
import ptls
import pytorch_lightning as pl
from ptls.data_load.datasets import ParquetFiles, ParquetDataset
from ptls.frames import PtlsDataModule
from ptls.frames.coles import ClusterIterableDataset, ClusterCallback, ClusterModule
from ptls.frames.supervised import SeqToTargetIterableDataset
from ptls.frames.coles.split_strategy import SampleSlices
from ptls.frames.coles.losses import ClusterLoss, ContrastiveLoss, ClusterAndContrastive
from functools import partial
from lightgbm import LGBMClassifier
from sklearn.metrics import roc_auc_score
from itertools import chain
from ptls.frames.coles.sampling_strategies import HardNegativePairSelector

  warn(
2024-09-02 04:36:30.303695: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
  torch.utils._pytree._register_pytree_node(


Firstly we need to define a dataset. Clustering dataset requires column with datapoint id to map it's view to corresponding cluster centroid. Of dataset does not have it, we must add one.

In [2]:
path = "./syndata/example_data/"
sample_df = pd.read_parquet(path+'train/train_0.parquet')
sample_df.head()

Unnamed: 0,A,B,C,event_time,class_label
0,"[21, 40, 7, 60, 34, 19, 24, 2, 23, 63, 60, 34,...","[26, 22, 51, 26, 17, 12, 32, 1, 12, 39, 60, 38...","[31, 57, 10, 21, 41, 10, 22, 48, 2, 21, 45, 42...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",0
1,"[57, 11, 24, 2, 17, 13, 45, 45, 45, 43, 24, 0,...","[18, 23, 57, 9, 12, 37, 44, 34, 17, 15, 59, 30...","[61, 43, 28, 37, 40, 2, 22, 53, 41, 9, 9, 14, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1
2,"[37, 47, 60, 37, 44, 37, 41, 13, 44, 37, 40, 3...","[47, 61, 40, 7, 57, 14, 49, 15, 57, 12, 32, 1,...","[24, 2, 19, 29, 41, 15, 56, 2, 20, 32, 2, 20, ...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",0
3,"[43, 30, 50, 16, 2, 23, 60, 38, 49, 13, 44, 32...","[26, 17, 8, 7, 61, 43, 31, 61, 44, 33, 15, 61,...","[0, 5, 41, 8, 6, 53, 41, 8, 6, 53, 45, 47, 61,...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",0
4,"[38, 54, 53, 45, 41, 11, 29, 45, 41, 8, 5, 42,...","[51, 25, 11, 27, 30, 49, 9, 8, 2, 23, 63, 59, ...","[63, 59, 26, 17, 13, 44, 34, 17, 15, 63, 61, 4...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1


In [3]:
train_files = ParquetFiles(os.path.join(path, "train"))
eval_files = ParquetFiles(os.path.join(path, "eval"))


# lets add unique index to every row in dataset

global_index = 0
for file_path in train_files.data_files:
    file = pd.read_parquet(file_path)
    l = file.shape[0]
    file['col_id'] = file.index + global_index
    global_index += l
    
    new_path = file_path.split('/')
    new_path[2] = new_path[2]+'_with_ind'
    os.makedirs('/'.join(new_path[:-1]), exist_ok=True)
    file.to_parquet('/'.join(new_path))

train_idx_dict = {i: i for i in range(global_index)}


eval_index = 0
for file_path in eval_files.data_files:
    file = pd.read_parquet(file_path)
    l = file.shape[0]
    file['col_id'] = file.index + global_index
    global_index += l
    eval_index += l
    
    new_path = file_path.split('/')
    new_path[2] = new_path[2]+'_with_ind'
    os.makedirs('/'.join(new_path[:-1]), exist_ok=True)
    file.to_parquet('/'.join(new_path))

eval_idx_dict = {i+global_index-eval_index: i for i in range(eval_index)}

"train_idx_dict" and "eval_idx_dict" are needed to map row id in case if id is not a integer number vareing from 0 to len(dataset)-1

In [4]:
path = "./syndata/example_data_with_ind/"
sample_df = pd.read_parquet(path+'train/train_10.parquet')
sample_df.head()

Unnamed: 0,A,B,C,event_time,class_label,col_id
0,"[32, 3, 29, 47, 57, 10, 23, 61, 42, 22, 49, 10...","[58, 20, 37, 45, 42, 18, 19, 25, 13, 42, 22, 5...","[3, 28, 39, 60, 39, 60, 39, 56, 3, 28, 38, 51,...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1,46080
1,"[29, 45, 45, 45, 45, 45, 45, 44, 32, 3, 30, 49...","[32, 5, 42, 19, 30, 51, 26, 16, 7, 61, 42, 19,...","[63, 62, 54, 54, 54, 54, 54, 54, 54, 54, 54, 5...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1,46081
2,"[11, 30, 48, 4, 34, 19, 24, 4, 34, 23, 62, 48,...","[14, 54, 53, 41, 14, 52, 34, 22, 53, 40, 5, 43...","[12, 36, 36, 38, 48, 3, 28, 33, 12, 33, 12, 33...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1,46082
3,"[26, 22, 53, 45, 47, 56, 7, 63, 60, 32, 7, 56,...","[10, 22, 52, 39, 61, 47, 59, 25, 14, 50, 21, 4...","[3, 24, 3, 24, 3, 24, 3, 24, 3, 31, 63, 56, 3,...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1,46083
4,"[59, 24, 7, 58, 22, 52, 39, 59, 26, 19, 28, 37...","[16, 0, 2, 17, 11, 28, 39, 56, 2, 21, 42, 16, ...","[10, 18, 18, 18, 18, 18, 23, 62, 48, 5, 43, 31...","[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...",1,46084


In [5]:
path = "./syndata/example_data_with_ind/"

train_files = ParquetFiles(os.path.join(path, "train"))
train_dataset = ParquetDataset(train_files, shuffle_files=True)
eval_files = ParquetFiles(os.path.join(path, "eval"))
eval_dataset = ParquetDataset(eval_files)


cluster_datamodule = PtlsDataModule(
    train_data=ClusterIterableDataset(
        train_dataset,
        splitter=SampleSlices(
            split_count=5,
            cnt_min=50,
            cnt_max=100),
        col_time='event_time',
        col_idx='col_id',
        idx_dict=train_idx_dict
    ),
    valid_data=ClusterIterableDataset(
        eval_dataset,
        splitter=SampleSlices(
            split_count=5,
            cnt_min=50,
            cnt_max=100),
        col_time='event_time',
        col_idx='col_id',
        idx_dict=eval_idx_dict
    ),
    train_num_workers=4,
    train_batch_size=512,
    valid_num_workers=4,
    valid_batch_size=512,
)

Now lets create model

In [6]:
# this callback will do the clustering operation every epoch or every given number of steps
# this object must be passed to pl trainer
device = 0


cluster_callback = ClusterCallback(
    cluster_datamodule = cluster_datamodule,
    num_cluster = [10, 100], # Here we difine number of clusters we will seek in data.
                             # We may try more than 1 clustering
    temperature = 0.5, # variance modifier
    device = device,  # must be int!
    use_portion_to_train=None, # If set to None will use all available data to do clusterization
                               # (alt. float <0;1>)
    run_each_n_train_steps=None # If set to None will run at the beginning of each epoch
)


# we will use both clustering and contrastive loss
loss = ClusterAndContrastive(
    cluster_loss = ClusterLoss(),
    contrastive_loss = ContrastiveLoss(
        margin = 1.,
        sampling_strategy = HardNegativePairSelector(neg_count=5)
    ),
    cluster_weight=1., # weights of losses
    contrastive_weight=1.
)


# model itself
trx_conf = {
    'embeddings_noise': 0.001,
    'embeddings': {
        'A': {'in': 64, 'out': 16},
        'B': {'in': 64, 'out': 16},
    },
}

cluster_module = ClusterModule(
    seq_encoder = ptls.nn.RnnSeqEncoder(
        trx_encoder=ptls.nn.TrxEncoder(**trx_conf),
        input_size=32,
        type='gru',
        hidden_size=32,
        is_reduce_sequence=True
    ),
    head=ptls.nn.Head(use_norm_encoder=True),
    loss=loss,
    optimizer_partial=partial(torch.optim.Adam, lr=1e-3),
    lr_scheduler_partial=partial(torch.optim.lr_scheduler.StepLR, step_size=30, gamma=0.9025),
    num_cluster=[10, 100], # must be same as in callback
)

In [7]:
trainer = pl.Trainer(
    gpus=[device],
    max_epochs=10,
    enable_progress_bar=True,
    callbacks=[cluster_callback]
)
trainer.fit(cluster_module, cluster_datamodule)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name         | Type                  | Params
-------------------------------------------------------
0 | _loss        | ClusterAndContrastive | 0     
1 | _seq_encoder | RnnSeqEncoder         | 8.4 K 
2 | _head        | Head                  | 0     
-------------------------------------------------------
8.4 K     Trainable params
0         Non-trainable params
8.4 K     Total params
0.034     Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Clustering 51200 points in 32D to 10 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.13 s, search 0.12 s): objective=31571.2 imbalance=1.008 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.26 s, search 0.23 s): objective=31543.5 imbalance=1.020 nsplit=0       
Objective improved: keep new clusters
Outer iteration 2 / 5
  Iteration 19 (0.39 s, search 0.34 s): objective=31528.9 imbalance=1.015 nsplit=0       
Objective improved: keep new clusters
Outer iteration 3 / 5
  Iteration 19 (0.52 s, search 0.46 s): objective=31644 imbalance=1.020 nsplit=0         
Outer iteration 4 / 5
  Iteration 19 (0.66 s, search 0.58 s): objective=31622.8 imbalance=1.029 nsplit=0       
Clustering 51200 points in 32D to 100 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.13 s, search 0.12 s): objective=21950.6 imbalance=1.055 nsplit=0       
Objective i

Training: 0it [00:00, ?it/s]


Clustering 256000 points in 32D to 10 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.73 s, search 0.57 s): objective=158474 imbalance=1.011 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (1.48 s, search 1.14 s): objective=158344 imbalance=1.008 nsplit=0       
Objective improved: keep new clusters
Outer iteration 2 / 5
  Iteration 19 (2.21 s, search 1.71 s): objective=159369 imbalance=1.029 nsplit=0       
Outer iteration 3 / 5
  Iteration 19 (2.94 s, search 2.28 s): objective=158298 imbalance=1.020 nsplit=0       
Objective improved: keep new clusters
Outer iteration 4 / 5
  Iteration 19 (3.66 s, search 2.84 s): objective=157996 imbalance=1.022 nsplit=0       
Objective improved: keep new clusters
Clustering 256000 points in 32D to 100 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.62 s, search 0.58 s): objective=110049 imbalanc

Validation: 0it [00:00, ?it/s]


Objective improved: keep new clusters
Clustering 51200 points in 32D to 10 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.14 s, search 0.11 s): objective=15316.4 imbalance=1.007 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.28 s, search 0.23 s): objective=15332.1 imbalance=1.011 nsplit=0       
Outer iteration 2 / 5
  Iteration 19 (0.43 s, search 0.34 s): objective=15333 imbalance=1.008 nsplit=0         
Outer iteration 3 / 5
  Iteration 19 (0.57 s, search 0.46 s): objective=15328.8 imbalance=1.005 nsplit=0       
Outer iteration 4 / 5
  Iteration 19 (0.71 s, search 0.57 s): objective=15353.8 imbalance=1.005 nsplit=0       
Clustering 51200 points in 32D to 100 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.12 s, search 0.12 s): objective=11924.7 imbalance=1.031 nsplit=0       
Objective improved: keep new clusters
Outer iter

Validation: 0it [00:00, ?it/s]


Clustering 51200 points in 32D to 10 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.14 s, search 0.11 s): objective=16651.2 imbalance=1.008 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.28 s, search 0.23 s): objective=16660.4 imbalance=1.023 nsplit=0       
Outer iteration 2 / 5
  Iteration 19 (0.43 s, search 0.34 s): objective=16657.8 imbalance=1.008 nsplit=0       
Outer iteration 3 / 5
  Iteration 19 (0.57 s, search 0.46 s): objective=16657.6 imbalance=1.006 nsplit=0       
Outer iteration 4 / 5
  Iteration 19 (0.71 s, search 0.57 s): objective=16655.5 imbalance=1.006 nsplit=0       
Clustering 51200 points in 32D to 100 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.12 s, search 0.12 s): objective=13884 imbalance=1.014 nsplit=0         
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.25 s, se

Validation: 0it [00:00, ?it/s]


Clustering 51200 points in 32D to 10 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.13 s, search 0.11 s): objective=20122.4 imbalance=1.005 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.26 s, search 0.23 s): objective=20099.2 imbalance=1.007 nsplit=0       
Objective improved: keep new clusters
Outer iteration 2 / 5
  Iteration 19 (0.39 s, search 0.34 s): objective=20076.5 imbalance=1.005 nsplit=0       
Objective improved: keep new clusters
Outer iteration 3 / 5
  Iteration 19 (0.52 s, search 0.46 s): objective=20059.9 imbalance=1.003 nsplit=0       
Objective improved: keep new clusters
Outer iteration 4 / 5
  Iteration 19 (0.66 s, search 0.57 s): objective=20059.9 imbalance=1.005 nsplit=0       
Clustering 51200 points in 32D to 100 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.12 s, search 0.12 s): objective=16050.9 imba

Validation: 0it [00:00, ?it/s]


Clustering 51200 points in 32D to 10 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.16 s, search 0.13 s): objective=24495.5 imbalance=1.002 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.32 s, search 0.26 s): objective=24454.1 imbalance=1.003 nsplit=0       
Objective improved: keep new clusters
Outer iteration 2 / 5
  Iteration 19 (0.47 s, search 0.38 s): objective=24546.7 imbalance=1.004 nsplit=0       
Outer iteration 3 / 5
  Iteration 19 (0.63 s, search 0.51 s): objective=24445.2 imbalance=1.002 nsplit=0       
Objective improved: keep new clusters
Outer iteration 4 / 5
  Iteration 19 (0.79 s, search 0.63 s): objective=24561.2 imbalance=1.003 nsplit=0       
Clustering 51200 points in 32D to 100 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.13 s, search 0.13 s): objective=18711.7 imbalance=1.009 nsplit=0       
Objective 

Validation: 0it [00:00, ?it/s]


Clustering 51200 points in 32D to 10 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.15 s, search 0.12 s): objective=29154.8 imbalance=1.001 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.30 s, search 0.25 s): objective=29173.7 imbalance=1.003 nsplit=0       
Outer iteration 2 / 5
  Iteration 19 (0.45 s, search 0.38 s): objective=29051.9 imbalance=1.001 nsplit=0       
Objective improved: keep new clusters
Outer iteration 3 / 5
  Iteration 19 (0.60 s, search 0.50 s): objective=29090 imbalance=1.003 nsplit=0         
Outer iteration 4 / 5
  Iteration 19 (0.75 s, search 0.63 s): objective=29181 imbalance=1.001 nsplit=0         
Clustering 51200 points in 32D to 100 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.14 s, search 0.13 s): objective=22193 imbalance=1.009 nsplit=0         
Objective improved: keep new clusters
Outer iter

Validation: 0it [00:00, ?it/s]


Clustering 51200 points in 32D to 10 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.16 s, search 0.13 s): objective=32555.5 imbalance=1.001 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.31 s, search 0.25 s): objective=32637.2 imbalance=1.001 nsplit=0       
Outer iteration 2 / 5
  Iteration 19 (0.47 s, search 0.38 s): objective=32516.4 imbalance=1.001 nsplit=0       
Objective improved: keep new clusters
Outer iteration 3 / 5
  Iteration 19 (0.62 s, search 0.50 s): objective=32532.5 imbalance=1.002 nsplit=0       
Outer iteration 4 / 5
  Iteration 19 (0.78 s, search 0.63 s): objective=32603.3 imbalance=1.003 nsplit=0       
Clustering 51200 points in 32D to 100 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.14 s, search 0.13 s): objective=25087.8 imbalance=1.010 nsplit=0       
Objective improved: keep new clusters
Outer iter

Validation: 0it [00:00, ?it/s]


Clustering 51200 points in 32D to 10 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.15 s, search 0.12 s): objective=34853.1 imbalance=1.002 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.30 s, search 0.25 s): objective=34975.4 imbalance=1.000 nsplit=0       
Outer iteration 2 / 5
  Iteration 19 (0.46 s, search 0.38 s): objective=34879.9 imbalance=1.002 nsplit=0       
Outer iteration 3 / 5
  Iteration 19 (0.61 s, search 0.50 s): objective=34905.2 imbalance=1.001 nsplit=0       
Outer iteration 4 / 5
  Iteration 19 (0.77 s, search 0.62 s): objective=34949.2 imbalance=1.001 nsplit=0       
Clustering 51200 points in 32D to 100 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.13 s, search 0.12 s): objective=27202.2 imbalance=1.013 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.27 s, se

Validation: 0it [00:00, ?it/s]


Objective improved: keep new clusters
Clustering 51200 points in 32D to 10 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.16 s, search 0.13 s): objective=36296.9 imbalance=1.002 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.32 s, search 0.25 s): objective=36349 imbalance=1.002 nsplit=0         
Outer iteration 2 / 5
  Iteration 19 (0.47 s, search 0.38 s): objective=36323.3 imbalance=1.002 nsplit=0       
Outer iteration 3 / 5
  Iteration 19 (0.63 s, search 0.50 s): objective=36353.5 imbalance=1.005 nsplit=0       
Outer iteration 4 / 5
  Iteration 19 (0.78 s, search 0.63 s): objective=36330.2 imbalance=1.003 nsplit=0       
Clustering 51200 points in 32D to 100 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.14 s, search 0.13 s): objective=28498.6 imbalance=1.016 nsplit=0       
Objective improved: keep new clusters
Outer iter

Validation: 0it [00:00, ?it/s]


Clustering 51200 points in 32D to 10 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.14 s, search 0.12 s): objective=37216.1 imbalance=1.002 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.28 s, search 0.25 s): objective=37189.6 imbalance=1.001 nsplit=0       
Objective improved: keep new clusters
Outer iteration 2 / 5
  Iteration 19 (0.42 s, search 0.37 s): objective=37242.1 imbalance=1.002 nsplit=0       
Outer iteration 3 / 5
  Iteration 19 (0.56 s, search 0.50 s): objective=37254.9 imbalance=1.001 nsplit=0       
Outer iteration 4 / 5
  Iteration 19 (0.70 s, search 0.62 s): objective=37180.2 imbalance=1.002 nsplit=0       
Objective improved: keep new clusters
Clustering 51200 points in 32D to 100 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.14 s, search 0.13 s): objective=29207.7 imbalance=1.014 nsplit=0       
Objective 

Validation: 0it [00:00, ?it/s]


Clustering 51200 points in 32D to 10 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.14 s, search 0.12 s): objective=38010.2 imbalance=1.002 nsplit=0       
Objective improved: keep new clusters
Outer iteration 1 / 5
  Iteration 19 (0.28 s, search 0.25 s): objective=38029.3 imbalance=1.002 nsplit=0       
Outer iteration 2 / 5
  Iteration 19 (0.42 s, search 0.37 s): objective=37936.1 imbalance=1.001 nsplit=0       
Objective improved: keep new clusters
Outer iteration 3 / 5
  Iteration 19 (0.56 s, search 0.50 s): objective=37947.1 imbalance=1.004 nsplit=0       
Outer iteration 4 / 5
  Iteration 19 (0.70 s, search 0.62 s): objective=37920.2 imbalance=1.002 nsplit=0       
Objective improved: keep new clusters
Clustering 51200 points in 32D to 100 clusters, redo 5 times, 20 iterations
  Preprocessing in 0.00 s
Outer iteration 0 / 5
  Iteration 19 (0.13 s, search 0.12 s): objective=29929.8 imbalance=1.016 nsplit=0       
Objective 

In [8]:
# now lets test it

def get_synthetic_sup_datamodule():
    path = "./syndata/example_data/"
    
    train_files = ParquetFiles(os.path.join(path, "train"))
    train_dataset = ParquetDataset(train_files, shuffle_files=True)
    test_files = ParquetFiles(os.path.join(path, "eval"))
    test_dataset = ParquetDataset(test_files, shuffle_files=True)

    sup_datamodule = PtlsDataModule(
        train_data=SeqToTargetIterableDataset(train_dataset, target_col_name='class_label', target_dtype=torch.long),
        test_data=SeqToTargetIterableDataset(test_dataset, target_col_name='class_label', target_dtype=torch.long),
        train_batch_size=512,
        test_batch_size=512,
        train_num_workers=4,
        test_num_workers=4,
    )
    return sup_datamodule


def eval_dataloader(model, dl, device='cuda:0'):
    embs, yy = list(), list()
    model.to(device)
    model.eval()
    for batch in dl:
        x, y = batch
        yy.append(y.numpy())
        with torch.no_grad():
            embs.append(model(x.to(device)).cpu().numpy())
    return {'x': np.concatenate(embs, axis=0), 'y': np.concatenate(yy, axis=0)}
        


def eval_embeddings(coles_model, data):
    train_gbm_data = eval_dataloader(coles_model, data.train_dataloader())
    test_gbm_data = eval_dataloader(coles_model, data.test_dataloader())
    return train_gbm_data, test_gbm_data


def gbm(train_gbm_data, test_gbm_data):
    accs = list()
    for gbm_i in range(5):
        gbm_model = LGBMClassifier(**{
              'n_estimators': 50,
              'boosting_type': 'gbdt',
              'objective': 'binary',
              'learning_rate': 0.02,
              'subsample': 0.75,
              'subsample_freq': 1,
              'feature_fraction': 0.75,
              'colsample_bytree': None,
              'max_depth': 12,
              'lambda_l1': 1,
              'reg_alpha': None,
              'lambda_l2': 1,
              'reg_lambda': None,
              'min_data_in_leaf': 50,
              'min_child_samples': None,
              'num_leaves': 50,
              'random_state': 42+gbm_i,
              'n_jobs': 4,
        })
        
        gbm_model.fit(train_gbm_data['x'], train_gbm_data['y'])
        acc = roc_auc_score(test_gbm_data['y'], gbm_model.predict_proba(test_gbm_data['x'])[:, 1])
        accs.append(acc)
    mean, std = np.mean(accs), np.std(accs)
    print(f'mean roc_auc: {mean:.4f} std : {std:.4f}')


eval_datamodule = get_synthetic_sup_datamodule()
train_gbm_data, test_gbm_data = eval_embeddings(cluster_module, eval_datamodule)
gbm(train_gbm_data, test_gbm_data)

[LightGBM] [Info] Number of positive: 25600, number of negative: 25600
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.002419 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 8160
[LightGBM] [Info] Number of data points in the train set: 51200, number of used features: 32
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
[LightGBM] [Info] Number of positive: 25600, number of negative: 25600
[LightGBM] [Info] Auto-choosing col-wise multi-threading, the overhead of testing was 0.002419 seconds.
You can set `force_col_wise=true` to remove the overhead.
[LightGBM] [Info] Total Bins 8160
[LightGBM] [Info] Number of data points in the train set: 51200, number of used features: 32
[LightGBM] [Info] [binary:BoostFromScore]: pavg=0.500000 -> initscore=0.000000
[LightGBM] [Info] Number of positive: 25600, number of negative: 25600
[LightGBM] [Info] Auto-choosing col-wise multi-thre