In [1]:
import os
import shutil
import pickle
import json
import time
from pathlib import Path

import torch
import torch.utils.data
from torcheval.metrics.functional import multiclass_confusion_matrix
from sklearn.model_selection import StratifiedKFold
from datasets_utils import  ds_get_info, ds_load, TrainTestDS, TSCDataset, DSInfo, ArtificialProtos
from autoencoder import PermutingConvAutoencoder, train_autoencoder, RegularConvEncoder
from log import create_logger

from typing import Dict

from train_utils import EarlyStopping

from train import EpochType, ProtoTSCoeffs, train_prototsnet, best_stat_saver, get_verbose_logger, BestModelCheckpointer

import pandas as pd
import numpy as np

import traceback

In [2]:
device = torch.device('cuda')

DATASETS_PATH = Path('datasets')

In [3]:
def experiment_setup(experiment_subpath):
    experiment_dir = Path.cwd() / 'experiments' / experiment_subpath
    os.makedirs(experiment_dir, exist_ok=True)

    shutil.copy(src=Path.cwd()/'autoencoder.py', dst=experiment_dir)
    shutil.copy(src=Path.cwd()/'datasets_utils.py', dst=experiment_dir)
    shutil.copy(src=Path.cwd()/'experiments.ipynb', dst=experiment_dir)
    shutil.copy(src=Path.cwd()/'model.py', dst=experiment_dir)
    shutil.copy(src=Path.cwd()/'push.py', dst=experiment_dir)
    shutil.copy(src=Path.cwd()/'train_utils.py', dst=experiment_dir)
    shutil.copy(src=Path.cwd()/'train.py', dst=experiment_dir)
    
    return experiment_dir

In [4]:
pickled_dses_file = './datasets.pickle'
if os.path.exists(pickled_dses_file):
    with open(pickled_dses_file, 'rb') as f:
        all_ds: Dict[str, TrainTestDS] = pickle.load(f)
else:
    all_ds = ds_load(DATASETS_PATH, list(ds_get_info().keys()))
    with open(pickled_dses_file, 'wb') as f:
        pickle.dump(all_ds, f)


In [None]:
experiment_name = "HyperparameterSearch5Fold"

default_params = {
    "coeffs": ProtoTSCoeffs(crs_ent=1, clst=0, sep=0, l1=1e-3, l1_addon=3e-4),
    "reception": 0.25,
    "proto_len": 5,
    "protos_per_class": 10,
    "proto_features": 32,
    "features_lr": 1e-3,
    "push_start_epoch": 60,
    "num_last_layer_epochs": 40,
}

