In [None]:
%%capture

!pip install ucimlrepo
!pip install ml_collections

In [None]:
!git clone https://github.com/JayoungKim408/STaSy.git

Cloning into 'STaSy'...
remote: Enumerating objects: 61, done.[K
remote: Counting objects: 100% (11/11), done.[K
remote: Compressing objects: 100% (2/2), done.[K
remote: Total 61 (delta 10), reused 9 (delta 9), pack-reused 50 (from 1)[K
Receiving objects: 100% (61/61), 448.88 KiB | 16.63 MiB/s, done.
Resolving deltas: 100% (19/19), done.


In [None]:
%cd STaSy

/content/STaSy


In [None]:
import os
import json
import numpy as np
import pandas as pd
from prepare_dataset_utils import CATEGORICAL, CONTINUOUS, ORDINAL, verify
from ucimlrepo import fetch_ucirepo


output_dir = 'tabular_datasets'
name = "adult"

def project_table(data, meta):
    values = np.zeros(shape=data.shape, dtype='float32')

    for id_, info in enumerate(meta):
        if info['type'] == CONTINUOUS:
            values[:, id_] = data.iloc[:, id_].values.astype('float32')
        else:
            mapper = dict([(item, id) for id, item in enumerate(info['i2s'])])
            mapped = data.iloc[:, id_].apply(lambda x: mapper[x]).values
            values[:, id_] = mapped

    return values

def main():
    os.makedirs(output_dir, exist_ok=True)

    # Загрузка датасета
    dataset = fetch_ucirepo(id=2)
    X = dataset.data.features
    y = dataset.data.targets

    df = pd.concat([X, y], axis=1)

    # Определяем типы столбцов (пример для Adult)
    col_type = [
        ('age', CONTINUOUS),
        ('workclass', CATEGORICAL),
        ('fnlwgt', CONTINUOUS),
        ('education', CATEGORICAL),
        ('education-num', CONTINUOUS),
        ('marital-status', CATEGORICAL),
        ('occupation', CATEGORICAL),
        ('relationship', CATEGORICAL),
        ('race', CATEGORICAL),
        ('sex', CATEGORICAL),
        ('capital-gain', CONTINUOUS),
        ('capital-loss', CONTINUOUS),
        ('hours-per-week', CONTINUOUS),
        ('native-country', CATEGORICAL),
        ('income', CATEGORICAL)
    ]

    df = df.replace('?', np.nan).dropna()

    # Создаем метаданные
    meta = []
    for id_, info in enumerate(col_type):
        col_name = info[0]
        col_data = df[col_name]

        if info[1] == CONTINUOUS:
            meta.append({
                "name": col_name,
                "type": CONTINUOUS,
                "min": float(col_data.min()),
                "max": float(col_data.max())
            })
        else:
            categories = list(col_data.unique())
            meta.append({
                "name": col_name,
                "type": CATEGORICAL,
                "size": len(categories),
                "i2s": categories
            })

    # Преобразование данных
    tdata = project_table(df, meta)

    # Конфигурация
    config = {
        "columns": meta,
        "problem_type": "binary_classification"
    }

    # Разделение данных
    np.random.seed(0)
    np.random.shuffle(tdata)

    split_ratio = int(tdata.shape[0] * 0.8)
    train_data = tdata[:split_ratio]
    test_data = tdata[split_ratio:]

    with open(f"{output_dir}/{name}.json", "w") as f:
        json.dump(config, f, indent=4)

    np.savez(f"{output_dir}/{name}.npz", train=train_data, test=test_data)

    verify(f"{output_dir}/{name}.npz", f"{output_dir}/{name}.json")

In [None]:
main()

In [None]:
!python main.py --config configs/adult.py --mode train --workdir stasy

2025-04-29 18:05:08.267074: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1745949908.287650    1384 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1745949908.293966    1384 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-04-29 18:05:08.313829: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
the number of parameters 46660
W0429 18:05:22.503205 136791537222272 utils.py:12] No checkpoint found at stasy/checkp