In [1]:
%load_ext autoreload
%autoreload 2 
%matplotlib agg

import visualization
from data import datasets, evaluation
import transformations
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.core import freeze
import optax
import inference
from utils import settings
import numpy as np
import os
import global_settings

rng_key = jax.random.PRNGKey(0)
fraction = 0.2

# benchmark

In [2]:
dataset_list =[
    os.path.join(
        os.path.join(global_settings.PATH_DATASETS, "benchmark_data"),
        file_name) for file_name in os.listdir("/home/gw/data/datasets/benchmark_data/") if file_name.split('.')[-1] == "data"
]

In [3]:
dataset_indices = {}
for dataset_element in dataset_list:
    print(dataset_element)
    with open(dataset_element, "r") as f:
        lines = f.readlines()
        lines_count = len(lines)
        train_count = int((1.0 - fraction) * lines_count)
        rng_key, rng_key_ = jax.random.split(rng_key)
        indices = jax.random.permutation(rng_key_, lines_count)
        indices_train = np.array(indices[:train_count]).tolist()
        indices_val = np.array(indices[train_count:]).tolist()
        print(len(indices_train), len(indices_val))
        
        # save to dict/json
        dataset_name = dataset_element.split('/')[-1].split('.')[0]
        dataset_indices[dataset_name] = {"train": indices_train, "validate": indices_val}

/home/gw/data/datasets/benchmark_data/naval_compressor.data
9547 2387
/home/gw/data/datasets/benchmark_data/airfoil.data
1202 301
/home/gw/data/datasets/benchmark_data/protein.data
36584 9146
/home/gw/data/datasets/benchmark_data/boston.data
404 102
/home/gw/data/datasets/benchmark_data/diabetes.data
353 89
/home/gw/data/datasets/benchmark_data/video_time.data
55027 13757
/home/gw/data/datasets/benchmark_data/concrete.data
824 206
/home/gw/data/datasets/benchmark_data/yacht.data
246 62
/home/gw/data/datasets/benchmark_data/video_mem.data
55027 13757
/home/gw/data/datasets/benchmark_data/gpu.data
193280 48320
/home/gw/data/datasets/benchmark_data/naval_turbine.data
9547 2387
/home/gw/data/datasets/benchmark_data/forest_fire.data
413 104
/home/gw/data/datasets/benchmark_data/energy.data
614 154
/home/gw/data/datasets/benchmark_data/wine.data
142 36


# toy

In [2]:
# toy datasets
dataset_indices = {}

## izmailov
dd = datasets.Izmailov()
N = len(dd._data)
print(N, fraction)

rng_key, rng_key_ = jax.random.split(rng_key)
train_count = int((1.0 - fraction) * N)
indices = jax.random.permutation(rng_key_, N)
indices_train = np.array(indices[:train_count]).tolist()
indices_val = np.array(indices[train_count:]).tolist()
print(len(indices_train), len(indices_val))

dataset_name = "izmailov"
dataset_indices[dataset_name] = {"train": indices_train, "validate": indices_val}

## sinusoidal
dd = datasets.Sinusoidal()
N = len(dd._data)
print(N, fraction)

rng_key, rng_key_ = jax.random.split(rng_key)
train_count = int((1.0 - fraction) * N)
indices = jax.random.permutation(rng_key_, N)
indices_train = np.array(indices[:train_count]).tolist()
indices_val = np.array(indices[train_count:]).tolist()
print(len(indices_train), len(indices_val))

dataset_name = "sinusoidal"
dataset_indices[dataset_name] = {"train": indices_train, "validate": indices_val}

## sinusoidal
dd = datasets.Regression2d()
N = len(dd._data)
print(N, fraction)

rng_key, rng_key_ = jax.random.split(rng_key)
train_count = int((1.0 - fraction) * N)
indices = jax.random.permutation(rng_key_, N)
indices_train = np.array(indices[:train_count]).tolist()
indices_val = np.array(indices[train_count:]).tolist()
print(len(indices_train), len(indices_val))

dataset_name = "regression2d"
dataset_indices[dataset_name] = {"train": indices_train, "validate": indices_val}

400 0.2
320 80
150 0.2
120 30
256 0.2
204 52


In [3]:
import json
with open(os.path.join(global_settings.PATH_DATASETS, "toy_dataset_indices_0.2.json"), 'w') as foutput:
    json.dump(dataset_indices, foutput)

# export toy data as .data files

In [28]:
dataset_indices = {}

## izmailov
dd = datasets.Izmailov()
N = len(dd._data)

In [29]:
result_string = " "
for element in dd._data:
    line = ""
    x = element[jnp.array(dd._conditional_indices)]
    y = element[jnp.array(dd._dependent_indices)]
    for feature in x:
        line += f"{feature} "
    for i in range(len(y)):
        if i == len(y) - 1:
            line += f"{y[i]}"
        else:
            line += f"{y[i]} "
    result_string += f"{line}\n"

print(result_string)
with open(os.path.join(global_settings.PATH_DATASETS, "izmailov.data"), 'w') as f:
    f.write(result_string)

 0.14994968473911285 0.714125394821167
-1.5646891593933105 -1.5090986490249634
-0.2670217454433441 0.4182179868221283
-1.3702987432479858 -1.42794668674469
-0.01059589721262455 0.5898424983024597
-0.07479147613048553 0.6236246228218079
1.0403910875320435 0.5589092969894409
-1.464500904083252 -1.6007733345031738
0.9993542432785034 0.9104956388473511
1.2155258655548096 0.5303336977958679
-0.12780259549617767 0.5286206007003784
0.9839214086532593 1.0920649766921997
0.012068305164575577 0.8030752539634705
0.0004735661786980927 0.6259732842445374
0.08514846116304398 0.47762641310691833
-0.297279417514801 0.2222561240196228
-0.20509883761405945 0.35332435369491577
1.3173505067825317 0.682981014251709
1.1775751113891602 0.4646731913089752
1.3564509153366089 0.4781087040901184
-0.24211657047271729 0.3683837950229645
-1.4100456237792969 -1.3796005249023438
-1.180680513381958 -1.6184237003326416
-1.5557782649993896 -1.5875425338745117
-0.033010803163051605 0.8252583742141724
-1.1979141235351562 