for ds_name, whole_dataset in [(k, v) for k, v in all_ds.items() if k == 'ArticularyWordRecognition']: # all_ds.items():

    if ds_name == 'StandWalkJump':
        kfold = StratifiedKFold(n_splits=4)
    else:
        kfold = StratifiedKFold(n_splits=5)

    ds_info = ds_get_info(whole_dataset.name)
    ds_info.features = whole_dataset.train.X.shape[1]
    ds_info.ts_len = whole_dataset.train.X.shape[2]

    receptions_to_try = [0.25, 0.5, 0.75, 0.9]
    if ds_info.features < 4:
        receptions_to_try.remove(0.25)
        receptions_to_try.remove(0.75)
    elif ds_info.features == 4:
        receptions_to_try.remove(0.75)
    elif ds_info.features > 500:
        receptions_to_try.remove(0.9)

    proto_len_factors_to_try = [0.01, 0.1, 0.3, 1]
    if ds_info.ts_len < 100:
        proto_len_factors_to_try.remove(0.01)

    for reception in receptions_to_try:
        for proto_len_factor in proto_len_factors_to_try:
            for fold_idx, (train_ind, test_ind) in enumerate(
                kfold.split(whole_dataset.train.X, whole_dataset.train.y)
            ):
                ds_info = ds_get_info(whole_dataset.name)
                ds_info.features = whole_dataset.train.X.shape[1]
                ds_info.ts_len = whole_dataset.train.X.shape[2]

                dataset = TrainTestDS(
                    ds_name + f"-fold-{fold_idx}",
                    train=TSCDataset(
                        whole_dataset.train.X[train_ind], whole_dataset.train.y[train_ind]
                    ),
                    val=TSCDataset(
                        whole_dataset.train.X[test_ind], whole_dataset.train.y[test_ind]
                    ),
                    test=TSCDataset(
                        whole_dataset.test.X, whole_dataset.test.y
                    ),
                )

                proto_len = max(int(ds_info.ts_len * proto_len_factor), 1)
                curr_experiment_dir = experiment_setup(
                    f"{experiment_name}/{whole_dataset.name}/proto-len-{proto_len}/reception-{reception}/fold-{fold_idx}"
                )

                log, logclose = create_logger(curr_experiment_dir / "log.txt", display=True)

                try:
                    if os.path.exists(curr_experiment_dir / 'models' / 'last-epoch.pth'):
                        print(f"Skipping training for {dataset.name}, proto len {proto_len}, reception {reception}, already done")
                        continue

                    curr_link_path = Path.cwd() / 'experiments' / experiment_name / 'current'
                    if os.path.islink(curr_link_path):
                        os.unlink(curr_link_path)
                    os.symlink(curr_experiment_dir, curr_link_path)
                    
                    curr_log_link_path = Path.cwd() / 'experiments' / experiment_name / 'curr_log.txt'
                    if os.path.islink(curr_log_link_path):
                        os.unlink(curr_log_link_path)
                    os.symlink(curr_experiment_dir / 'log.txt', curr_log_link_path)
                    
                    features_lr = default_params["features_lr"]

                    protos_per_class = default_params["protos_per_class"]
                    proto_features = default_params["proto_features"]
                    train_batch_size = 32
                    while train_batch_size > len(dataset.train.X) / 2:
                        train_batch_size //= 2
                    test_batch_size = 128
                    coeffs = default_params["coeffs"]
                    padding = 'same'

                    push_start_epoch = default_params["push_start_epoch"]
                    num_warm_epochs = push_start_epoch
                    num_last_layer_epochs = default_params["num_last_layer_epochs"]

                    early_stopping = EarlyStopping(
                        retrieve_stat="loss_val",
                        mode="min",
                        patience=60,
                        wait=push_start_epoch,
                    )

                    params = {
                        "protos_per_class": protos_per_class,
                        "proto_features": proto_features,
                        "proto_len_latent": proto_len,
                        "features_lr": features_lr,
                        "num_classes": ds_info.num_classes,
                        "protos_per_class": protos_per_class,
                        "coeffs": coeffs._asdict(),
                        "num_warm_epochs": num_warm_epochs,
                        "push_start_epoch": push_start_epoch,
                        "num_last_layer_epochs": num_last_layer_epochs,
                    }
                    with open(curr_experiment_dir / "params.json", "w") as f:
                        json.dump(params, f, indent=4)

                    log(
                        f"Training for {dataset.name}, proto len {proto_len}, reception {reception}, features_lr {features_lr}, protos per class {protos_per_class}, l1_addon {coeffs.l1_addon}",
                        flush=True,
                        display=True
                    )
                    log(f'Params: {json.dumps(params, indent=4)}')
                    
                    whole_training_start = time.time()

                    log(f'Training encoder', flush=True, display=True)
                    autoencoder = PermutingConvAutoencoder(num_features=ds_info.features, latent_features=proto_features, reception_percent=reception, padding=padding)
                    train_loader = torch.utils.data.DataLoader(dataset.train, batch_size=train_batch_size, shuffle=True)
                    val_loader = torch.utils.data.DataLoader(dataset.val, batch_size=test_batch_size)
                    train_autoencoder(autoencoder, train_loader, val_loader, device=device, log=log)
                    encoder = autoencoder.encoder

                    log(f'Training ProtoTSNet', flush=True, display=True)
                    trainer = train_prototsnet(
                        dataset,
                        curr_experiment_dir,
                        device,
                        encoder,
                        features_lr,
                        coeffs,
                        protos_per_class,
                        proto_features,
                        proto_len,
                        train_batch_size,
                        test_batch_size,
                        num_epochs=1000,
                        num_warm_epochs=num_warm_epochs,
                        push_start_epoch=push_start_epoch,
                        push_epochs=range(0, 1000, 20),
                        ds_info=ds_info,
                        num_last_layer_epochs=num_last_layer_epochs,
                        custom_checkpointers=[
                            get_verbose_logger(dataset.name),
                            best_stat_saver(
                                "loss_val", curr_experiment_dir / "min_loss.json"
                            ) if dataset.val is not None else lambda *_: None,
                        ],
                        early_stopping=early_stopping,
                        log=log,
                    )

                    accu_val = trainer.stats()["accu_val"]
                    for i, d in enumerate(accu_val):
                        if d["epoch_type"] == "PUSH":
                            break
                    accu_val = accu_val[i:]
                    log(
                        f'Overall best val accu: {max(s["value"] for s in accu_val)*100:.2f}%, push best: {max(s["value"] for s in trainer.stats()["accu_test"] if s["epoch_type"] in ["PUSH", "LAST_LAYER"])*100:.2f}%',
                        display=True)
                    whole_training_end = time.time()
                    
                    log(f"Done in {trainer.curr_epoch - 1} epochs, {whole_training_end - whole_training_start:.2f}s", display=True)
                except Exception as e:
                    log(f"Exception ocurred for {ds_name}: {e}", display=True)
                    tb_str = traceback.format_tb(e.__traceback__)
                    log('\n'.join(tb_str), display=True)
                finally:
                    logclose()

