In [1]:
# check if it's a Colab notebook
try:
    import google.colab
    IN_COLAB = True
except:
    IN_COLAB = False


In [2]:
git_branch = 'origina_pipeline'


import os
if IN_COLAB:
    from google.colab import runtime, userdata
    !wget -q https://raw.githubusercontent.com/tsunrise/colab-github/main/colab_github.py
    import colab_github
    colab_github.github_auth(persistent_key=True)
    # https://github.com/tsunrise/colab-github/tree/main

    ! rm -rf submovement_detector
    ! git clone "REMOVED_FOR_ANONYMITY"
    %cd submovement_detector
    !git checkout {git_branch}
    # %cd ..

    import sys
    sys.path.append('/content/submovement_detector')

    !pip install torchviz
    !pip install fastkde

    root_dir = f'/content/drive/MyDrive/submov_nn'
    wandb_key = userdata.get('WANDB_KEY')
    code_dir = '/content/submovement_detector'
else:
    root_dir = f'./submov_nn'
    wandb_key = os.getenv('WANDB_KEY')
    code_dir = '.'

CONFIG_PATH = f'{root_dir}/config/config-0426-ModGaussian_shorter_kernel.yaml'
USE_WANDB = True
VIZUALIZE_BAD_AMPLITUDES = False

datasets_dir = f'{root_dir}/data/'
dataset2path = {
     # 'crank1d':
    'steering': os.path.join(datasets_dir, 'steering_tangential_velocity_data.csv'),
    'crank': os.path.join(datasets_dir, 'crank_tangential_velocity_data.csv'),
    'Fitts': os.path.join(datasets_dir, 'Fitts_tangential_velocity_data.csv'),
    'whacamole': os.path.join(datasets_dir, 'whacamole_tangential_velocity_data.csv'),
    'object_moving': os.path.join(datasets_dir, 'object_moving_tangential_velocity_data.csv'),
    'pointing': os.path.join(datasets_dir, 'pointing_tangential_velocity_data.csv'),
    'tablet_writing': os.path.join(datasets_dir, 'tablet_writing_tangential_velocity_data.csv'),
}

dataset_names = dataset2path.keys()
# train_noise_condition = (10, 50)

noise_conditions = [float('inf'), 20, 10]
refractory_conditions = [(0., 0.5), (0.5, 1.), (1, 1.5)]


In [3]:
# %%
import os
import math
import time
import random
import copy
from collections import defaultdict

from itertools import product

import numpy as np
import matplotlib.pyplot as plt

import torch
from torch.utils import data
import torch.nn as nn
import torch.optim as optim
# from torch.utils.tensorboard import SummaryWriter  # TensorBoard writer

import wandb

from data import SyntheticDataset, OrganicDataset, CombinedSyntheticDataset
from models import TDNNDetector, STEContinuousReconstructor, STEBinarizer
from utils import onset_prediction_metrics_on_masks, Config, evaluate_on_organic_data, evaluate_on_synthetic_data
from sklearn.metrics import r2_score, mean_absolute_error





def log(message, file):
    with open(file, 'a') as f:
        f.write(message + '\n')
    print(message)



In [4]:
config = Config(CONFIG_PATH, root_dir=root_dir)

# os.environ['WANDB_NOTEBOOK_NAME'] = 'Train.ipynb'


