# Schema

In [None]:
import os
import numpy as np

from alphacnn import paths
from alphacnn.database.dataset_schema import *

connect_to_database(
    dj_config_file=paths.CONFIG_FILE,
    create_tables=True, create_schema=True, schema_name=paths.SCHEMA_PREFIX + 'dataset')
dataset_schema

# ERD

In [None]:
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore", FutureWarning)
    display(dj.ERD(dataset_schema))

# Load Data-Sets

In [None]:
src_dir = 'database_v1'

available_data_set_files = [
    f for f in os.listdir(os.path.join(paths.DATASET_PATH, src_dir))
    if f.startswith('dataset')]

available_data_set_files

In [None]:
rgcs = ['nsl', 'tmp', 'tmp_ws', 'tmp_ss']
bc_noise_lvls = ['med']
suffix_list = ['']

data_set_files = [f'dataset_f002_f003_rot_1975_w_and_wo_test_{rgc}_bcns{bc_noise_lvl}{suffix}.pkl'
                 for rgc in rgcs
                 for bc_noise_lvl in bc_noise_lvls
                 for suffix in suffix_list]
data_set_files

In [None]:
for data_set_file in data_set_files:
    if data_set_file not in available_data_set_files:
        print(f'data_set_file `{data_set_file}` does not exist')
        continue
    
    if data_set_file in DataSet().fetch('data_set_file'):
        print('Skip (already there):', data_set_file)
        continue
    print('Add:', data_set_file)
    DataSet().add(data_set_file=data_set_file, skip_duplicates=True, src_dir=src_dir)

In [None]:
DataSet()

In [None]:
DataSet.DataPoint()

In [None]:
DataSet().plot1(key=None, frame_i=0)

In [None]:
(DataSet & "data_set_file='dataset_f002_f003_rot_1975_w_and_wo_test_nsl_bcnsmed.pkl'").plot1(frame_i=10)

# Normalize data

In [None]:
DataNorm()

In [None]:
DataNorm.populate(display_progress=True)
DataNorm()

In [None]:
DataNorm.NormPoint()

In [None]:
DataNorm().plot1(frame_i=0)

# Create Splits

In [None]:
DataSplit()

In [None]:
for data_set_file in DataSet.fetch('data_set_file'):
    print(data_set_file)
    DataSplit().add_distance_stratified(data_set_file, split_id=0, seed=431, skip_duplicates=True)

In [None]:
for split_id in np.unique(DataSplit.fetch('split_id')):
    print(split_id)
    for kind in ['train', 'dev', 'test']:
        x_len, p_sum = (DataSet.DataPoint() & (DataSplit.SplitPoint() & f"split_id={split_id}" & f"split_kind='{kind}'")).fetch('x_len', 'p_sum')
        print(kind, np.sum(x_len), np.sum(p_sum), np.sum(p_sum)/np.sum(x_len))
    print()