In [17]:
experiment_name = "HyperparameterOptimized"

best_params = pd.read_csv('best_params.csv', index_col=0)

default_params = {
    "coeffs": ProtoTSCoeffs(crs_ent=1, clst=0, sep=0, l1=1e-3, l1_addon=3e-5),
    "reception": 0.25,
    "proto_len": 5,
    "protos_per_class": 10,
    "proto_features": 32,
    "features_lr": 1e-3,
    "push_start_epoch": 110,
    "num_last_layer_epochs": 40,
}

for run in range(1, 6):
    for ds_name, dataset in all_ds.items():
        ds_info = ds_get_info(dataset.name)

        proto_len = int(best_params.loc[ds_name, 'proto_len'])
        reception = float(best_params.loc[ds_name, 'reception'])
        epochs = int(best_params.loc[ds_name, 'epochs'])
        curr_experiment_dir = experiment_setup(f"{experiment_name}/{dataset.name}/run-{run}")

        log, logclose = create_logger(curr_experiment_dir / "log.txt", display=True)

        try:
            if os.path.exists(curr_experiment_dir / 'models' / 'last-epoch.pth'):
                print(f"Skipping training for {dataset.name}, already done")
                continue

            curr_link_path = Path.cwd() / 'experiments' / experiment_name / 'current'
            if os.path.islink(curr_link_path):
                os.unlink(curr_link_path)
            os.symlink(curr_experiment_dir, curr_link_path)
            
            curr_log_link_path = Path.cwd() / 'experiments' / experiment_name / 'curr_log.txt'
            if os.path.islink(curr_log_link_path):
                os.unlink(curr_log_link_path)
            os.symlink(curr_experiment_dir / 'log.txt', curr_log_link_path)
            
            features_lr = default_params["features_lr"]

            protos_per_class = default_params["protos_per_class"]
            proto_features = default_params["proto_features"]
            train_batch_size = 32
            while train_batch_size > len(dataset.train.X) / 2:
                train_batch_size //= 2
            test_batch_size = 128
            coeffs = default_params["coeffs"]
            padding = 'same'

            push_start_epoch = default_params["push_start_epoch"]
            num_warm_epochs = 50
            num_last_layer_epochs = default_params["num_last_layer_epochs"]
            push_epochs = range(push_start_epoch, 1000, 30)

            params = {
                "protos_per_class": protos_per_class,
                "proto_features": proto_features,
                "proto_len_latent": proto_len,
                "features_lr": features_lr,
                "num_classes": ds_info.num_classes,
                "protos_per_class": protos_per_class,
                "coeffs": coeffs._asdict(),
                "num_warm_epochs": num_warm_epochs,
                "push_start_epoch": push_start_epoch,
                "num_last_layer_epochs": num_last_layer_epochs,
                "epochs": epochs,
            }
            with open(curr_experiment_dir / "params.json", "w") as f:
                json.dump(params, f, indent=4)

            log(
                f"Training for {dataset.name}, proto len {proto_len}, reception {reception}, features_lr {features_lr}, protos per class {protos_per_class}, l1_addon {coeffs.l1_addon}",
                flush=True,
                display=True
            )
            log(f'Params: {json.dumps(params, indent=4)}')
            
            whole_training_start = time.time()

            log(f'Training encoder', flush=True, display=True)
            autoencoder = PermutingConvAutoencoder(num_features=ds_info.features, latent_features=proto_features, reception_percent=reception, padding=padding)
            train_ds = TSCDataset(dataset.train.X, dataset.train.y)
            train_loader = torch.utils.data.DataLoader(dataset.train, batch_size=train_batch_size, shuffle=True)
            test_loader = torch.utils.data.DataLoader(dataset.test, batch_size=test_batch_size)
            train_autoencoder(autoencoder, train_loader, test_loader, device=device, log=log)
            encoder = autoencoder.encoder

            log(f'Training ProtoTSNet', flush=True, display=True)
            trainer = train_prototsnet(
                dataset,
                curr_experiment_dir,
                device,
                encoder,
                features_lr,
                coeffs,
                protos_per_class,
                proto_features,
                proto_len,
                train_batch_size,
                test_batch_size,
                num_epochs=epochs,
                num_warm_epochs=num_warm_epochs,
                push_start_epoch=push_start_epoch,
                push_epochs=push_epochs,
                ds_info=ds_info,
                num_last_layer_epochs=num_last_layer_epochs,
                custom_checkpointers=[
                    get_verbose_logger(dataset.name),
                ],
                log=log,
            )

            accu_test = trainer.latest_stat("accu_test")
            log(f'Last epoch test accu: {accu_test*100:.2f}%', display=True)
            with open(curr_experiment_dir / "test_accu.json", "w") as f:
                json.dump({"value": accu_test}, f, indent=4)
            
            ptsnet = trainer.ptsnet
            confusion_matrix = torch.zeros(ptsnet.num_classes, ptsnet.num_classes)
            for i, (image, label) in enumerate(test_loader):
                output, _ = ptsnet(image.to(device))
                confusion_matrix += multiclass_confusion_matrix(output.to('cpu'), label, num_classes=output.shape[1])
            np.savetxt(curr_experiment_dir / 'confusion_matrix.txt', confusion_matrix.numpy(), fmt='%4d')

            whole_training_end = time.time()
            log(f"Done in {trainer.curr_epoch - 1} epochs, {whole_training_end - whole_training_start:.2f}s", display=True)
        except Exception as e:
            log(f"Exception ocurred for {ds_name}: {e}", display=True)
            tb_str = traceback.format_tb(e.__traceback__)
            log('\n'.join(tb_str), display=True)
        finally:
            logclose()

