# Colab Setup, Installation of Prerequisites, Modifications
If running offline, many of the steps here do not work out of the box unless iPython is setup properly.

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
%%shell

#!/bin/bash

### not sure what ml libraries hopper comes with - or your setup for that matter
### you should only need to install / use pytorch for now provided you don't modify the example calls
### assuming no sudo access...

# apt update && apt install ffmpeg not required for colab

pip install brian2hears
pip install brian2
pip install importlib
pip install sklearn
pip install scipy
pip install librosa
pip install tqdm
pip install visualkeras

wget https://github.com/bBoxType/FiraSans/archive/refs/heads/master.zip -O FiraSans.zip

unzip FiraSans.zip
sudo mkdir -p /usr/local/share/fonts/truetype/fira

find FiraSans-master/Fira_Sans_4_3/Fonts/Fira_Sans_TTF_4301 -type f -name "*.ttf" -exec sudo cp {} /usr/local/share/fonts/truetype/fira/ \;
sudo fc-cache -fv


### your python version may vary! -- check path to this library once you've finished pip installing.
file_path="/usr/local/lib/python3.10/dist-packages/brian2hears/filtering/filterbanklibrary.py"
cp "$file_path" "${file_path}.bak" || { echo "Backup failed"; exit 1; }
perl -i -0pe '
    s/(class ApproximateGammatone\(LinearFilterbank\)(?:.*?\n)*?)(\s*def __init__\(self,)/$1    def get_individual_filter_outputs(self, input_signal):\n        filter_outputs = []  # List to store the outputs of each filter for each bandwidth\n        for i in range(len(self.cf)):\n            bandwidth_filter_outputs = []  # List to store the outputs of each filter for the current bandwidth\n            for j in range(self.order):\n                # Apply the j-th filter for the i-th bandwidth to the input signal\n                output = self.filters[i * self.order + j].apply(input_signal)\n                bandwidth_filter_outputs.append(output)  # Append the output to the list of filter outputs\n            filter_outputs.append(bandwidth_filter_outputs)  # Append the outputs for the current bandwidth to the main list\n        return filter_outputs\n$2/s' "$file_path" || echo "Failed to insert get_individual_filter_outputs"

# appends the 'process' method if not already present
perl -i -0pe '
    $process_method_code = qq{
        def process(self, func=None, duration=None, buffersize=32):
            if self.use_individual_outputs:
                if duration is None:
                    duration = self.duration
                if not isinstance(duration, int):
                    duration = int(duration * self.samplerate)

                self.buffer_init()
                total_output = []
                for start in range(0, duration, buffersize):
                    end = min(start + buffersize, duration)
                    input_signal = self.source.buffer_fetch(start, end)
                    output = self.get_individual_filter_outputs(input_signal)
                    total_output.append(output)
                return np.concatenate(total_output, axis=0)
            else:
                return super().process(func, duration, buffersize)
    };
    if (/class ApproximateGammatone\(LinearFilterbank\)/) {
        unless (/def process\(self, func=None, duration=None, buffersize=32\):/) {
            s/(def get_individual_filter_outputs\(self, input_signal\):.*?return filter_outputs\n)/$1\n$process_method_code/s;
        }
    }
' "$file_path" || echo "Failed to insert process method"

In [None]:
# optional
# !pip install arrayfire==3.8.0+cu112 -f https://repo.arrayfire.com/python/wheels/3.8.0/

In [None]:
import re
import os

def modify_file(file_path):
    with open(file_path, 'r', encoding='utf-8') as file:
        content = file.read()

    new_init_code = """
    def __init__(self, source, cf, bandwidth, order=4, use_individual_outputs=False):
        self.cf = cf
        cf = np.asarray(np.atleast_1d(cf))
        bandwidth = np.asarray(np.atleast_1d(bandwidth))
        self.samplerate = source.samplerate
        dt = float(1 / self.samplerate)
        phi = 2 * np.pi * bandwidth * dt
        theta = 2 * np.pi * cf * dt
        cos_theta = np.cos(theta)
        sin_theta = np.sin(theta)
        alpha = -np.exp(-phi) * cos_theta
        b0 = np.ones(len(cf))
        b1 = 2 * alpha
        b2 = np.exp(-2 * phi)
        z1 = (1 + alpha * cos_theta) - (alpha * sin_theta) * 1j
        z2 = (1 + b1 * cos_theta) - (b1 * sin_theta) * 1j
        z3 = (b2 * np.cos(2 * theta)) - (b2 * np.sin(2 * theta)) * 1j
        tf = (z2 + z3) / z1
        a0 = abs(tf)
        a1 = alpha * a0
        self.filt_a = np.dstack((np.array([b0, b1, b2]).T,)*order)
        self.filt_b = np.dstack((np.array([a0, a1, np.zeros(len(cf))]).T,)*order)
        super().__init__(source, self.filt_b, self.filt_a)
    """.strip()

    pattern = re.compile(r'(def __init__\(self, source, cf,  bandwidth,order=4\):)(.*?)(?=\n\w)', re.DOTALL)
    content = re.sub(pattern, new_init_code, content)
    with open(file_path, 'w', encoding='utf-8') as file:
        file.write(content)

modify_file('/usr/local/lib/python3.10/dist-packages/brian2hears/filtering/filterbanklibrary.py')


In [None]:
%%shell
#!/bin/bash

# Base directory
BASE_DIR="Model"

SUBDIRS=("graphics" "graphics_sample2" "graphics_sample3" "graphics_sample4" "graphics_sample5")
SUBFOLDERS=("maximal_entropy" "rossler" "simple_random_shuffle" "ornstein-uhlenbeck")

if [ ! -d "$BASE_DIR" ]; then
  mkdir "$BASE_DIR"
fi

create_subfolders() {
  local parent_dir=$1
  for subfolder in "${SUBFOLDERS[@]}"; do
    if [ ! -d "$parent_dir/$subfolder" ]; then
      mkdir -p "$parent_dir/$subfolder"
    fi
  done
}

for subdir in "${SUBDIRS[@]}"; do
  target_dir="$BASE_DIR/$subdir"
  if [ ! -d "$target_dir" ]; then
    mkdir "$target_dir"
  fi
  create_subfolders "$target_dir"
done

echo "Directory structure of graphics initialized."


# Imports + Configurable Settings
## Note
You can assign global variables in this stage after the imports are finished.

In [None]:
# import arrayfire as af
import gc
import gzip
import importlib
import json
import librosa
import math
import matplotlib.animation as animation
import matplotlib.font_manager as fm
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
import os
import pickle
import pprint
import random
import requests
import scipy.io
import scipy.signal
import seaborn as sns
import shutil
import soundfile as sf
import subprocess
import tarfile
import tempfile
import tensorflow as tf
import tensorflow_probability as tfp
import time
import torch
import torch.nn as nn
import torch.optim as optim
import visualkeras

from brian2 import *
from brian2hears import *
from collections import deque
from datetime import datetime
from IPython.display import HTML
from itertools import cycle
from matplotlib import cm, colors
from matplotlib.colors import to_rgba
from mpl_toolkits.mplot3d import Axes3D
from sklearn.manifold import TSNE
from tensorflow.keras import layers, Model
from tensorflow.keras import backend as K, layers, models, callbacks
from tensorflow.keras.callbacks import Callback, EarlyStopping, ModelCheckpoint
from tensorflow.keras.layers import Input, Dense, Lambda, Flatten, Reshape, Conv1D, Conv1DTranspose, Conv2D, Conv2DTranspose, Concatenate
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam
from tqdm import tqdm
from types import NoneType
from typing import Union, List, Tuple, Dict, Any
from wave import Wave_read

In [None]:
### configurable parameters
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
workingdirectory = '/content/'
audiosampledir = 'Model/macleod/'
graphpaths = 'Model/graphics'
arraydir = 'Model/preprocessed_arrays'
urlstxtdir = ''
incompetechurl = 'https://incompetech.com/music/royalty-free/mp3-royaltyfree/'
font_dirs = ["/usr/local/share/fonts/truetype"]

font_files = fm.findSystemFonts(fontpaths=font_dirs)
for font_file in font_files:
    fm.fontManager.addfont(font_file)
font_check = fm.get_font_names()
plt.rcParams['font.family'] = 'Fira Sans'
plt.rcParams['savefig.transparent'] = True

# af.set_backend("cuda")
importlib.reload(filterbanklibrary)

# Preprocessing

In [None]:
# Preprocessing Function defintions
def extract_data_arrays(
    sampled_data: Dict[str, Dict[str, Dict[str, ArrayContainer]]],
    file_id: str = None,
    processing_type: str = 'original',
    component_type: str = 'real'
) -> np.ndarray:

    data_arrays = []

    def extract_from_container(container: ArrayContainer) -> List[np.ndarray]:
        """Extracts arrays from a list, tuple, or dict."""
        if isinstance(container, (list, tuple)):
            return [data for data in container]
        elif isinstance(container, dict):
            return [data for key, data in container.items()]
        else:
            print(f"Expected a list, tuple, or dict but got {type(container)}")
            return []

    def fetch_data_for_file(fid: str) -> List[np.ndarray]:
        if fid in sampled_data:
            file_data = sampled_data[fid]
            if processing_type in file_data:
                process_data = file_data[processing_type]
                if component_type in process_data:
                    bins = process_data[component_type]
                    return extract_from_container(bins)
                else:
                    print(f"Component type '{component_type}' not found for file ID '{fid}'.")
            else:
                print(f"Processing type '{processing_type}' not found for file ID '{fid}'.")
        else:
            print(f"File ID '{fid}' not found in data.")
        return []

    if file_id:
        data_arrays.extend(fetch_data_for_file(file_id))
    else:
        for fid in sampled_data.keys():
            data_arrays.extend(fetch_data_for_file(fid))

    if not data_arrays:
        print(f"No data arrays were extracted for '{processing_type}' and '{component_type}'. Check your input data.")

    return np.array(data_arrays)

def extract_and_save_data_arrays_with_timestamp(
    sampled_data: Dict[str, Dict[str, Dict[str, ArrayContainer]]],
    directory=save_path[0]
) -> str:

    data_arrays = []
    metadata = {'keys': [], 'processing_types': [], 'component_types': [], 'bin_ids': []}
    os.makedirs(directory, exist_ok=True)
    timestamp = str(int(time.time()))
    arrays_filename = os.path.join(directory, f'saved_arrays_{timestamp}.npz')
    metadata_filename = os.path.join(directory, f'metadata_{timestamp}.json')
    for key, proc_data in sampled_data.items():
        for proc_type, comp_data in proc_data.items():
            for comp_type, bins in comp_data.items():
                for bin_id, array in bins.items():
                    data_arrays.append(array)
                    metadata['keys'].append(key)
                    metadata['processing_types'].append(proc_type)
                    metadata['component_types'].append(comp_type)
                    metadata['bin_ids'].append(bin_id)
    np.savez_compressed(arrays_filename, *data_arrays)
    with open(metadata_filename, 'w') as meta_file:
        json.dump(metadata, meta_file)

    return metadata_filename, timestamp

def load_and_rebuild_structure(metadata_filename: str, arrays_filename: str):
    with open(metadata_filename, 'r') as meta_file:
        metadata = json.load(meta_file)

    array_data = np.load(arrays_filename)

    reconstructed_data = {}
    array_index = 0
    for key, proc_type, comp_type, bin_id in zip(metadata['keys'], metadata['processing_types'], metadata['component_types'], metadata['bin_ids']):
        if key not in reconstructed_data:
            reconstructed_data[key] = {}
        if proc_type not in reconstructed_data[key]:
            reconstructed_data[key][proc_type] = {}
        if comp_type not in reconstructed_data[key][proc_type]:
            reconstructed_data[key][proc_type][comp_type] = {}
        reconstructed_data[key][proc_type][comp_type][bin_id] = array_data[f'arr_{array_index}']
        array_index += 1
    return reconstructed_data

def colab_write_frequency_bins_to_wav_and_archive(filter_outputs, file_path, sample_rate):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
    directory = f"{os.path.splitext(file_path)[0]}_iirexport_{timestamp}"
    os.makedirs(directory, exist_ok=True)

    if 'original' in filter_outputs and isinstance(filter_outputs['original'], dict):
        nested_dict = filter_outputs['original']
        num_filters = None
        concatenated_frames = {}
        for bin_key, waveform in nested_dict.items():
            if isinstance(waveform, np.ndarray):
                if num_filters is None:
                    num_filters = waveform.shape[1]
                    concatenated_frames = {f"filter_{i}": [] for i in range(num_filters)}
                for i in range(num_filters):
                    concatenated_frames[f"filter_{i}"].append(waveform[:, i])
        for filter_index in range(num_filters):
            concatenated_waveform = np.concatenate(concatenated_frames[f"filter_{filter_index}"])
            filename = f"{directory}/filter_{filter_index+1:03d}_{timestamp}.wav"
            sf.write(filename, concatenated_waveform, sample_rate)
            print(f"Written WAV file for filter {filter_index+1:03d}")
    else:
        print("Expected 'original' key with a dictionary of numpy arrays. Structure not found.")
    tarball_name = f"{directory}.tar.gz"
    os.system(f"tar -czf '{tarball_name}' -C '{directory}' .")
    print(f"Tarball created: {tarball_name}")
    os.system(f"rm -r '{directory}'")

## FFMPEG SUBPROCESSES
def convert_to_mp3_and_back(input_wav_path, output_wav_path, bitrate='32k', file_sr=44100):
    temp_mp3 = tempfile.NamedTemporaryFile(suffix='.mp3', delete=False)
    command_mp3 = [
        'ffmpeg', '-y',
        '-i', input_wav_path,
        '-codec:a', 'libmp3lame',
        '-b:a', bitrate,
        temp_mp3.name
    ]
    subprocess.run(command_mp3, check=True)

    command_wav = [
        'ffmpeg', '-y',
        '-i', temp_mp3.name,
        '-acodec', 'pcm_s16le',
        '-ar', str(file_sr),
        output_wav_path
    ]

    subprocess.run(command_wav, check=True)
    temp_mp3.close()

def prepare_audio_files(file_path, duration, UseDualMono=True):
    try:
        probe_command = ['ffprobe', '-v', 'error', '-select_streams', 'a:0', '-show_entries', 'stream=channels', '-of', 'default=noprint_wrappers=1:nokey=1', file_path]
        probe_result = subprocess.run(probe_command, text=True, capture_output=True)
        if probe_result.returncode != 0:
            print("Error probing the file:", probe_result.stderr)
            return []
        channels = int(probe_result.stdout.strip())
        print(f"Detected {channels} channels.")
        base_name = os.path.splitext(file_path)[0]
        temp_files = []
        if channels == 2:
            if UseDualMono:
                channel_labels = ['_L', '_R']
            else:
                channel_labels = [random.choice(['_L', '_R'])]

            for i, label in enumerate(channel_labels):
                output_file = f"{base_name}{label}.wav"
                command = [
                    'ffmpeg', '-y', '-i', file_path,
                    '-map_channel', f'0.0.{0 if label == "_L" else 1}',
                    '-acodec', 'pcm_s16le',
                    '-ar', '44100',
                    '-ac', '1',
                    '-t', str(duration),
                    output_file
                ]
                print("Running command:", ' '.join(command))
                subprocess_result = subprocess.run(command, text=True, capture_output=True)
                if subprocess_result.returncode != 0:
                    print("ffmpeg error:", subprocess_result.stderr)
                    continue
                temp_files.append(output_file)

        elif channels == 1:  # Mono file, standardize format and return
            output_file = f"{base_name}_mono.wav"
            command = [
                'ffmpeg', '-y', '-i', file_path,
                '-acodec', 'pcm_s16le',
                '-ar', '44100',
                '-ac', '1',
                '-t', str(duration),
                output_file
            ]
            subprocess_result = subprocess.run(command, text=True, capture_output=True)
            if subprocess_result.returncode == 0:
                temp_files.append(output_file)
        return temp_files
    except Exception as e:
        print(f"An error occurred: {str(e)}")
        return []


def cochlear_iir(file_path, fft_size=256, downloadfile=False, rectify=True, min_frequency=20, max_frequency=20000, num_frequency_bins=100, bw=None, random_order=False, with_lossy_backprop=False, *args, **kwargs):
    y, file_sr = librosa.load(file_path, sr=None, )
    brainload = loadsound(f"{file_path}")  # Ensure load is capable of reading from file path
    num_frames_original = int(np.ceil(len(y) / fft_size))
    total_frames = num_frames_original * (2 if with_lossy_backprop else 1)
    print(f"now passing file: {file_path} through filterbank...")
    global center_frequencies
    center_frequencies = erbspace(min_frequency*Hz, max_frequency*Hz, num_frequency_bins)
    if rectify:
              brainload = FunctionFilterbank(brainload, lambda x: clip(x, 0, Inf)**(1.0/3.0))
    if bw is None:
        bw = 10**(0.037 + 0.785 * np.log10(center_frequencies / Hz))
    def process_audio(audio_data, audio_brain, num_frames, file_sr):
        filter_outputs = {}
        try:
            gammatone = ApproximateGammatone(audio_brain, center_frequencies, bw, 4, use_individual_outputs=True)
            framefilter = gammatone.process(buffersize=fft_size)
            for i in range(num_frames):
                start_sample = i * fft_size
                end_sample = min(start_sample + fft_size, framefilter.shape[0])
                frame = framefilter[start_sample:end_sample]
                if len(frame) < fft_size:  # zero-pad if necessary
                    frame = np.pad(frame, ((0, fft_size - len(frame)), (0, 0)), mode='constant')
                filter_outputs[f'bin_{i}'] = frame
        except Exception as e:
            print("Error during processing:", e)
        print(f'Type of filter_outputs: {type(filter_outputs)}') #dictionary
        return filter_outputs
    results = {'original': process_audio(y, brainload, num_frames_original, file_sr)}

    if downloadfile:
        colab_write_frequency_bins_to_wav_and_archive(results, file_path, file_sr)

    if with_lossy_backprop:
        temp_wav_path = file_path.replace('.wav', '_temp.wav')
        output_wav_path = file_path.replace('.wav', '_back_from_mp3.wav')
        convert_to_mp3_and_back(file_path, output_wav_path, bitrate='64k', file_sr=file_sr)
        y_lossy, _ = librosa.load(output_wav_path, sr=file_sr)
        brianload = loadsound(f"{output_wav_path}")
        results['lossy'] = process_audio(y_lossy, brainload, num_frames_original, file_sr)
        os.remove(output_wav_path)  # Clean up

    return results

def epochset(file_paths, *args, **kwargs):
    results = {}
    filename_mapping = {}
    outputs = {}
    new_file_paths = []
    with_lossy_backprop = kwargs.get('with_lossy_backprop', False)
    backend = kwargs.get('backend', "pytorch")
    duration = kwargs.get('duration', 60)
    UseDualMono = kwargs.get('UseDualMono', True)
    for file_path in file_paths:
        temp_files = prepare_audio_files(file_path, duration, UseDualMono)
        if not temp_files:
            continue
        new_file_path = temp_files[0]  # assume the first file is what we need
        new_file_paths.append(new_file_path)
        normalized_name = normalize_filename(new_file_path)
        filename_mapping[file_path] = normalized_name
        outputs = cochlear_iir(new_file_path, *args, **kwargs)
        for processing_type in outputs:
            key = f"{processing_type}_{normalized_name}"
            print(f"now performing discrete fourier transform series for key: {normalized_name}...")
            print(f"performing calculations on {processing_type} version of the original file...")
            results[key] = perform_dft(outputs[processing_type], **kwargs)
    return results, filename_mapping, new_file_paths, outputs  # we return the filtered outputs

def perform_dft(named_arrays, with_lossy_backprop=False, progress_meter=None, backend='pytorch', **kwargs):
    if isinstance(named_arrays, list):
        named_arrays = {f'bin_{i}': named_arrays[i] for i in range(len(named_arrays))}

    if progress_meter is None:
        total_length = sum(len(data) for data in named_arrays.values())
        progress_meter = tqdm(total=total_length, desc='DFTs per second', unit='DFTs', leave=True, position=0)

    dfts_for_network = {'real': {}, 'imaginary': {}}

    def get_target_length(frames):
        return max(len(frame) for frame in frames)

    def pad_or_trim_frame(frame, target_length):
        if len(frame) > target_length:
            return frame[:target_length]
        else:
            padding = target_length - len(frame)
            if isinstance(frame, torch.Tensor):
                return torch.cat((frame, torch.zeros(padding, dtype=frame.dtype, device=frame.device)), dim=0)
            else:
                return np.pad(frame, (0, padding), 'constant')

    def process_with_pytorch(frames, use_gpu):
        target_length = get_target_length(frames)
        if use_gpu:
            device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        else:
            device = torch.device('cpu')
        padded_frames = [pad_or_trim_frame(torch.tensor(frame, dtype=torch.complex64, device=device), target_length) for frame in frames]
        stacked_frames = torch.stack(padded_frames)
        dft_results = torch.fft.fft(stacked_frames)
        real_parts = torch.real(dft_results).cpu().numpy()  # Ensure transfer back to CPU for compatibility with numpy
        imaginary_parts = torch.imag(dft_results).cpu().numpy()

        progress_meter.update(len(frames))
        return real_parts, imaginary_parts

    def process_with_arrayfire(frames):
        real_parts = []
        imaginary_parts = []
        target_length = get_target_length(frames)
        frames = np.array(frames) if not isinstance(frames, np.ndarray) else frames
        for frame in frames:
            frame = pad_or_trim_frame(frame, target_length)
            af_array = af.interop.from_ndarray(frame)
            dft_result = af.dft(af_array)
            real_parts.append(np.real(dft_result.to_ndarray()))
            imaginary_parts.append(np.imag(dft_result.to_ndarray()))
            progress_meter.update(1)
        return np.stack(real_parts), np.stack(imaginary_parts)

    def process_with_tensorflow(frames):
        real_parts = []
        imaginary_parts = []
        target_length = get_target_length(frames)
        frames = np.array(frames) if not isinstance(frames, np.ndarray) else frames
        for frame in frames:
            frame = pad_or_trim_frame(frame, target_length)
            frame_tensor = tf.convert_to_tensor(frame, dtype=tf.complex64)
            dft_result = tf.signal.fft(frame_tensor)
            real_parts.append(tf.math.real(dft_result).numpy())
            imaginary_parts.append(tf.math.imag(dft_result).numpy())
            progress_meter.update(1)
        return np.stack(real_parts), np.stack(imaginary_parts)

    for key, frames in named_arrays.items():
        if backend == 'pytorch':
            use_gpu = kwargs.get('use_gpu', "False")
            real_parts, imaginary_parts = process_with_pytorch(frames, use_gpu=use_gpu)
        elif backend == 'tensorflow':
            real_parts, imaginary_parts = process_with_tensorflow(frames)
        elif backend == 'pytorch':
            real_parts, imaginary_parts = process_with_pytorch(frames)
        else:
            raise ValueError(f"Unsupported backend: {backend}")
        dfts_for_network['real'][key] = real_parts
        dfts_for_network['imaginary'][key] = imaginary_parts

    progress_meter.close()

    if backend == 'tensorflow':
        tf.keras.backend.clear_session()
        gc.collect()

    if backend == 'pytorch' and use_gpu:
        torch.cuda.empty_cache()  # clear gpus after processing

    return dfts_for_network

def configure_epochs(file_paths, num_bins=None, fft_size=256, ignore_temporality=False, *args, **kwargs):
    full_results, filename_mapping, new_file_paths, outputs = epochset(file_paths, fft_size=fft_size, *args, **kwargs)
    epoch_data = {}

    for file_path, normalized_name in filename_mapping.items():
        for key in [f"original_{normalized_name}", f"lossy_{normalized_name}"]:
            if key in full_results:
                dfts = full_results[key]
                if normalized_name not in epoch_data:
                    epoch_data[normalized_name] = {}
                if 'original' in key:
                    processing_type = 'original'
                else:
                    processing_type = 'lossy'

                epoch_data[normalized_name][processing_type] = {'real': {}, 'imaginary': {}}
                for component_type in ['real', 'imaginary']:
                    if component_type in dfts:
                        bin_limit = num_bins.get(normalized_name, None) if num_bins else None
                        for bin_key, values in dfts[component_type].items():
                            if 'bin' in bin_key:
                                bin_index = int(bin_key.split('_')[-1])
                                if bin_limit is None or bin_index < bin_limit:
                                    epoch_data[normalized_name][processing_type][component_type][bin_key] = values

    if ignore_temporality:
        for basename in epoch_data:
            for processing_type in epoch_data[basename]:
                for component_type in ['real', 'imaginary']:
                    bin_keys = list(epoch_data[basename][processing_type][component_type].keys())
                    random.shuffle(bin_keys)
                    shuffled_bins = {key: epoch_data[basename][processing_type][component_type][key] for key in bin_keys}
                    epoch_data[basename][processing_type][component_type] = shuffled_bins

    return epoch_data, outputs

def normalize_filename(file_path):
    base_name = os.path.splitext(os.path.basename(file_path))[0]
    normalized_name = base_name.replace(' ', '_').replace('-', '_').lower()
    return normalized_name

def append_dataset(directory):
    file_list = []
    for root, dirs, files in os.walk(directory):
        for file in files:
            full_path = os.path.join(root, file)
            if len(file_list) % 2 == 0:
                file_list.append(f"{full_path}")
            else:
                file_list.append(f'{full_path}')
    return file_list

def macleodchance(file_path, sample_size):
    with open(file_path, 'r') as file:
        lines = file.readlines()
    lines = [line.strip() for line in lines if line.strip()]
    sampled_lines = random.sample(lines, min(sample_size, len(lines)))
    return sampled_lines

def macleodownload(base_url, paths, directory):
    os.makedirs(directory, exist_ok=True)
    for path in paths:
        url = base_url + path
        file_name = path
        file_path = os.path.join(directory, file_name)
        try:
            response = requests.get(url)
            response.raise_for_status()
            with open(file_path, 'wb') as file:
                file.write(response.content)
            print(f"Downloaded {url} to {file_path}")
        except requests.exceptions.RequestException as e:
            print(f"Failed to download {url}. Error: {e}")

def nab_fileids(sampled_data):
    filearr = []
    for file_id, processing_types in sampled_data.items():
        filearr.append(file_id)
    return filearr

def save_svg(fig, filename):
    output_dir = '/content/Model/graphics'
    os.makedirs(output_dir, exist_ok=True)
    fig.savefig(os.path.join(output_dir, f'{filename}.svg'), format='svg')

# Stochastic Processes + Visualizations

In [None]:
### Kamada-Kawai Helpers using dijekstra lengths (for gpu acceleration)
class KamadaKawaiOptimizer(nn.Module):
    def __init__(self, D, K):
        super(KamadaKawaiOptimizer, self).__init__()
        self.D = D
        self.K = K
        self.pos = nn.Parameter(torch.rand(D.shape[0], 2))

    def forward(self):
        dist = torch.cdist(self.pos, self.pos, p=2)
        delta = self.D - dist
        energy = 0.5 * torch.sum(self.K * delta**2)
        return energy

def compute_kamada_kawai_positions(D, K, num_epochs=1000, learning_rate=0.01):
    model = KamadaKawaiOptimizer(D, K)
    optimizer = optim.Adam(model.parameters(), lr=learning_rate)

    for epoch in range(num_epochs):
        optimizer.zero_grad()
        loss = model()
        loss.backward()
        optimizer.step()

    return model.pos.detach().numpy()

def initialize_kamada_kawai_layout(G, num_epochs=1000, learning_rate=0.01):
    num_nodes = G.number_of_nodes()
    D = np.zeros((num_nodes, num_nodes))
    K = np.zeros((num_nodes, num_nodes))

    lengths = dict(nx.all_pairs_dijkstra_path_length(G))

    for i in range(num_nodes):
        for j in range(num_nodes):
            if i != j:
                D[i, j] = lengths[i][j]
                K[i, j] = 1 / D[i, j]**2

    D = torch.tensor(D, dtype=torch.float32)
    K = torch.tensor(K, dtype=torch.float32)

    return compute_kamada_kawai_positions(D, K, num_epochs=num_epochs, learning_rate=learning_rate)

#### ^^ OPtionAL
### Helpers:
def create_color_mapping(file_ids, cmap_name='tab10'):
    color_cycle = plt.get_cmap(cmap_name).colors  # Get color cycle from colormap
    color_iterator = cycle(color_cycle)  # Create an iterator that cycles through the colors
    file_id_to_color = {file_id: next(color_iterator) for file_id in file_ids}
    return file_id_to_color

def get_sorted_bin_keys(data):
    # Extract keys and sort them based on the numerical suffix
    bin_keys = list(data.keys())
    bin_keys.sort(key=lambda x: int(re.search(r'\d+', x).group()))
    return bin_keys

### sampling functions and logging functions
def print_initial_bin_counts(sampled_data, epoch_data, cross_file_sampling, file_ids):
    if cross_file_sampling:
        print("Initial bin counts:")
        for processing_type in ['original', 'lossy']:
            for component_type in ['real', 'imaginary']:
                bins_count = sum([len(epoch_data[file_id][processing_type][component_type]) for file_id in file_ids])
                print(f"{processing_type.capitalize()} - {component_type.capitalize()}: {bins_count} bins")
    else:
        for file_id in file_ids:
            print(f"Initial bin counts for file_id: {file_id}")
            for processing_type in ['original', 'lossy']:
                for component_type in ['real', 'imaginary']:
                    bins_count = len(epoch_data[file_id][processing_type][component_type])
                    print(f"{processing_type.capitalize()} - {component_type.capitalize()}: {bins_count} bins")

def print_final_bin_counts(sampled_data, cross_file_sampling, file_ids):
    if cross_file_sampling:
        print("Final bin counts:")
        for processing_type in ['original', 'lossy']:
            for component_type in ['real', 'imaginary']:
                bins = list(sampled_data['total_list'][processing_type][component_type].keys())
                print(f"{processing_type.capitalize()} - {component_type.capitalize()}: {len(bins)} unique bins, containing: {bins}")
    else:
        for file_id in file_ids:
            print(f"Final bin counts for file_id: {file_id}")
            for processing_type in ['original', 'lossy']:
                for component_type in ['real', 'imaginary']:
                    bins = list(sampled_data[file_id][processing_type][component_type].keys())
                    print(f"{processing_type.capitalize()} - {component_type.capitalize()}: {len(bins)} unique bins, containing: {bins}")

def stochastic_process(method, unique_bins_list, num_samples, bin_min, bin_max, **params):
    methods = {
        'simple_random_shuffle': simple_random_shuffle,
        'ornstein-uhlenbeck': ornstein_uhlenbeck_process,
        'maximal_entropy': maximal_entropy_process,
        'rossler': rossler_process
    }

    method_func = methods.get(method)

    if not method_func:
        raise ValueError(f"Sampling method '{method}' not recognized")
    return method_func(unique_bins_list, num_samples, bin_min, bin_max, **params)

def simple_random_shuffle(unique_bins_list, num_samples, bin_min, bin_max, **params):
    np.random.shuffle(unique_bins_list)
    indices = np.arange(min(num_samples, len(unique_bins_list)))
    return indices, None

### OU Process + Animators
def ornstein_uhlenbeck(start, num_samples, mean, scale, theta, dt=1):
    """Generates path from a starting point."""
    if num_samples <= 0:
        raise ValueError("num_samples must be a positive integer")
    path = np.zeros(num_samples)
    path[0] = start
    for i in range(1, num_samples):
        drift = theta * (mean - path[i-1]) * dt
        randomness = scale * np.random.normal()
        path[i] = path[i-1] + drift + randomness
    return path

def ornstein_uhlenbeck_process(unique_bins_list, num_samples, bin_min, bin_max, randomize=False, **params):
    # Unpack parameters with defaults and randomization if needed
    start = params.get('start', random.uniform(bin_min, bin_max) if randomize else 0)
    mean = params.get('mean', random.uniform(bin_min, bin_max) if randomize else (bin_min + bin_max) / 2)
    scale = params.get('scale', random.uniform(0.1, 2.0) if randomize else 1)
    theta = params.get('theta', random.uniform(0.05, 0.3) if randomize else 0.15)
    dt = params.get('dt', random.uniform(0.01, 1.0) if randomize else 1)

    path = ornstein_uhlenbeck(start, num_samples, mean, scale, theta, dt)
    indices = np.clip(path, bin_min, bin_max).astype(int)
    if params.get('visualize', False):
        animation = animate_ornstein_uhlenbeck(path, num_samples)
        html_video = HTML(animation.to_html5_video())
        return indices, html_video

    return indices, None

def animate_ornstein_uhlenbeck(ou_path, num_samples, fps=30):
    """Animate the Ornstein-Uhlenbeck process using a given path and save the last frame as an SVG."""
    global graphics_path
    time_constant = 0.09287981859410431

    fig, ax = plt.subplots()
    line, = ax.plot([], [], 'r-', label='OU Process')
    point, = ax.plot([], [], 'ro')

    ax.set_xlim(0, num_samples - 1)
    ax.set_ylim(min(ou_path), max(ou_path))
    ax.legend()

    ax.set_xlabel('bin # in sequence')
    ax.set_ylabel(f'time (s)')
    ax.set_title('Ornstein-Uhlenbeck Process Example')

    def init():
        line.set_data(range(num_samples), ou_path)
        point.set_data([], [])
        return line, point,

    def animate(i):
        point.set_data(i, ou_path[i])
        if i == num_samples - 1:
            fig.savefig(os.path.join(f'{graphics_path}/ornstein-uhlenbeck', 'ou_last_frame.svg'), format='svg', transparent=True, bbox_inches='tight')
        return line, point,

    os.makedirs(graphics_path, exist_ok=True)  # Ensure the directory exists
    ani = animation.FuncAnimation(fig, animate, init_func=init, frames=num_samples, interval=1000 / fps, blit=True)
    ani.save(os.path.join(f'{graphics_path}/ornstein-uhlenbeck', 'ou_animation.mp4'), writer='ffmpeg', fps=fps)
    plt.close(fig)
    return ani
### rossler process + animators

def random_color():
    """Generate a random color."""
    return np.random.rand(3,)

def sanitize(value, default=0.0):
    """Sanitize a value to ensure it is a finite real number."""
    if not np.isfinite(value):
        return default
    return value

def rossler_attractor(start, num_samples, a, b, c, dt):
    x, y, z = start
    path_x = [x]
    path_y = [y]
    path_z = [z]
    for _ in range(num_samples - 1):
        dx = -y - z
        dy = x + a * y
        dz = b + z * (x - c)
        x += dx * dt
        y += dy * dt
        z += dz * dt
        path_x.append(x)
        path_y.append(y)
        path_z.append(z)
    return np.array([path_x, path_y, path_z]).T

def rossler_process(unique_bins_list, num_samples, bin_min, bin_max, randomize=False, **params):
    retries = 500
    while retries > 0:
        if randomize:
            start = (
                sanitize(random.uniform(bin_min, bin_max)),
                sanitize(random.uniform(-10, 10)),
                sanitize(random.uniform(-10, 10))
            )
            a = sanitize(random.uniform(0.5, 2.0))
            b = sanitize(random.uniform(0.5, 3.0))
            c = sanitize(random.uniform(0.1, 1.5))
            dt = sanitize(random.uniform(0.01, 0.1))
        else:
            start = (bin_min, 0, 0)
            a = 1.2
            b = 1.8
            c = 0.7
            dt = 0.04

        r_params = sanitize_params({
            'start': start,
            'num_samples': num_samples,
            'a': a,
            'b': b,
            'c': c,
            'dt': dt,
        })

        process_path = rossler_attractor(**r_params)
        z_component = process_path[:, 2]

        # Check if process_path is valid
        if np.all(np.isfinite(process_path)) and np.any(z_component != 0):
            normalized_indices = (z_component - np.min(z_component)) / (np.max(z_component) - np.min(z_component))
            indices = np.clip((normalized_indices * (bin_max - bin_min) + bin_min).astype(int), 0, bin_max)

            if params.get('visualize', False):
                animation = animate_rossler(process_path, num_samples, indices, r_params)
                return indices, animation

            return indices, None

        retries -= 1

    raise ValueError("Failed to generate valid Rossler attractor parameters after multiple retries")

def animate_rossler(process_path, num_samples, indices, r_params):
    """Animate the Rössler attractor and save the last frame as an SVG."""
    global graphics_path
    time_constant = 0.09287981859410431
    x, y, z = process_path.T

    # Sanitize limits to avoid NaN or Inf
    x_min, x_max = sanitize(np.min(x), -1), sanitize(np.max(x), 1)
    y_min, y_max = sanitize(np.min(y), -1), sanitize(np.max(y), 1)
    z_min, z_max = sanitize(np.min(z), -1), sanitize(np.max(z), 1)

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')

    color_z = random_color()
    color_x = random_color()
    color_y = random_color()

    line_z, = ax.plot([], [], [], lw=2, label='Z-component', color=color_z)
    line_x, = ax.plot([], [], [], lw=1, label='X-component', color=color_x, alpha=0.5)
    line_y, = ax.plot([], [], [], lw=1, label='Y-component', color=color_y, alpha=0.5)

    ax.set_xlim([x_min, x_max])
    ax.set_ylim([y_min, y_max])
    ax.set_zlim([z_min, z_max])

    ax.set_title('Rössler Attractor Animation')

    # Create a second legend for the random parameters and time
    param_legend_text = f"Parameters:\n" \
                        f"a = {r_params['a']:.2f}\n" \
                        f"b = {r_params['b']:.2f}\n" \
                        f"c = {r_params['c']:.2f}\n" \
                        f"dt = {r_params['dt']:.2f}\n" \
                        f"Time (s): {0:.2f}"
    param_legend = ax.text2D(0.05, 0.95, param_legend_text, transform=ax.transAxes, fontsize=10, verticalalignment='top', bbox=dict(facecolor='white', alpha=0.5))

    def init():
        line_z.set_data([], [])
        line_z.set_3d_properties([])
        line_x.set_data([], [])
        line_x.set_3d_properties([])
        line_y.set_data([], [])
        line_y.set_3d_properties([])
        return line_z, line_x, line_y, param_legend

    def animate(i):
        global graphics_path
        line_z.set_data(x[:i], y[:i])
        line_z.set_3d_properties(z[:i])
        line_x.set_data(x[:i], [0] * i)
        line_x.set_3d_properties(z[:i])
        line_y.set_data([0] * i, y[:i])
        line_y.set_3d_properties(z[:i])

        # Update the legend with the current time value
        param_legend.set_text(f"Parameters:\n"
                              f"a = {r_params['a']:.2f}\n"
                              f"b = {r_params['b']:.2f}\n"
                              f"c = {r_params['c']:.2f}\n"
                              f"dt = {r_params['dt']:.2f}\n"
                              f"Time (s): {indices[i] * time_constant:.2f}")

        if i == num_samples - 1:  # Save last frame
            fig.savefig(os.path.join(f'{graphics_path}/rossler', 'rossler_last_frame.svg'), format='svg', transparent=True, bbox_inches='tight')
        return line_z, line_x, line_y, param_legend

    ani = animation.FuncAnimation(fig, animate, init_func=init, frames=num_samples, interval=50, blit=True)
    plt.close(fig)
    output_dir = graphics_path
    os.makedirs(output_dir, exist_ok=True)
    ani.save(os.path.join(f'{output_dir}/rossler', 'rossler_animation.mp4'), writer='ffmpeg', fps=30)

    return ani

### MERW + animators
def maximal_entropy_process(unique_bins_list, num_samples, bin_min, bin_max, **params):
    L = len(unique_bins_list)
    print(f"Total unique bins: {L}")  # Debugging output
    q = params.get('connectivity', 0.7)
    G = initialize_irregular_lattice(L, q)
    path = maximal_entropy_random_walk(G, L)
    if G.number_of_edges() == 0:
        print("Warning: No edges in the graph. Check connectivity parameter and node initialization.")
    if not path:
        print("Warning: Random walk path is empty.")
    indices = path
    if not indices:
        print("Warning: No indices generated from the path.")
    if params.get('visualize', False):
        animation = animate_random_walk(params, G, path, L, q)
        return indices, animation
    return indices, None

def maximal_entropy_random_walk(G, num_samples):
    A = nx.adjacency_matrix(G).todense()
    evals, evecs = np.linalg.eigh(A)
    lambda_max = np.max(evals)
    psi = evecs[:, evals.argmax()]
    psi_normalized = psi / np.linalg.norm(psi, ord=2)
    P = np.zeros_like(A, dtype=float)
    for i in range(len(G.nodes)):
        for j in range(len(G.nodes)):
            if A[i, j] > 0:
                P[i, j] = (A[i, j] / lambda_max) * (psi_normalized[j] / psi_normalized[i])
    P /= P.sum(axis=1, keepdims=True)  # Normalize the transition probability matrix
    path = [np.random.randint(len(G.nodes))]
    for _ in range(1, num_samples):
        current = path[-1]
        if np.any(np.isnan(P[current])):
            raise ValueError("Transition probabilities contain NaN.")
        next_node = np.random.choice(len(G.nodes), p=P[current])
        path.append(next_node)
    return path

def initialize_irregular_lattice(num_nodes, connectivity):
    G = nx.Graph()
    G.add_nodes_from(range(num_nodes))
    for i in range(num_nodes):
        for j in range(i + 1, num_nodes):
            if np.random.rand() < connectivity:
                G.add_edge(i, j)
    return G

def animate_random_walk(params, G, path, L, q):
    global graphics_path
    num_samples = params.get('num_samples', 1500)
    fps = params.get('fps', 30)  # frames per second
    total_duration = params.get('total_duration', 50)  # duration of the animation in seconds
    frames = fps * total_duration  # total frames shown in the animation
    highest_step = min(frames, len(path) - 1)  # Highest step to determine color scale
    pos = nx.kamada_kawai_layout(G)  # use gpu here if need
    fig, ax = plt.subplots(figsize=(10, 12), dpi=300)  # Increased DPI for higher resolution
    ax.set_xlim([min(p[0] for p in pos.values()) - 0.2, max(p[0] for p in pos.values()) + 0.2])
    ax.set_ylim([min(p[1] for p in pos.values()) - 0.2, max(p[1] for p in pos.values()) + 0.2])
    ax.set_title('Maximal Entropy Random Walk Visualization', fontsize=16, color='black')
    # Edge and node setup
    edge_list = list(G.edges())
    edge_colors = [to_rgba('grey', alpha=0.10)] * len(edge_list)
    edges = nx.draw_networkx_edges(G, pos, ax=ax, edgelist=edge_list, edge_color=edge_colors, width=1)
    color_map = plt.get_cmap('plasma')
    nodes = nx.draw_networkx_nodes(G, pos, ax=ax, node_color='grey', alpha=0.4, cmap=color_map, node_size=50)

    visited = {}
    node_color_values = np.zeros(L)
    current_text = ax.text(0.05, 0.95, '', transform=ax.transAxes, color='black')

    def update(num):
        current_node = path[num]
        if current_node not in visited:
            visited[current_node] = num
        for n in range(L):
            node_color_values[n] = visited.get(n, 0) / highest_step
        nodes.set_array(node_color_values)
        nodes.set_alpha(1.0)
        if num < len(path) - 1:
            next_node = path[num + 1]
            edge = (current_node, next_node) if (current_node, next_node) in edge_list else (next_node, current_node)
            if edge in edge_list:
                edge_index = edge_list.index(edge)
                edge_colors[edge_index] = color_map(visited.get(current_node, 0) / highest_step)
                edges.set_edgecolor(edge_colors)
        current_bin_id = visited.get(current_node, 0)
        current_text.set_text(f'Time Course (s): {current_bin_id * 0.09287981859410431:.2f}')

        return nodes, edges, current_text

    # Color bar setup
    sm = plt.cm.ScalarMappable(cmap=color_map, norm=plt.Normalize(vmin=0, vmax=highest_step))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, orientation='vertical', fraction=0.02, pad=0.04)
    cbar.set_label('Step in Random Walk')

    ani = animation.FuncAnimation(fig, update, frames=highest_step, repeat=False, blit=True)
    ani.save(os.path.join(f'{graphics_path}/maximal_entropy', 'random_walk_animation.mp4'), writer='ffmpeg', fps=fps, extra_args=['-vcodec', 'libx264', '-pix_fmt', 'yuv420p', '-crf', '18', '-preset', 'veryfast'])
    plt.close(fig)

    # Save the last frame as SVG
    nodes.set_array(node_color_values)  # Set final colors for nodes
    nodes.set_alpha(1.0)  # Ensure all nodes are fully opaque in the final frame
    fig.savefig(os.path.join(f'{graphics_path}/maximal_entropy', 'final_random_walk_frame.svg'), format='svg', transparent=True, bbox_inches='tight')

    return ani


