# Schemas

In [1]:
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt
from sklearn import metrics as sk_metrics

In [None]:
from alphacnn.visualize import plot_decoding
from alphacnn import paths
from alphacnn.database.dataset_schema import *

connect_to_database(
    dj_config_file=paths.CONFIG_FILE,
    create_tables=False, create_schema=False, schema_name=paths.SCHEMA_PREFIX + 'dataset')
dataset_schema

In [None]:
from alphacnn.database.pres_decoder_schema import *

connect_to_database(
    dj_config_file=paths.CONFIG_FILE,
    create_tables=True, create_schema=True, schema_name=paths.SCHEMA_PREFIX + 'decoder')
pres_decoder_schema

# ERD

In [None]:
import warnings

with warnings.catch_warnings():
    warnings.simplefilter("ignore", FutureWarning)
    display(dj.ERD(pres_decoder_schema))

# Pres Decoders

## Define

In [5]:
PresDecoderKind().add(
    decoder_id='cnn_ensemble_10', kind='cnn', params=dict(
        w_l2=0.003, w_l2_conv=0.001, n_convs=5, n_dense=8, first_conv_size=3, other_conv_size=3,
        first_conv_nfilt=3, other_conv_nfilt=3, padding='same', pool_padding='same', pres_loss='binary_crossentropy',
    ), skip_duplicates=True)

## Predict

In [None]:
rgcs = ['nsl', 'tmp', 'tmp_ws', 'tmp_ss']
bc_noise_lvls = ['med']
suffix_list = ['']

data_set_files = [f'dataset_f002_f003_rot_1975_w_and_wo_test_{rgc}_bcns{bc_noise_lvl}{suffix}.pkl'
                 for rgc in rgcs
                 for bc_noise_lvl in bc_noise_lvls
                 for suffix in suffix_list]
data_set_files

In [None]:
PresDecoderPrediction().populate(
    [dict(data_set_file=data_set_file) for data_set_file in data_set_files], dict(split_id=0),
    display_progress=False)

In [None]:
PresDecoderPrediction()

## Plot

In [None]:
for key in (PresDecoderPrediction() & dict(split_kind='test')).proj().fetch(as_dict=True):
    print(key)
    PresDecoderPrediction().plot_loss(**key)
    plt.show()

In [None]:
for key in (PresDecoderPrediction() & dict(split_kind='test')).proj().fetch(as_dict=True):
    print(key)
    PresDecoderPrediction().plot(**key)
    plt.show()