Training for ArticularyWordRecognition, proto len 144, reception 0.25, features_lr 0.001, protos per class 10, l1_addon 3e-05
Params: {
    "protos_per_class": 10,
    "proto_features": 32,
    "proto_len_latent": 144,
    "features_lr": 0.001,
    "num_classes": 25,
    "coeffs": {
        "crs_ent": 1,
        "clst": 0,
        "sep": 0,
        "l1": 0.001,
        "l1_addon": 3e-05
    },
    "num_warm_epochs": 50,
    "push_start_epoch": 110,
    "num_last_layer_epochs": 40,
    "epochs": 200
}
Training encoder
epoch:   10/300 mse loss: 0.0437
epoch:   20/300 mse loss: 0.0360
epoch:   30/300 mse loss: 0.0391
epoch:   40/300 mse loss: 0.0538
epoch:   50/300 mse loss: 0.0527
epoch:   60/300 mse loss: 0.0597
epoch:   70/300 mse loss: 0.0694
epoch:   80/300 mse loss: 0.0751
epoch:   90/300 mse loss: 0.0763
epoch:  100/300 mse loss: 0.0742
epoch:  110/300 mse loss: 0.0753
epoch:  120/300 mse loss: 0.0751
epoch:  130/300 mse loss: 0.0750
epoch:  140/300 mse loss: 0.0729
epoch:  150/300