### Main sampling and plotting:
def plot_indices(indices, total_bins, method_name, file_id_colors=None, file_ids_at_indices=None):
    """Plot the indices as a time series to visualize the selection sequence and color them by file_id."""
    print("Indices length:", len(indices))
    print("File IDs at Indices length:", len(file_ids_at_indices))
    print("Sample indices values:", indices[:10])
    print("Sample file IDs at indices values:", file_ids_at_indices[:10])

    # Check and handle total_bins if it's an np.array
    if isinstance(total_bins, np.ndarray):
        print("total_bins is an np.array:", total_bins)
        total_bins_count = len(total_bins)  # Assuming you want the count of bins if it's an array
    else:
        total_bins_count = total_bins

    print(f'Total bins count: {total_bins_count}')

    label_step = max(1, total_bins_count // 30)
    label_size = min(10, max(5, 500 // total_bins_count))

    plt.figure(figsize=(20, 6))
    if file_id_colors and file_ids_at_indices:
        colors = [file_id_colors.get(file_id, 'grey') for file_id in file_ids_at_indices]
    else:
        colors = ['blue'] * len(indices)
    colors = colors[:len(indices)]
    unique_colors = set(colors)
    color_labels = {color: file_id for file_id, color in file_id_colors.items() if color in unique_colors}
    for color, label in color_labels.items():
        specific_indices = [i for i, c in enumerate(colors) if c == color]
        specific_indices = np.clip(specific_indices, 0, len(indices) - 1)
        specific_values = [indices[i] for i in specific_indices]
        plt.scatter(specific_indices, specific_values, color=color, label=label)
    plt.plot(indices, 'grey', label='sampling sequence', alpha=0.4)  # Using grey to keep focus on points
    bin_labels = [str(i) if i % label_step == 0 else '' for i in range(total_bins_count)]
    plt.yticks(np.arange(0, total_bins_count, label_step), bin_labels[::label_step], fontsize=label_size)
    plt.xlabel('Sample Step')
    plt.ylabel('Bin ID')
    plt.title(f'Bin Sampling Sequence for {method_name}')
    plt.legend(title="File IDs")
    plt.grid(True)
    global graphics_path
    output_dir = f'{graphics_path}/{method_name}'
    os.makedirs(output_dir, exist_ok=True)
    plt.savefig(os.path.join(output_dir, f'sample_plot_{method_name}.svg'), format='svg', transparent=True, bbox_inches='tight')
    plt.show()

# Sampling Procedure / Database Structuring

In [None]:
def sample_data(epoch_data, file_ids, indices_color_mapping=None, cross_file_sampling=False, method='ornstein-uhlenbeck', sample_size=1, clip_bins=False, use_fallback=True, **params):
    all_samples = {}
    if cross_file_sampling:
        sampled_data = {'all': {'original': {'real': {}, 'imaginary': {}},
                                  'lossy': {'real': {}, 'imaginary': {}}}}
        chosen_file_ids = []
    else:
        sampled_data = {file_id: {'original': {'real': {}, 'imaginary': {}},
                                  'lossy': {'real': {}, 'imaginary': {}}}
                        for file_id in file_ids}

    print_initial_bin_counts(sampled_data, epoch_data, cross_file_sampling, file_ids)

    all_bins = {}
    for file_id in file_ids:
        print(f"Processing file_id: {file_id}")
        if file_id not in epoch_data:
            raise ValueError(f"No data found for file_id '{file_id}'.")
        for processing_type in ['original', 'lossy']:
            for component_type in ['real', 'imaginary']:
                key = (file_id, processing_type, component_type)
                if key not in all_bins:
                    all_bins[key] = []
                all_bins[key] = get_sorted_bin_keys(epoch_data[file_id][processing_type][component_type])
    # Flatten the list of lists of keys and then sort uniquely
    flat_list_of_keys = [bin_key for sublist in all_bins.values() for bin_key in sublist]
    unique_bins_list = get_sorted_bin_keys({key: None for key in set(flat_list_of_keys)})

    print(f"Unique bins list: {unique_bins_list}")
    num_samples = min(sample_size, len(unique_bins_list))
    bin_min, bin_max = 0, len(unique_bins_list) - 1
    indices, html_video = stochastic_process(method, unique_bins_list, num_samples, bin_min, bin_max, **params)
    if clip_bins:
        selected_bins = wrap_indices(indices, unique_bins_list)
    else:
        selected_bins = [unique_bins_list[i] for i in indices if 0 <= i < len(unique_bins_list)]
    print(f"Selected bins: {selected_bins}")
    file_ids_at_indices = []
    if cross_file_sampling:
        bin_to_file = copy_selected_bins_cross_file(sampled_data, all_bins, epoch_data, selected_bins)
        file_ids_at_indices = [bin_to_file.get(bin_key, 'Unknown') for bin_key in selected_bins]
    else:
        for file_id in file_ids:
            for processing_type in ['original', 'lossy']:
                for component_type in ['real', 'imaginary']:
                    target = sampled_data[file_id][processing_type][component_type]
                    source = epoch_data[file_id][processing_type][component_type]
                    copy_selected_bins(target, [source], selected_bins, use_fallback)
    sampleset_timeseries = plot_indices(indices, len(indices), method, indices_color_mapping, file_ids_at_indices)
    return sampled_data, sampleset_timeseries, file_ids_at_indices if cross_file_sampling else []


In [None]:
# strict typing
def copy_selected_bins_cross_file(
    target: Dict[str, Any],
    all_bins: Dict[Tuple[str, str, str], List[str]],
    epoch_data: Dict[str, Dict[str, Dict[str, Dict[str, Any]]]],
    selected_bins: List[str]
) -> Dict[str, str]:
    """Populate the target dictionary with data associated with each selected bin, with type checks and diagnostics."""
    bin_to_file = {}
    for bin_key in selected_bins:
        available_files = [file_id for (file_id, p_type, c_type), bins in all_bins.items() if bin_key in bins]
        if available_files:
            chosen_file_id = random.choice(available_files)
            bin_to_file[bin_key] = chosen_file_id
            for processing_type in ['original', 'lossy']:
                for component_type in ['real', 'imaginary']:
                    source = epoch_data[chosen_file_id][processing_type][component_type]
                    if bin_key in source:
                        if 'all' not in target:
                            target['all'] = {}
                        if processing_type not in target['all']:
                            target['all'][processing_type] = {}
                        if component_type not in target['all'][processing_type]:
                            target['all'][processing_type][component_type] = []
                        elif isinstance(target['all'][processing_type][component_type], dict):
                            target['all'][processing_type][component_type] = []  # Resetting it to a list if it's incorrectly a dict
                        target['all'][processing_type][component_type].append({
                            'bin_key': bin_key,
                            'data': source[bin_key]
                        })
        else:
            print(f"No available files for bin {bin_key}")

    return bin_to_file

def dump_lists_to_directories(lists_dict):
    base_path = "/content/Model/graphics"
    dir_map = {
        "_me": "maximal_entropy",
        "_rossler": "rossler",
        "_ou": "ornstein-uhlenbeck",
        "_srs": "simple_random_shuffle"
    }

    for list_name, list_obj in lists_dict.items():
        pattern = None
        optional_number = None

        for key in dir_map.keys():
            if list_name.endswith(key):
                pattern = key
                break

        if pattern is None:
            print(f"No valid pattern found in list name: {list_name}")
            continue

        try:
            optional_number = int(''.join(filter(str.isdigit, list_name.split(pattern)[-1])))
        except ValueError:
            optional_number = None

        if optional_number is not None:
            target_dir = os.path.join("/content/Model", f"graphics_sample{optional_number}", dir_map[pattern])
        else:
            target_dir = os.path.join(base_path, dir_map[pattern])

        os.makedirs(target_dir, exist_ok=True)

        file_path = os.path.join(target_dir, f"{list_name}_order_of_bins.json")
        with open(file_path, 'w') as f:
            json.dump(list_obj, f)

        print(f"List '{list_name}' dumped to {file_path}")


In [None]:
def save_dict_as_tar_gz(dictionary, filename):
    pickle_filename = 'dictionary.pkl'
    with open(pickle_filename, 'wb') as f:
        pickle.dump(dictionary, f)
    gzip_filename = f'{pickle_filename}.gz'
    with open(pickle_filename, 'rb') as f_in:
        with gzip.open(gzip_filename, 'wb') as f_out:
            f_out.writelines(f_in)
    tar_gz_filename = f'{filename}.tar.gz'
    with tarfile.open(tar_gz_filename, 'w:gz') as tar:
        tar.add(gzip_filename, arcname=os.path.basename(gzip_filename))
    os.remove(pickle_filename)
    os.remove(gzip_filename)

    print(f'Dictionary saved and compressed as {tar_gz_filename}')

def load_dict_from_tar_gz(filename):
    with tarfile.open(filename, 'r:gz') as tar:
        tar.extractall()
    gzip_filename = 'dictionary.pkl.gz'
    with gzip.open(gzip_filename, 'rb') as f_in:
        with open('dictionary.pkl', 'wb') as f_out:
            f_out.writelines(f_in)
    with open('dictionary.pkl', 'rb') as f:
        dictionary = pickle.load(f)
    os.remove('dictionary.pkl')
    os.remove(gzip_filename)
    print(f'Dictionary loaded from {filename}')
    return dictionary

In [None]:
def generate_model_path(base_path, epoch, val_loss, stochastic_process_name):
    formatted_val_loss = f"{val_loss:.4f}"
    return f"{base_path}_{stochastic_process_name}_epoch_{epoch}_val_loss_{formatted_val_loss}.h5"

def save_model_and_config(model, filepath):
    ensure_directory_exists(filepath)
    model.save(filepath, save_format='h5')
    model_json = model.to_json()
    json_path = f"{filepath}_config.json"
    with open(json_path, 'w') as json_file:
        json_file.write(model_json)
    print(f"Model and configuration saved to {filepath} and {json_path}")

def extract_data(sample):
    return sample['data']

def prepare_data(real_samples, imag_samples):
    real = np.vstack([extract_data(r) for r in real_samples])
    imag = np.vstack([extract_data(i) for i in imag_samples])
    real = real.reshape(-1, 100, 1)
    imag = imag.reshape(-1, 100, 1)
    combined = np.concatenate([real, imag], axis=-1)  # Now shape is (-1, 100, 2)
    return combined

def prepare_segmented_data(data, samples_per_segment):
    num_samples = data.shape[0]
    num_segments = num_samples // samples_per_segment
    if num_samples % samples_per_segment != 0:
        raise ValueError("The total number of samples is not a multiple of samples per segment.")
    segmented_data = data[:num_segments * samples_per_segment].reshape(-1, samples_per_segment, 100, 2)
    return segmented_data

def prepare_tensorflow_dataset(prepped_data, batch_size):
    dataset_shape = prepped_data.shape
    flattened_data = prepped_data.reshape(-1, dataset_shape[2], dataset_shape[3])  # Reshape to (-1, 100, 2)
    dataset = tf.data.Dataset.from_tensor_slices((flattened_data, flattened_data))
    return dataset.batch(batch_size)

# Model Terms / Definitions

In [None]:
tfd = tfp.distributions
tfb = tfp.bijectors

# Clear previous session
K.clear_session()

# Manage GPU memory growth
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

# Define the Sampling layer
class Sampling(layers.Layer):
    def call(self, inputs):
        mean, logvar = inputs
        epsilon = tf.keras.backend.random_normal(shape=tf.shape(mean))
        return mean + tf.exp(0.5 * logvar) * epsilon

# Define the Encoder
def create_encoder(input_shape, latent_dim):
    inputs = tf.keras.Input(shape=input_shape)
    x = layers.Flatten()(inputs)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dense(64, activation='relu')(x)
    z_mean = layers.Dense(latent_dim)(x)
    z_logvar = layers.Dense(latent_dim)(x)
    z = Sampling()([z_mean, z_logvar])
    encoder = Model(inputs, [z_mean, z_logvar, z], name='encoder')
    return encoder

# Define the Decoder
def create_decoder(output_shape, latent_dim):
    latent_inputs = tf.keras.Input(shape=(latent_dim,))
    x = layers.Dense(64, activation='relu')(latent_inputs)
    x = layers.Dense(128, activation='relu')(x)
    x = layers.Dense(tf.reduce_prod(output_shape), activation='sigmoid')(x)
    outputs = layers.Reshape(output_shape)(x)
    decoder = Model(latent_inputs, outputs, name='decoder')
    return decoder

# Define the IAF Bijector Layer
class IAFBijectorLayer(layers.Layer):
    def __init__(self, latent_dim, **kwargs):
        super().__init__(**kwargs)
        self.tfd.TransformedDistribution(
    distribution=tfd.Sample(
        tfd.Normal(loc=0., scale=1.), sample_shape=[dims]),
    bijector=tfb.Invert(tfb.MaskedAutoregressiveFlow(
        shift_and_log_scale_fn=tfb.AutoregressiveNetwork(
            params=2, hidden_units=[512, 512]))))


    def call(self, inputs):
        return self.iaf.forward(inputs)

# Define the VAE
def create_vae(input_shape, latent_dim):
    encoder = create_encoder(input_shape, latent_dim)
    decoder = create_decoder(input_shape, latent_dim)

    inputs = tf.keras.Input(shape=input_shape)
    z_mean, z_logvar, z = encoder(inputs)

    z_transformed = IAFBijectorLayer(latent_dim)(z)

    outputs = decoder(z_transformed)

    vae = Model(inputs, outputs, name='vae')

    kl_loss = -0.5 * tf.reduce_mean(
        1 + z_logvar - tf.square(z_mean) - tf.exp(z_logvar)
    )
    vae.add_loss(kl_loss)

    vae.compile(optimizer=tf.keras.optimizers.Adam(), loss='mse')

    return vae, encoder

# Define the LatentSpaceHeatmapCallback
class LatentSpaceHeatmapCallback(Callback):
    def __init__(self, encoder, save_dir="heatmaps"):
        super().__init__()
        self.encoder = encoder
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)

    def on_epoch_end(self, epoch, logs=None):
        self.plot_weights(epoch)

    def plot_weights(self, epoch):
        for layer in self.encoder.layers:
            if isinstance(layer, layers.Dense):
                weights, biases = layer.get_weights()
                plt.figure(figsize=(10, 6))
                sns.heatmap(weights, annot=False, cmap='viridis')
                plt.title(f'Weights of layer {layer.name} at epoch {epoch + 1}')
                plt.xlabel('Output Units')
                plt.ylabel('Input Units')
                plt.savefig(os.path.join(self.save_dir, f'epoch_{epoch + 1}_{layer.name}.png'))
                plt.close()

# Define the SaveModelAfterNBatches callback
class SaveModelAfterNBatches(Callback):
    def __init__(self, save_interval, save_path):
        super().__init__()
        self.save_interval = save_interval
        self.save_path = save_path
        self.batch_count = 0

    def on_batch_end(self, batch, logs=None):
        self.batch_count += 1
        if self.batch_count % self.save_interval == 0:
            model_save_path = f"{self.save_path}_batch_{self.batch_count}.h5"
            self.model.save(model_save_path)
            print(f"Model saved after {self.batch_count} batches at {model_save_path}")
            
class LatentSpaceHeatmapCallback(Callback):
    def __init__(self, vae, save_dir="heatmaps"):
        super().__init__()
        self.encoder = vae.get_layer('encoder')
        self.save_dir = save_dir
        os.makedirs(self.save_dir, exist_ok=True)

    def on_batch_end(self, batch, logs=None):
        self.plot_weights(batch)

    def plot_weights(self, batch):
        for layer in self.encoder.layers:
            if isinstance(layer, layers.Dense):
                weights, biases = layer.get_weights()
                plt.figure(figsize=(10, 6))
                sns.heatmap(weights, annot=False, cmap='viridis')
                plt.title(f'Weights of layer {layer.name} at batch {batch + 1}')
                plt.xlabel('Output Units')
                plt.ylabel('Input Units')
                plt.savefig(os.path.join(self.save_dir, f'batch_{batch + 1}_{layer.name}.png'))
                plt.close()

class SaveModelAfterNBatches(Callback):
    def __init__(self, save_interval, save_path):
        super().__init__()
        self.save_interval = save_interval
        self.save_path = save_path
        self.batch_count = 0

    def on_batch_end(self, batch, logs=None):
        self.batch_count += 1
        if self.batch_count % self.save_interval == 0:
            timestamp = time.strftime("%Y%m%d-%H%M%S")
            model_save_path = f"{self.save_path}_batch_{self.batch_count}_{timestamp}.h5"
            self.model.save(model_save_path)
            print(f"Model saved after {self.batch_count} batches at {model_save_path}")


# Define the data preparation functions
def generate_model_path(base_path, epoch, val_loss, stochastic_process_name):
    formatted_val_loss = f"{val_loss:.4f}"
    return f"{base_path}_{stochastic_process_name}_epoch_{epoch}_val_loss_{formatted_val_loss}.h5"

def save_model_and_config(model, filepath):
    ensure_directory_exists(filepath)
    model.save(filepath, save_format='h5')
    model_json = model.to_json()
    json_path = f"{filepath}_config.json"
    with open(json_path, 'w') as json_file:
        json_file.write(model_json)
    print(f"Model and configuration saved to {filepath} and {json_path}")

def extract_data(sample):
    return sample['data']

def prepare_data(real_samples, imag_samples):
    real = np.vstack([extract_data(r) for r in real_samples])
    imag = np.vstack([extract_data(i) for i in imag_samples])
    real = real.reshape(-1, 100, 1)
    imag = imag.reshape(-1, 100, 1)
    combined = np.concatenate([real, imag], axis=-1)  # Now shape is (-1, 100, 2)
    return combined

def prepare_tensorflow_dataset(prepped_data, batch_size):
    dataset_shape = prepped_data.shape
    dataset = tf.data.Dataset.from_tensor_slices((prepped_data, prepped_data))
    return dataset.batch(batch_size)

# Remux + IFFT Output

In [None]:
# Not yet implemented - just requires me to set aside time to reverse process of tensor stacking.

# Main Function Calls

In [None]:
### Pre-Processing

# Randomly sample the Kevin Macleod catalog from online, and download the 10 songs chosen in an organized fashion
pathinit = append_base_path(workingdirectory, urlstxtdir)[0] + '/urls.txt'
macleodsample = macleodchance(pathinit, 10)
audiopaths = append_base_path(workingdirectory, audiosampledir)[0]
macleodownthemall = macleodownload(macleonline, macleodsample, audiopaths)
audiorandomsample = append_dataset(audiopaths)

# Do the pre-processing (cochlear filtration, lossy and lossless coding, DFT, and dictionary structuring)
experimental_dataset, outputs = configure_epochs(audiorandomsample, duration=10, fft_size=4096, downloadfile=False, rectify=False, min_frequency=20, max_frequency=20000, num_frequency_bins=400, bw=None, with_lossy_backprop=True, UseDualMono=False, backend='pytorch', use_gpu=True)

# Auto backup the processed sets of arrays for each file 
save_path = append_base_path(workingdirectory, arraydir)
exarrays, extract_timestamp = extract_and_save_data_arrays_with_timestamp(experimental_dataset)
with open(f'{save_path[0]}/sampledfrommacleod_{extract_timestamp}.json', 'w') as file:
    json.dump(macleodsample, file)
    
# reload the entirety of the dataset and unpack into the correct structure in memory if necessary by uncommenting ->
# test = load_and_rebuild_structure('/content/Model/Preprocessed_Arrays/metadata_1715470570.json',
#                                         '/content/Model/Preprocessed_Arrays/saved_arrays_1715470570.npz')


In [None]:
# Run stochastic methods + generate visualizations of each window over time - across an audio file or within audio files.
# Five runs of each stochastic process, using the sample_data function, returning three objects each.

all_files = nab_fileids(experimental_dataset)
cmap = create_color_mapping(all_files, 'Dark2')
graphics_path = '/content/Model/graphics_sample'

## one day this will be a loop for i > 5... for now just run the same code snips five times with different namings

### 1/5 runs of stoch methods
# Maximal Entropy Random Walk: 
all_samples_me, timeseries_me, checkcrossfiles_me = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="maximal_entropy", clip_bins=False, use_fallback=False, visualize=True, sample_size=108)

# ^ In sample_data, object one is the set of stacked / stochastically shuffled tensor bins, object two is the matplotlib plot,
# ^ And object three is a list of the audio tracks corresponding to each bin if cross_file_sampling is true.

# Rössler Attractor:
all_samples_rossler, timeseries_rossler, checkcrossfiles_rossler = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="rossler", clip_bins=False, use_fallback=False, visualize=True, sample_size=108)
# Ohnstein-Uhlenbeck:
all_samples_ou, timeseries_ou, checkcrossfiles_ou = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="ornstein-uhlenbeck", clip_bins=False, use_fallback=False, visualize=True, sample_size=108)
# Simple Random Sample:
all_samples_srs, timeseries_srs, checkcrossfiles_srs = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="simple_random_shuffle", clip_bins=False, use_fallback=False, visualize=True, sample_size=108)
set1 = {"checkcrossfiles_me": checkcrossfiles_me, "checkcrossfiles_rossler": checkcrossfiles_rossler, "checkcrossfiles_ou": checkcrossfiles_ou, "checkcrossfiles_srs": checkcrossfiles_srs}
# Dump the dictionary of named lists to the filesystem, alongside the generated visualizations.
dump_lists_to_directories(set1)