if USE_WANDB:
    wandb.login(key=wandb_key)
    wandb.init(project='submovement_detector', name=config.experiment_name, config=config.to_dict() , save_code=True) #, settings=wandb.Settings(code_dir='.'))

    wandb.run.log_code(code_dir, include_fn=lambda path: path.endswith(".py") or path.endswith(".ipynb"))


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33meugenartemovich[0m ([33mNAME_REMOVED[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [5]:

# %%
# Instantiate the config
if __name__ == '__main__':

    # Model
    # dilation of 2 for most layers, 1 for the first and last
    # additional output vector for masking (is not trained in the beginning)
    # split train datasets into train and test
    # from the test dataset, select only a fraction of it for testing

    # Steps:
    # 1. Start training from synthetic data (~5 epochs)
    #   length (5 to 60 steps)
    #   amplitude (-1 to 1) multiplied by duration
    #   add jitter noise (10 to 50?)
    #   refractory (0.1 to 1.5, clip to 3 steps minimum)
    #   standardizing the trial afterwards to have std of 1
    #   no reconstruction loss,
    #   what about dropout and batchnorm?? should it starting without it
    #   measuring performance on the train and test datasets the whole time
    # 2. Train on organic data (from epoch ~5 onwards)
    #   without reconstruction loss
    #   generate 3 different classes of data from 3 different domains

    # 3. Add reconstruction loss

    # 4. Remove dropout and batchnorm?

    # refractory_distributions = [(0.1, 1.5)]
    # noise_conditions = [(10, 50)]
    # conditions = [{'refractory_distribution': rd, 'snr_distribution': nc, 'experiment_name': f'Overlapping-{rd[0]}_{rd[1]}_Noise-{nc}-no_different_amplitude'}
    #                 for nc, rd in product(noise_conditions, refractory_distributions)]

    config = Config(CONFIG_PATH, root_dir=root_dir)
    config.experiment_name = CONFIG_PATH.split('/')[-1].replace('.yaml', '')

    #conditions = [{'refractory_distribution': config.refractory_distribution, 'snr_distribution': config.snr_distribution, 'experiment_name': CONFIG_PATH.split('/')[-1].replace('.yaml', '')}]

    # conditions =\
    #     [{'refractory_distribution': (0.2, 1.5), 'snr_distribution': (10,50), 'experiment_name': 'Overlapping-0.2_1.5_Noise-10-40Fixed_Reconstruction5'}]\
    #         + conditions

    # TensorBoard writer setup

    # for condition in conditions:
    # if condition['refractory_distribution'] == (1, 2) and condition['snr_distribution'] != 10:
    #     continue

    # if (condition['snr_distribution'] != 10) and (condition['refractory_distribution'] != (1, 2) or condition['snr_distribution'] != 20):
    #     continue



    # for key, value in condition.items():
    #     setattr(config, key, value)

    os.makedirs(os.path.dirname(config.log_file), exist_ok=True)
    # writer = SummaryWriter(log_dir=config.log_dir)
    print(config.experiment_name)

    # Load the model
    model = TDNNDetector(
        batchnorm=config.batchnorm,
        dilations=config.dilations,
        channels=config.channels,
        kernel_sizes=config.kernel_sizes,
        num_layers=config.num_layers,
        dropout_rate=config.dropout_rate,
    ).to(config.device, config.dtype)

    model.eval()


    latest_epoch = -1
    if config.start_with_weights and config.start_with_weights != 'Xavier':
        weights_file = None

        if isinstance(config.start_with_weights, str):
            weights_file = config.start_with_weights
            weights_dir = '/'.join(config.weights_file.split('/')[:-1])
            weights_file = os.path.join(weights_dir, weights_file)
        else:
            # find the latest weights file
            weights_files = [f for f in os.listdir('/'.join(config.weights_file.split('/')[:-1])) if f.startswith(config.weights_file.split('/')[-1].replace('.pth', ''))]
            if len(weights_files) > 0:
                if isinstance(config.start_with_weights, bool):
                    epoch_numbers = [int(f.split('_')[-1].replace('.pth', '')) for f in weights_files]
                    latest_epoch = max(epoch_numbers)
                elif isinstance(config.start_with_weights, int):
                    latest_epoch = config.start_with_weights
                else:
                    raise ValueError(f'Invalid start_with_weights value: {config.start_with_weights}')
                weights_file = config.weights_file.replace('.pth', f'_{latest_epoch}.pth')
        if weights_file is not None:
            model.load_state_dict(torch.load(weights_file, map_location=config.device))
            print(f'Loaded weights from {weights_file}, continuing from epoch {latest_epoch}')
        else:
            print(f'No weights file found, starting from scratch')

    # Define the loss functions and optimizer
    criterion_entropy = nn.BCELoss()
    criterion_entropy_wo_reduction = nn.BCELoss(reduction='none')
    criterion_mse = nn.MSELoss()

    basic_dataset = SyntheticDataset(**config.get_dataset_parameters())

    reconstructor = basic_dataset.reconstruction_model

    if 'stat_snr_distribution':
        train_noise_condition = config.stat_snr_distribution
    else:
        train_noise_condition = (10, 50)
    dataset2stats_path = {
        dataset_name: os.path.join(datasets_dir, f'{config.experiment_name}-{dataset_name}-{train_noise_condition}-train-pulled_stats.csv')
        for dataset_name in dataset_names
    }

    noise_conditions_train = [train_noise_condition]
    if config.combined_dataset:

        dataset2path_train = {k: v for k, v in dataset2path.items() if k in config.datasets}

        evaluate_on_organic_data(
            model=model,
            dataset2path=dataset2path_train,
            noise_conditions=noise_conditions_train,
            config=config,
            reconstructor=reconstructor,
            step=0,
            purpose='train',
            low_pass_filter=np.inf,
            save_pulled_stats=config.datasets_dir
        )

        stats_datasets = []
        for dataset_name in config.datasets:
            file_path = dataset2stats_path[dataset_name]
            dataset = SyntheticDataset(joint_distribution=file_path, **config.get_dataset_parameters())
            stats_datasets.append(dataset)

        dataset = CombinedSyntheticDataset(stats_datasets + [basic_dataset], proportions=config.proportions,
                                                    total_duration_distribution=config.total_duration_distribution,
                                                    batch_size=config.batch_size, dtype=config.dtype, device=config.device)
    else:
        dataset = basic_dataset

    # Load the dataset and dataloader
    # dataset = SyntheticDataset(**config.get_dataset_parameters())
    dataloader = data.DataLoader(dataset, batch_size=1, shuffle=False)


    # lr exponential decay
    optimizer = optim.Adam(model.parameters(), lr=config.learning_rate)


    scheduler_start = config.lr_decay_start
    scheduler_end = config.lr_decay_end
    scheduler_total_decay = config.lr_decay_total_change
    step_decay = scheduler_total_decay ** (1/((scheduler_end - scheduler_start) * len(dataloader)))
    scheduler = optim.lr_scheduler.ExponentialLR(optimizer, gamma=step_decay)

    # Reconstruction model

config-0426-ModGaussian_abs_velo




In [6]:
if 'start_test' in config.to_dict() and config.start_test:
    model.eval()

    evaluate_on_synthetic_data(
                    model=model,
                    noise_conditions=noise_conditions,
                    refractory_conditions=refractory_conditions,
                    config=config,
                    reconsturctor=reconstructor,
            step=0
        )
    evaluate_on_organic_data(
                    model=model,
                    dataset2path=dataset2path,
                    noise_conditions=noise_conditions,
                    config=config,
                    reconstructor=reconstructor,
                    step=0
                )
    model.train()


In [None]:
if __name__ == '__main__':

    # Training loop
    model.train()
    # intialize with Xavier
    if config.start_with_weights == 'Xavier':
        for param in model.parameters():
            if isinstance(param, nn.Conv1d) or isinstance(param, nn.Linear):
                nn.init.xavier_normal_(param)
    tm = time.time()


    for epoch in range(latest_epoch+1, config.num_epochs):
        dataset.seed = epoch  # Update seed for each epoch
        detection_loss_mean = 0
        duration_loss_mean = 0
        amplitude_loss_mean = 0
        reconstruction_loss_mean = 0
        reconstruction_detection_loss_mean = 0
        mean_onset_precision = 0
        mean_onset_recall = 0
        mean_onset_distance = 0

        if epoch >= config.reconstruction_loss_start:
            config.use_reconstruction_loss = True

        if epoch >= config.bn_dropout_freeze_start:
            if config.batchnorm:
                for batch_norm in model.batchnorm_layers:
                    batch_norm.eval()
            if config.dropout_rate > 0:
                model.dropout.eval()

        for i, data_ in enumerate(dataloader):
            if epoch >= scheduler_start and epoch < scheduler_end:
                scheduler.step()
            x, x_clean, y = data_
            if x.device != config.device or x.dtype != config.dtype:
                x = x.to(config.device, config.dtype)
                y = y.to(config.device, config.dtype)
            x = x.squeeze(0)
            x_clean = x_clean.squeeze(0)
            y = y.squeeze(0)

            optimizer.zero_grad()
            y_pred = model(x)

            mask = y[:, 0]
            amplitude = y[:, 1]
            duration = y[:, 2]

            mask_pred = y_pred[:, 0]
            amplitude_pred = y_pred[:, 1]
            duration_pred = y_pred[:, 2]
            if y_pred.shape[1] == 4:
                reconstruction_mask_pred = y_pred[:, 3]
            else:
                reconstruction_mask_pred = None

            # Compute losses
            if hasattr(config, 'weight_with_amplitude') and config.weight_with_amplitude:
                detection_loss_positives = criterion_entropy_wo_reduction(mask_pred[mask == 1], mask[mask == 1])
                abs_amplitudes = torch.abs(amplitude[mask == 1])
                abs_amplitudes = torch.sqrt(abs_amplitudes)
                abs_amplitudes = abs_amplitudes.detach()
                detection_loss_positives *= abs_amplitudes
                detection_loss_positives = detection_loss_positives.mean() / abs_amplitudes.mean()             
            else:
                detection_loss_positives = criterion_entropy(mask_pred[mask == 1], mask[mask == 1])

            if hasattr(config, 'weight_up_to_n_neighbors') and config.weight_up_to_n_neighbors > 0:
                weight_up_to_n_neighbors = config.weight_up_to_n_neighbors
                weighting_factor = 0.5
                # detection_loss_negatives = 0
                scaling_exponent_tensor = torch.zeros_like(mask)
                for j in range(1, weight_up_to_n_neighbors + 1):
                    scaling_exponent_tensor[:, j:] = torch.max(scaling_exponent_tensor[:, j:], mask[:, :-j] * (weight_up_to_n_neighbors - j + 1))
                    scaling_exponent_tensor[:, :-j] = torch.max(scaling_exponent_tensor[:, :-j], mask[:, j:] * (weight_up_to_n_neighbors - j + 1))
                weight_tensor = torch.pow(weighting_factor, scaling_exponent_tensor)
                weight_tensor[mask == 1] = 0
                weight_tensor = weight_tensor.detach()
                detection_loss_negatives = criterion_entropy_wo_reduction(mask_pred[mask == 0], mask[mask == 0]) * weight_tensor[mask == 0]
                detection_loss_negatives = detection_loss_negatives.mean() / weight_tensor[mask == 0].mean()
            else:
                detection_loss_negatives = criterion_entropy(mask_pred[mask == 0], mask[mask == 0])

            if hasattr(config, 'negative_loss_multiplier'):
                if config.negative_loss_multiplier == 'adaptive' or config.negative_loss_multiplier == 'balanced':
                    binarized_mask_pred = reconstructor.binarizer.apply(mask_pred, False, True)
                    binarized_mask_pred = binarized_mask_pred.detach()
                    pred_to_true_ratio = binarized_mask_pred.sum() / mask.sum()
                    if config.negative_loss_multiplier == 'balanced' and pred_to_true_ratio < 1.0:
                        detection_loss_positives /= pred_to_true_ratio ** 0.5
                        # when there are not enough positives, we punish for false negatives more
                        # but not as much as for false positives
                    pred_to_true_ratio = torch.clamp(pred_to_true_ratio, 1.0, 10.0)
                    detection_loss_negatives *= pred_to_true_ratio
                elif isinstance(config.negative_loss_multiplier, (int, float)):
                    detection_loss_negatives *= config.negative_loss_multiplier
            detection_loss = detection_loss_positives + detection_loss_negatives

            original_duration_loss = criterion_mse(duration_pred[mask == 1], duration[mask == 1])
            original_amplitude_loss = criterion_mse(amplitude_pred[mask == 1], amplitude[mask == 1])

            if VIZUALIZE_BAD_AMPLITUDES and original_amplitude_loss.item() > 1000:
                print('='*100)
                print(f'Epoch: {epoch}, Iteration: {i}')
                print(f'Amplitude loss is too high: {original_amplitude_loss.item()}')
                print(f'Amplitude true: {amplitude[mask == 1]}')
                print(f'Amplitude pred: {amplitude_pred[mask == 1]}')
                for bt_el in range(amplitude.shape[0]):
                    onset_indices = np.where(mask[bt_el].cpu().numpy() == 1)[0]
                    bt_el_with_anomaly = False
                    for onset_index in onset_indices:
                        mse = (amplitude[bt_el, onset_index] - amplitude_pred[bt_el, onset_index])**2
                        if mse > 5000:
                            bt_el_with_anomaly = True
                            print(f'MSE: {mse}, Amplitude true: {amplitude[bt_el, onset_index]}, Amplitude pred: {amplitude_pred[bt_el, onset_index]}')
                    if bt_el_with_anomaly:
                        print(f'BT element {bt_el} has anomalies')
                        plt.figure(figsize=(10, 5), dpi=200)
                        plt.plot(amplitude[bt_el].detach().cpu().numpy())
                        plt.title('Amplitude')
                        plt.show()
                        plt.figure(figsize=(10, 5), dpi=200)
                        plt.plot(amplitude_pred[bt_el].detach().cpu().numpy())
                        plt.title('Amplitude pred')
                        plt.show()
                        plt.figure(figsize=(10, 5), dpi=200)
                        plt.plot(x[bt_el][0].detach().cpu().numpy())
                        plt.title('Signal')
                        plt.show()
                    print(f'x mean: {x[bt_el].mean().item()}, x ala std {np.sqrt((x[bt_el]**2).cpu().numpy()).mean().item()}')
                    print(f'amplitude mean: {amplitude[bt_el].mean().item()}')
                    print(f'amplitude pred mean: {amplitude_pred[bt_el].mean().item()}')

                print('='*100)

            if y_pred.shape[1] == 4:
                reconstruction_detection_loss_positives = criterion_entropy(reconstruction_mask_pred[mask == 1], mask[mask == 1])
                reconstruction_detection_loss_negatives = criterion_entropy(reconstruction_mask_pred[mask == 0], mask[mask == 0])
                original_reconstruction_detection_loss = reconstruction_detection_loss_positives + reconstruction_detection_loss_negatives

                reconstruction_detection_loss_mean *= i / (i + 1)
                reconstruction_detection_loss_mean += original_reconstruction_detection_loss.item() / (i + 1)
                reconstruction_detection_loss = original_reconstruction_detection_loss / reconstruction_detection_loss_mean * detection_loss_mean
                reconstruction_detection_loss *= 0.1 # decrease the weight of the reconstruction detection loss

                # writer.add_scalar('Loss/ReconstructionDetection', original_reconstruction_detection_loss.item(), epoch * len(dataloader) + i)
                if USE_WANDB:
                    wandb.log({'Loss/ReconstructionDetection': original_reconstruction_detection_loss.item()}, step=epoch * len(dataloader) + i)


            # Update running means
            detection_loss_mean *= i / (i + 1)
            detection_loss_mean += detection_loss.item() / (i + 1)
            duration_loss_mean *= i / (i + 1)
            duration_loss_mean += original_duration_loss.item() / (i + 1)
            amplitude_loss_mean *= i / (i + 1)
            amplitude_loss_mean += original_amplitude_loss.item() / (i + 1)

            # Normalize the losses
            duration_loss = original_duration_loss / duration_loss_mean * detection_loss_mean
            amplitude_loss = original_amplitude_loss / amplitude_loss_mean * detection_loss_mean

            # Compute reconstruction loss if applicable
            reconstructed_x, _ = reconstructor(y_pred)
            if not config.use_reconstruction_loss:
                reconstructed_x = reconstructed_x.detach()
            original_reconstruction_loss = criterion_mse(reconstructed_x, x_clean)
            reconstruction_loss_mean *= i / (i + 1)
            reconstruction_loss_mean += original_reconstruction_loss.item() / (i + 1)
            reconstruction_loss = original_reconstruction_loss / reconstruction_loss_mean * detection_loss_mean

            # Total loss
            loss = detection_loss + duration_loss + amplitude_loss
            if config.use_reconstruction_loss:
                loss += reconstruction_loss
            elif y_pred.shape[1] == 4:
                loss += reconstruction_detection_loss

            # Backpropagation and optimizer step
            loss.backward()
            optimizer.step()

            if USE_WANDB:
                wandb.log({'Loss/Total': loss.item(),
                        'Loss/Detection': detection_loss.item(),
                        'Loss/Duration': original_duration_loss.item(),
                        'Loss/Amplitude': original_amplitude_loss.item(),
                        'Loss/Reconstruction': original_reconstruction_loss.item()}, step=epoch * len(dataloader) + i)

            # Onset metrics and printing progress
            if i % config.log_interval == 0:
                _, _, _, precision, recall, _, distance = onset_prediction_metrics_on_masks(mask, mask_pred)
                mean_onset_precision *= (i // config.log_interval) / (i // config.log_interval + 1)
                mean_onset_precision += precision / (i // config.log_interval + 1)
                mean_onset_recall *= (i // config.log_interval) / (i // config.log_interval + 1)
                mean_onset_recall += recall / (i // config.log_interval + 1)
                mean_onset_distance *= (i // config.log_interval) / (i // config.log_interval + 1)
                mean_onset_distance += distance / (i // config.log_interval + 1)

                if USE_WANDB:
                    wandb.log({'Onset/Precision': precision,
                            'Onset/Recall': recall,
                            'Onset/Distance': distance}, step=epoch * len(dataloader) + i)

                if USE_WANDB:
                    wandb.log({'Params/LearningRate': scheduler.get_last_lr()[0]}, step=epoch * len(dataloader) + i)

                message =   f'Epoch {epoch}, Iteration {i}, Loss: {loss.item()},'\
                            f'\nDetection Loss: {detection_loss.item()},'\
                            f'\nDuration Loss: {original_duration_loss.item()}, Amplitude Loss: {original_amplitude_loss.item()},'\
                            f'\nReconstruction Loss: {original_reconstruction_loss.item()}'\
                            f'\nOnset Precision: {precision}, Onset Recall: {recall}, Onset Distance: {distance}'\
                            f'\nTime: {time.time() - tm}'
                log(message, config.log_file)


            # Plot reconstructions at intervals
            if i % config.plot_interval == 0:
                plt.figure(figsize=(10, 5), dpi=200)
                ids_to_plot = random.sample(range(len(x)), config.reconstructions_to_plot)
                for j, id_ in enumerate(ids_to_plot):

                    num_rows = 2
                    num_cols = math.ceil(config.reconstructions_to_plot / 2)
                    plt.subplot(num_rows, num_cols, j + 1)
                    original_signal = x[id_].squeeze().detach().cpu().numpy()
                    clean_signal = x_clean[id_].squeeze().detach().cpu().numpy()
                    reconstructed_signal = reconstructed_x[id_].squeeze().detach().cpu().numpy()
                    ts = np.arange(len(original_signal)) / 60
                    plt.plot(ts, original_signal)
                    plt.plot(ts, reconstructed_signal, linestyle='--')
                    plt.plot(ts, clean_signal, linestyle=':')
                    # plt.title(f'Original vs Reconstructed Signal {j}')
                    if j == 0:
                        plt.legend(['Original', 'Reconstructed', 'Clean'])
                    if j % num_cols == 0:
                        plt.ylabel('Amplitude, a.u.')
                    if j >= num_cols * (num_rows - 1):
                        plt.xlabel('Time (s)')
                    plt.grid()
                # plt.tight_layout()
                plt.suptitle(f'Original vs Reconstructed Signal Examples, Epoch {epoch}, Step {i}')

                if USE_WANDB:
                    wandb.log({'Reconstructions': wandb.Image(plt)}, step=epoch * len(dataloader) + i)

                plt.show()

            # break

        # Log mean losses and onset metrics for the epoch
        if USE_WANDB:
            wandb.log({'Loss_Epoch/Detection_Mean': detection_loss_mean,
                        'Loss_Epoch/Duration_Mean': duration_loss_mean,
                        'Loss_Epoch/Amplitude_Mean': amplitude_loss_mean,
                        'Loss_Epoch/Reconstruction_Mean': reconstruction_loss_mean,
                        'Onset_Epoch/Precision_Mean': mean_onset_precision,
                        'Onset_Epoch/Recall_Mean': mean_onset_recall,
                        'Onset_Epoch/Distance_Mean': mean_onset_distance}, step=len(dataloader) * (epoch+1))

        # Save model weights at the end of the epoch
        os.makedirs(os.path.dirname(config.weights_file), exist_ok=True)
        torch.save(model.state_dict(), config.weights_file.replace('.pth', f'_{epoch}.pth'))

        message =   f'Epoch {epoch} finished,'\
                    f'\nDetection Loss: {detection_loss_mean},'\
                    f'\nDuration Loss: {duration_loss_mean}, Amplitude Loss: {amplitude_loss_mean},'\
                    f'\nReconstruction Loss: {reconstruction_loss_mean}'\
                    f'\nOnset Precision: {mean_onset_precision}, Onset Recall: {mean_onset_recall}, Onset Distance: {mean_onset_distance}'\
                    f'\nTime: {time.time() - tm}'
        log(message, config.log_file)

        model.eval()
        if (epoch+1) % 1 == 0:
            evaluate_on_synthetic_data(
                            model=model,
                            noise_conditions=noise_conditions,
                            refractory_conditions=refractory_conditions,
                            config=config,
                            reconsturctor=reconstructor,
                    step=len(dataloader) * (epoch+1)
                )
        if (epoch+1) % 5 == 0:
            evaluate_on_organic_data(
                            model=model,
                            dataset2path=dataset2path,
                            noise_conditions=noise_conditions,
                            config=config,
                            reconstructor=reconstructor,
                            step=len(dataloader) * (epoch+1)
                        )

        if (epoch+1) % 1 == 0:
            if config.combined_dataset:
                dataset2path_ = dataset2path_train
            else:
                dataset2path_ = {}


            evaluate_on_organic_data(
                model=model,
                dataset2path=dataset2path_,
                noise_conditions=noise_conditions_train, 
                config=config,
                reconstructor=reconstructor,
                step=len(dataloader) * (epoch+1),
                purpose='train',
                low_pass_filter=np.inf,
                save_pulled_stats=config.datasets_dir
            )
            stats_datasets = []
            if 'datasets' in config.to_dict():
                for dataset_name in config.datasets:
                    file_path = dataset2stats_path[dataset_name]
                    dataset = SyntheticDataset(joint_distribution=file_path, **config.get_dataset_parameters())
                    stats_datasets.append(dataset)

                dataset = CombinedSyntheticDataset(stats_datasets + [basic_dataset], proportions=config.proportions,
                                                            total_duration_distribution=config.total_duration_distribution,
                                                            batch_size=config.batch_size, dtype=config.dtype, device=config.device)
                
                dataloader = data.DataLoader(dataset, batch_size=1, shuffle=False)

        model.train()
        # save tensorboard logs
        # writer.flush()


    # Close the writer after training
    #writer.close()

    # %%


In [45]:
model

tensor([[0.1862, 0.0931, 0.0466,  ..., 0.7450, 0.7450, 0.7450],
        [0.7450, 0.7450, 0.7450,  ..., 0.7450, 0.7450, 0.7450],
        [0.7450, 0.7450, 0.7450,  ..., 0.7450, 0.7450, 0.7450],
        ...,
        [0.3725, 0.1862, 0.0931,  ..., 0.7450, 0.7450, 0.7450],
        [0.7450, 0.7450, 0.7450,  ..., 0.7450, 0.7450, 0.7450],
        [0.7450, 0.7450, 0.7450,  ..., 0.7450, 0.7450, 0.7450]],
       grad_fn=<MulBackward0>)

In [None]:
if USE_WANDB:
    wandb.finish()
if IN_COLAB:
    runtime.unassign()