In [18]:
experiment_name = "PrototypesTests"

selected_dses = ["Libras"]

best_params = pd.read_csv('best_params.csv', index_col=0)

default_params = {
    "coeffs": ProtoTSCoeffs(crs_ent=1, clst=0, sep=0, l1=1e-3, l1_addon=1e-4),
    "reception": 0.25,
    "proto_len": 5,
    "protos_per_class": 3,
    "proto_features": 32,
    "features_lr": 1e-4,
    "push_start_epoch": 110,
    "num_last_layer_epochs": 40,
}

for ds_name, dataset in all_ds.items():
    if ds_name not in selected_dses:
        continue
    
    ds_info = ds_get_info(dataset.name)

    proto_len = int(best_params.loc[ds_name, 'proto_len'])
    reception = float(best_params.loc[ds_name, 'reception'])
    epochs = int(best_params.loc[ds_name, 'epochs'])
    curr_experiment_dir = experiment_setup(f"{experiment_name}/{dataset.name}")

    log, logclose = create_logger(curr_experiment_dir / "log.txt", display=True)

    try:
        curr_link_path = Path.cwd() / 'experiments' / experiment_name / 'current'
        if os.path.islink(curr_link_path):
            os.unlink(curr_link_path)
        os.symlink(curr_experiment_dir, curr_link_path)
        
        curr_log_link_path = Path.cwd() / 'experiments' / experiment_name / 'curr_log.txt'
        if os.path.islink(curr_log_link_path):
            os.unlink(curr_log_link_path)
        os.symlink(curr_experiment_dir / 'log.txt', curr_log_link_path)
        
        features_lr = default_params["features_lr"]

        protos_per_class = default_params["protos_per_class"]
        proto_features = default_params["proto_features"]
        train_batch_size = 32
        while train_batch_size > len(dataset.train.X) / 2:
            train_batch_size //= 2
        test_batch_size = 128
        if ds_name == 'EigenWorms':
            train_batch_size //= 2
            test_batch_size = 64
        coeffs = default_params["coeffs"]
        padding = 'same'

        push_start_epoch = default_params["push_start_epoch"]
        num_warm_epochs = push_start_epoch - 60
        num_last_layer_epochs = default_params["num_last_layer_epochs"]
        push_epochs = range(push_start_epoch, 1000, 30)

        params = {
            "protos_per_class": protos_per_class,
            "proto_features": proto_features,
            "proto_len_latent": proto_len,
            "features_lr": features_lr,
            "num_classes": ds_info.num_classes,
            "protos_per_class": protos_per_class,
            "coeffs": coeffs._asdict(),
            "num_warm_epochs": num_warm_epochs,
            "push_start_epoch": push_start_epoch,
            "num_last_layer_epochs": num_last_layer_epochs,
            "epochs": epochs,
        }
        with open(curr_experiment_dir / "params.json", "w") as f:
            json.dump(params, f, indent=4)

        log(
            f"Training for {dataset.name}, proto len {proto_len}, reception {reception}, features_lr {features_lr}, protos per class {protos_per_class}, l1_addon {coeffs.l1_addon}",
            flush=True,
            display=True
        )
        log(f'Params: {json.dumps(params, indent=4)}')
        
        whole_training_start = time.time()

        log(f'Training encoder', flush=True, display=True)
        autoencoder = PermutingConvAutoencoder(num_features=ds_info.features, latent_features=proto_features, reception_percent=reception, padding=padding)
        train_ds = TSCDataset(dataset.train.X, dataset.train.y)
        train_loader = torch.utils.data.DataLoader(dataset.train, batch_size=train_batch_size, shuffle=True)
        test_loader = torch.utils.data.DataLoader(dataset.test, batch_size=test_batch_size)
        train_autoencoder(autoencoder, train_loader, test_loader, device=device, log=log)
        encoder = autoencoder.encoder

        log(f'Training ProtoTSNet', flush=True, display=True)
        trainer = train_prototsnet(
            dataset,
            curr_experiment_dir,
            device,
            encoder,
            features_lr,
            coeffs,
            protos_per_class,
            proto_features,
            proto_len,
            train_batch_size,
            test_batch_size,
            num_epochs=epochs,
            num_warm_epochs=num_warm_epochs,
            push_start_epoch=push_start_epoch,
            push_epochs=push_epochs,
            ds_info=ds_info,
            num_last_layer_epochs=num_last_layer_epochs,
            custom_checkpointers=[
                get_verbose_logger(dataset.name)
            ],
            log=log,
        )

        accu_test = trainer.latest_stat("accu_test")
        log(f'Last epoch test accu: {accu_test*100:.2f}%', display=True)
        with open(curr_experiment_dir / "test_accu.json", "w") as f:
            json.dump({"value": accu_test}, f, indent=4)

        whole_training_end = time.time()
        log(f"Done in {trainer.curr_epoch - 1} epochs, {whole_training_end - whole_training_start:.2f}s", display=True)
    except Exception as e:
        log(f"Exception ocurred for {ds_name}: {e}", display=True)
        tb_str = traceback.format_tb(e.__traceback__)
        log('\n'.join(tb_str), display=True)
    finally:
        logclose()