#### 2/5 runs
graphics_path = '/content/Model/graphics_sample2'
cmap = create_color_mapping(all_files, 'tab20')
all_samples_me2, timeseries_me2, checkcrossfiles_me2 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="maximal_entropy", clip_bins=False, use_fallback=False, visualize=True, sample_size=108)
all_samples_rossler2, timeseries_rossler2, checkcrossfiles_rossler2 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="rossler", clip_bins=False, use_fallback=False, visualize=True, sample_size=108, randomize=True)
all_samples_ou2, timeseries_ou2, checkcrossfiles_ou2 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="ornstein-uhlenbeck", clip_bins=False, use_fallback=False, visualize=True, sample_size=108, randomize=True)
all_samples_srs2, timeseries_srs2, checkcrossfiles_srs2 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="simple_random_shuffle", clip_bins=False, use_fallback=False, visualize=True, sample_size=108)
set2 = {"checkcrossfiles_me2": checkcrossfiles_me, "checkcrossfiles_rossler2": checkcrossfiles_rossler, "checkcrossfiles_ou2": checkcrossfiles_ou, "checkcrossfiles_srs2": checkcrossfiles_srs}
dump_lists_to_directories(set2)

#### 3/5 runs
graphics_path = '/content/Model/graphics_sample3'
cmap = create_color_mapping(all_files, 'tab20b')
all_samples_me3, timeseries_me3, checkcrossfiles_me3 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="maximal_entropy", clip_bins=False, use_fallback=False, visualize=True, sample_size=108)
all_samples_rossler3, timeseries_rossler3, checkcrossfiles_rossler3 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="rossler", clip_bins=False, use_fallback=False, visualize=True, sample_size=108, randomize=True)
all_samples_ou3, timeseries_ou3, checkcrossfiles_ou3 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="ornstein-uhlenbeck", clip_bins=False, use_fallback=False, visualize=True, sample_size=108, randomize=True)
all_samples_srs3, timeseries_srs3, checkcrossfiles_srs3 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="simple_random_shuffle", clip_bins=False, use_fallback=False, visualize=True, sample_size=108)
set3 = {"checkcrossfiles_me3": checkcrossfiles_me, "checkcrossfiles_rossler3": checkcrossfiles_rossler, "checkcrossfiles_ou3": checkcrossfiles_ou, "checkcrossfiles_srs3": checkcrossfiles_srs}
dump_lists_to_directories(set3)

