In [None]:
import pickle
import numpy as np
import math
import panel as pn
import gc
import holoviews as hv
import altair as alt
alt.data_transformers.disable_max_rows()
hv.extension("plotly")
pn.extension("plotly")
pn.config.theme = 'dark'
hv.renderer('plotly').theme = 'dark'

import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import seaborn as sns
sns.set_style('darkgrid')
sns.set_palette('muted')
sns.set_context("notebook", font_scale=1.5,
                rc={"lines.linewidth": 2.5})
import torch
from tqdm.auto import tqdm
import torch
import statsmodels.api as sm
import numpy as np
from tqdm import tqdm
import os
import matplotlib.pyplot as plt
from tqdm import tqdm
import yaml
from torch.utils.data import Dataset
from sklearn.model_selection import train_test_split
from dataclasses import dataclass
import sys

sys.path.append(os.path.abspath(os.path.join('..')))
from src.preprocessing import ChromoData
from src.md import create_md_ds, MDDenseSet, MDDenseDatasetConfig, MDDataset, generate_splits, generate_gausian_data_groups_chromo, generate_mock_chromos, generate_phases_for_dense_validation


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
# TRAINING MD
# This cell generates MD datasets from configs
gc.collect()
config_folder = '../configs/MD'
configs = os.listdir(config_folder)
md_datasets = []

print(f"Found {len(configs)} configs")
for config in configs:
    print(config)

for config in tqdm(configs):
    print("=====================================")
    md_dataset = MDDataset.load(config_name=config, skip_data_load=True)
    config = md_dataset.config
    print(f"Generating dataset for {config.description}")
    print("Generating splits")
    split_size, splits_m, deviation_m = generate_splits(
        min_pos=config.min_pos,
        max_pos=config.max_pos,
        split_count=config.split_count)
    print("Generating distributions")
    data = generate_gausian_data_groups_chromo(
        splits_m=splits_m,
        bell_shift_deviation=config.bell_shift_deviation,
        std=config.std,
        std_deviation=config.std_deviation,
        min_pos=config.min_pos,
        max_pos=config.max_pos,
        count_per_set=config.count_per_set,
        ceiling_top_deviation=config.ceiling_top_deviation,
        window=config.window)
    print("Generating tensors")
    dataset = create_md_ds(data)
    md_dataset.data = dataset
    print(f"Generated dataset with:")
    print(f"len(labels_classifier): {len(dataset['labels_classifier'])}")
    print(f"len(labels_values): {len(dataset['labels_values'])}")
    print(f"len(gxx): {len(dataset['gxx'])}")
    print("Saving dataset")
    md_dataset.save()
    md_datasets.append(md_dataset)

In [None]:
# Validation MD
# Dense MD for validating averaging algorithm
gc.collect()
config_folder = '../configs/MDDense'
dataset_folder = '../data/MDDense'
configs = os.listdir(config_folder)
md_datasets = []

print(f"Found {len(configs)} configs")
for config in configs:
    print(config)

for config in configs:
    print("=====================================")
    dense_md_dataset = MDDenseSet.load(
        config_name=config, 
        config_folder=config_folder,
        dataset_folder=dataset_folder,
        config_cls=MDDenseDatasetConfig,
        dataset_cls=MDDenseSet,
        skip_data_load=True)
    config = dense_md_dataset.config
    print(f"Generating dataset for {config.description}")
    print("Generating splits")
    split_size, splits_m, deviation_m = generate_splits(
        min_pos=config.min_pos,
        max_pos=config.max_pos,
        split_count=config.split_count)
    print("Generating distributions")
    data = generate_gausian_data_groups_chromo(
        splits_m=splits_m,
        bell_shift_deviation=config.bell_shift_deviation,
        std=config.std,
        std_deviation=config.std_deviation,
        min_pos=config.min_pos,
        max_pos=config.max_pos,
        count_per_set=config.count_per_set,
        ceiling_top_deviation=config.ceiling_top_deviation,
        window=config.window)
    print("Generating mock chromo data")
    
    # sample config.sample_size of data randomly
    random_idxs = np.random.randint(0, len(data), config.sample_size)
    data = [data[i] for i in random_idxs]
    
    dense_mds = generate_mock_chromos(data)
    print("Generating phases")
    md_Dense_data = generate_phases_for_dense_validation(dense_mds, window_len=config.phase_window_size, do_real_mu=config.use_mu)
    print(f"Generated {len(md_Dense_data)} phases")
    dense_md_dataset.data = md_Dense_data
    dense_md_dataset.save(dataset_folder=dataset_folder)
    md_datasets.append(dense_md_dataset)