Training for Libras, proto len 13, reception 0.9, features_lr 0.0001, protos per class 3, l1_addon 0.0001
Params: {
    "protos_per_class": 3,
    "proto_features": 32,
    "proto_len_latent": 13,
    "features_lr": 0.0001,
    "num_classes": 15,
    "coeffs": {
        "crs_ent": 1,
        "clst": 0,
        "sep": 0,
        "l1": 0.001,
        "l1_addon": 0.0001
    },
    "num_warm_epochs": 50,
    "push_start_epoch": 110,
    "num_last_layer_epochs": 40,
    "epochs": 200
}
Training encoder
epoch:   10/300 mse loss: 0.0065
epoch:   20/300 mse loss: 0.0016
epoch:   30/300 mse loss: 0.0008
epoch:   40/300 mse loss: 0.0009
epoch:   50/300 mse loss: 0.0007
epoch:   60/300 mse loss: 0.0006
epoch:   70/300 mse loss: 0.0007
epoch:   80/300 mse loss: 0.0007
epoch:   90/300 mse loss: 0.0009
epoch:  100/300 mse loss: 0.0010
epoch:  110/300 mse loss: 0.0006
epoch:  120/300 mse loss: 0.0007
epoch:  130/300 mse loss: 0.0007
epoch:  140/300 mse loss: 0.0007
epoch:  150/300 mse loss: 0.0007
ep

In [6]:
artifTrainDS = ArtificialProtos(1000, feature_noise_power=0.05, randomize_right_side=False)
artifTestDS = ArtificialProtos(300, feature_noise_power=0.05, randomize_right_side=False)

In [19]:
experiment_name = 'ArtificialDataset'
experiment_dir = experiment_setup(experiment_name)

default_params = {
    "coeffs": ProtoTSCoeffs(crs_ent=1, clst=0, sep=0, l1=1e-3, l1_addon=3e-4),
    "reception": 0.75,
    "proto_len": 20,
    "protos_per_class": 1,
    "proto_features": 32,
    "features_lr": 1e-3,
    "push_start_epoch": 110,
    "num_last_layer_epochs": 40,
}

dataset = TrainTestDS('ArtificialDataset', artifTrainDS, artifTestDS)
ds_info = DSInfo(dataset.name, features=3, ts_len=100, num_classes=4)

log, logclose = create_logger(experiment_dir / "log.txt", display=True)