#### 4/5 runs
graphics_path = '/content/Model/graphics_sample4'
cmap = create_color_mapping(all_files, 'tab20c')
all_samples_me4, timeseries_me4, checkcrossfiles_me4 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="maximal_entropy", clip_bins=False, use_fallback=False, visualize=True, sample_size=108)
all_samples_rossler4, timeseries_rossler4, checkcrossfiles_rossler4 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="rossler", clip_bins=False, use_fallback=False, visualize=True, sample_size=108, randomize=True)
all_samples_ou4, timeseries_ou4, checkcrossfiles_ou4 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="ornstein-uhlenbeck", clip_bins=False, use_fallback=False, visualize=True, sample_size=108, randomize=True)
all_samples_srs4, timeseries_srs4, checkcrossfiles_srs4 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="simple_random_shuffle", clip_bins=False, use_fallback=False, visualize=True, sample_size=108)
set4 = {"checkcrossfiles_me4": checkcrossfiles_me, "checkcrossfiles_rossler4": checkcrossfiles_rossler, "checkcrossfiles_ou4": checkcrossfiles_ou, "checkcrossfiles_srs4": checkcrossfiles_srs}
dump_lists_to_directories(set4)

#### 5/5 runs
graphics_path = '/content/Model/graphics_sample5'
cmap = create_color_mapping(all_files, 'Set3')
all_samples_me5, timeseries_me5, checkcrossfiles_me5 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="maximal_entropy", clip_bins=False, use_fallback=False, visualize=True, sample_size=108)
all_samples_rossler5, timeseries_rossler5, checkcrossfiles_rossler5 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="rossler", clip_bins=False, use_fallback=False, visualize=True, sample_size=108, randomize=True)
all_samples_ou5, timeseries_ou5, checkcrossfiles_ou5 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="ornstein-uhlenbeck", clip_bins=False, use_fallback=False, visualize=True, sample_size=108, randomize=True)
all_samples_srs5, timeseries_srs5, checkcrossfiles_srs5 = sample_data(experimental_dataset, all_files, indices_color_mapping=cmap, cross_file_sampling=True, method="simple_random_shuffle", clip_bins=False, use_fallback=False, visualize=True, sample_size=108)
set5 = {"checkcrossfiles_me5": checkcrossfiles_me, "checkcrossfiles_rossler5": checkcrossfiles_rossler, "checkcrossfiles_ou5": checkcrossfiles_ou, "checkcrossfiles_srs5": checkcrossfiles_srs}
dump_lists_to_directories(set5)

