In [1]:
%load_ext autoreload
%autoreload 2

import os
from collections import defaultdict
from tqdm import tqdm
from IPython.display import clear_output
import torch.multiprocessing
torch.multiprocessing.set_sharing_strategy('file_system')

# import logging
import torch
import pytorch_lightning as pl
import warnings
import numpy as np
import pandas as pd


from functools import partial
from ptls.data_load.datasets import SyntheticDataset, ParquetFiles, ParquetDataset
from ptls.frames.supervised import SeqToTargetDataset, SeqToTargetIterableDataset, SequenceToTarget
from ptls.frames import PtlsDataModule
from functools import partial


from ptls.data_load.iterable_processing import SeqLenFilter
from ptls.frames.coles import ColesDataset
from ptls.frames.coles.split_strategy import SampleSlices


import torch
import torchmetrics
from ptls.nn import TrxEncoder, RnnSeqEncoder, Head
from ptls.frames.coles import CoLESModule


import pytorch_lightning as pl
import pickle

warnings.filterwarnings('ignore')
# logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)

2024-02-20 06:51:29.149006: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [2]:
from ptls.data_load.datasets import Config, SyntheticClient, SimpleSchedule, PlaneClassAssigner
from random import shuffle


def get_clients(n, noise):
    chain_confs = {
        "A": (4, 4, noise),
        "B": (8, 16, 0)
    }

    state_from = ["A"]
    state_to = ["B"]

    labeling_conf = {
        0: ("A",)
    }

    config = Config(chain_confs, state_from, state_to, labeling_conf)
    schedule = SimpleSchedule(config)
    config.save_assigners("config")
    config.load_assigners("config")
    
    per_class_n = int(n/2)
    
    clients = [SyntheticClient({0: 0}, config, schedule) for _ in range(per_class_n)] +\
              [SyntheticClient({0: 1}, config, schedule) for _ in range(per_class_n)]
    shuffle(clients)
    
    return clients


def get_sup_datamodule(noise, n_train=5000, n_eval=2000, n_test=3000):
    n = n_train + n_eval + n_test
    clients = get_clients(n, noise)
    assert len(clients) == n
    
    dataset_train = SyntheticDataset(clients[:n_train], seq_len=1024)
    dataset_valid = SyntheticDataset(clients[n_train:n_train+n_eval], seq_len=1024)
    dataset_test = SyntheticDataset(clients[n_train+n_eval:], seq_len=1024)

    sup_data = PtlsDataModule(
        train_data=SeqToTargetDataset(dataset_train, target_col_name='class_label', target_dtype=torch.long),
        valid_data=SeqToTargetDataset(dataset_valid, target_col_name='class_label', target_dtype=torch.long),
        test_data=SeqToTargetDataset(dataset_test, target_col_name='class_label', target_dtype=torch.long),
        train_batch_size=256,
        valid_batch_size=256,
        test_batch_size=256,
        train_num_workers=16,
        valid_num_workers=16,
        test_num_workers=16
    )
    return sup_data




In [3]:
def write_dataset(main_folder, noise=0.,
                  train_num_files=200, eval_num_files=100, test_num_files=100,
                  n_train=256*10, n_eval=256*10, n_test=256*10):
    
    train_folder = os.path.join(main_folder, "train")
    eval_folder = os.path.join(main_folder, "eval")
    test_folder = os.path.join(main_folder, "test")
    os.makedirs(train_folder, exist_ok=True)
    os.makedirs(eval_folder, exist_ok=True)
    os.makedirs(test_folder, exist_ok=True)
    
    for fn in tqdm(range(max(train_num_files, eval_num_files, test_num_files))):
        data = get_sup_datamodule(noise, n_train, n_eval, n_test)
        
        if fn < train_num_files:
            df = defaultdict(list)
            for i, batch in enumerate(data.train_dataloader()):
                x, y = batch
                x_d = x.payload
                for k in x_d:
                    df[k].extend(x_d[k].int().tolist())
                df['class_label'].extend(y.int().tolist())
            df = pd.DataFrame(df)
            df.to_parquet(os.path.join(train_folder, "train_"+str(fn)+".parquet"))
        
        
        if fn < eval_num_files:
            df = defaultdict(list)
            for i, batch in enumerate(data.val_dataloader()):
                x, y = batch
                x_d = x.payload
                for k in x_d:
                    df[k].extend(x_d[k].int().tolist())
                df['class_label'].extend(y.int().tolist())
            df = pd.DataFrame(df)
            df.to_parquet(os.path.join(eval_folder, "eval_"+str(fn)+".parquet"))
        
        
        if fn < test_num_files:
            df = defaultdict(list)
            for i, batch in enumerate(data.test_dataloader()):
                x, y = batch
                x_d = x.payload
                for k in x_d:
                    df[k].extend(x_d[k].int().tolist())
                df['class_label'].extend(y.int().tolist())
            df = pd.DataFrame(df)
            df.to_parquet(os.path.join(test_folder, "test_"+str(fn)+".parquet"))

In [4]:
write_dataset("syndata/new_data_0", noise=0.,
              train_num_files=250, eval_num_files=50, test_num_files=0,
              n_train=256*4, n_eval=256*4, n_test=256*4)

  3%|█▍                                         | 8/250 [01:37<49:23, 12.25s/it]


KeyboardInterrupt: 