try:
    features_lr = default_params["features_lr"]

    proto_len = default_params["proto_len"]
    reception = default_params["reception"]
    epochs = 200

    protos_per_class = default_params["protos_per_class"]
    proto_features = default_params["proto_features"]
    train_batch_size = 32
    test_batch_size = 128
    coeffs = default_params["coeffs"]
    padding = 'same'

    push_start_epoch = default_params["push_start_epoch"]
    num_warm_epochs = push_start_epoch - 60
    num_last_layer_epochs = default_params["num_last_layer_epochs"]
    push_epochs = range(push_start_epoch, 1000, 30)

    params = {
        "protos_per_class": protos_per_class,
        "proto_features": proto_features,
        "proto_len_latent": proto_len,
        "features_lr": features_lr,
        "num_classes": ds_info.num_classes,
        "protos_per_class": protos_per_class,
        "coeffs": coeffs._asdict(),
        "num_warm_epochs": num_warm_epochs,
        "push_start_epoch": push_start_epoch,
        "num_last_layer_epochs": num_last_layer_epochs,
        "epochs": epochs,
    }
    with open(experiment_dir / "params.json", "w") as f:
        json.dump(params, f, indent=4)

    log(
        f"Training for {dataset.name}, proto len {proto_len}, reception {reception}, features_lr {features_lr}, protos per class {protos_per_class}, l1_addon {coeffs.l1_addon}",
        flush=True,
        display=True
    )
    log(f'Params: {json.dumps(params, indent=4)}')
    
    whole_training_start = time.time()

    log(f'Training encoder', flush=True, display=True)
    autoencoder = PermutingConvAutoencoder(num_features=ds_info.features, latent_features=proto_features, reception_percent=reception, padding=padding)
    train_loader = torch.utils.data.DataLoader(dataset.train, batch_size=train_batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset.test, batch_size=test_batch_size)
    train_autoencoder(autoencoder, train_loader, test_loader, device=device, log=log)
    encoder = autoencoder.encoder

    log(f'Training ProtoTSNet', flush=True, display=True)
    trainer = train_prototsnet(
        dataset,
        experiment_dir,
        device,
        encoder,
        features_lr,
        coeffs,
        protos_per_class,
        proto_features,
        proto_len,
        train_batch_size,
        test_batch_size,
        num_epochs=epochs,
        num_warm_epochs=num_warm_epochs,
        push_start_epoch=push_start_epoch,
        push_epochs=push_epochs,
        ds_info=ds_info,
        num_last_layer_epochs=num_last_layer_epochs,
        custom_checkpointers=[
            get_verbose_logger(dataset.name)
        ],
        log=log,
    )

    accu_test = trainer.latest_stat("accu_test")
    log(f'Last epoch test accu: {accu_test*100:.2f}%', display=True)
    with open(experiment_dir / "test_accu.json", "w") as f:
        json.dump({"value": accu_test}, f, indent=4)

    whole_training_end = time.time()
    log(f"Done in {trainer.curr_epoch - 1} epochs, {whole_training_end - whole_training_start:.2f}s", display=True)
except Exception as e:
    log(f"Exception ocurred for {ds_info.name}: {e}", display=True)
    tb_str = traceback.format_tb(e.__traceback__)
    log('\n'.join(tb_str), display=True)
    raise
finally:
    logclose()

Training for ArtificialDataset, proto len 20, reception 0.75, features_lr 0.001, protos per class 1, l1_addon 0.0003
Params: {
    "protos_per_class": 1,
    "proto_features": 32,
    "proto_len_latent": 20,
    "features_lr": 0.001,
    "num_classes": 4,
    "coeffs": {
        "crs_ent": 1,
        "clst": 0,
        "sep": 0,
        "l1": 0.001,
        "l1_addon": 0.0003
    },
    "num_warm_epochs": 50,
    "push_start_epoch": 110,
    "num_last_layer_epochs": 40,
    "epochs": 200
}
Training encoder
epoch:   10/300 mse loss: 0.0104
epoch:   20/300 mse loss: 0.0164
epoch:   30/300 mse loss: 0.0157
epoch:   40/300 mse loss: 0.0163
epoch:   50/300 mse loss: 0.0161
epoch:   60/300 mse loss: 0.0136
epoch:   70/300 mse loss: 0.0111
epoch:   80/300 mse loss: 0.0150
epoch:   90/300 mse loss: 0.0146
epoch:  100/300 mse loss: 0.0102
epoch:  110/300 mse loss: 0.0129
epoch:  120/300 mse loss: 0.0138
epoch:  130/300 mse loss: 0.0146
epoch:  140/300 mse loss: 0.0133
epoch:  150/300 mse loss: 