# You can save and load the actual content of each sampled set of audio arrays as a .pkl if need be with a call like so -->
# save_dict_as_tar_gz(all_samples_me, 'merw_save')

# and reload if need be, like so -->
# all_samples_me = load_dict_from_tar_gz('drive/MyDrive/merw_save.tar.gz')

In [None]:

# Example separating the arrays for each all_samples_me - uncomment this when doing your own
real_data_error = np.array(extract_data_arrays(all_samples_me, processing_type='original', component_type='real'))
imaginary_data_error = np.array(extract_data_arrays(all_samples_me, processing_type='original', component_type='imaginary'))
real_data_train = np.array(extract_data_arrays(all_samples_me, processing_type='lossy', component_type='real'))
imaginary_data_train = np.array(extract_data_arrays(all_samples_me, processing_type='lossy', component_type='imaginary'))

train_prep = prepare_data(real_data_train, imaginary_data_train)  # Prepares training data
train_dataset = prepare_tensorflow_dataset(train_prep, 128)       # Prepare TensorFlow dataset for training with batch size 128

val_prep = prepare_data(real_data_error, imaginary_data_error)    # Prepares validation data
val_dataset = prepare_tensorflow_dataset(val_prep, 128)           # Prepare TensorFlow dataset for validation with batch size 128

# Declare and Compile Model
input_shape = (100, 2)
latent_dim = 16

vae, encoder = create_vae(input_shape, latent_dim)
vae.summary()

# Now, add callbacks to the compiled model
heatmap_callback = LatentSpaceHeatmapCallback(vae)
save_model_callback = SaveModelAfterNBatches(save_interval=100, save_path='vae_model')
early_stopping_callback = EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)

# Train the VAE model with the custom callbacks
history = vae.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=50,
    callbacks=[heatmap_callback, save_model_callback, early_stopping_callback]
)

In [None]:

plot_model(
    vae,
    to_file='minimal_model_plot_model.png',
    show_shapes=True,
    show_layer_names=True,
    rankdir='TB'
)

In [None]:
# iPython tar for Colab
# !tar -czvf heatmapz.tar.gz /content/heatmaps