# Repeat Experiment 
## import packages

In [None]:
import warnings
import torch

warnings.filterwarnings("ignore")

In [None]:
from tcnn.utils.experiment.repeat import repeat_experiment

In [None]:
from os.path import join

dataset_name = "speechcommand"
config = {
    "order": "03",
    "task": "multiclass",
    "data": {
        "batch_size": 32,
        "size": 8000,
    },
    "network": {
        "input_channels": 1,
        "linear_size": 31680,
        "num_classes": 35,
        "first_layer_kernel_size": 80,
        "second_layer_kernel_size": 3,
    },
    "train": {
        "criterion": torch.nn.CrossEntropyLoss(),
        "checkpoint_save_dir": "checkpoints",
        "epochs": 100,
        "device": "cuda" if torch.cuda.is_available() else "cpu",
    },
    "repeat": {"num_experiments": 5, "epochs_per_experiemnt": 100, "log_save_dir": "logs"},
}

In [None]:
ORDER = config["order"]
SIZE = config["data"]["size"]
TASK = config["task"]
NUM_CLASSES = config["network"]["num_classes"]
EPOCHS = config["train"]["epochs"]
BATCH_SIZE = config["data"]["batch_size"]
INPUT_CHANNEL = config["network"]["input_channels"]
LINEAR_SIZE = config["network"]["linear_size"]
CRITERION = config["train"]["criterion"]
DEVICE = config["train"]["device"]
NUM_EXPERIMENTS = config["repeat"]["num_experiments"]
EPOCHS_PER_EXPERIMENT = config["repeat"]["epochs_per_experiemnt"]
experiment_name = (
    f"repeat-{ORDER}-{dataset_name}-{SIZE}-{NUM_EXPERIMENTS}-{EPOCHS_PER_EXPERIMENT}"
)

In [None]:
checkpoint_save_dir = join(config["train"]["checkpoint_save_dir"], experiment_name)
log_save_dir = join(config["repeat"]["log_save_dir"])

## get dataset

In [None]:
import torchaudio

train_dataset = torchaudio.datasets.SPEECHCOMMANDS(
    f"./data/",
    download=True,
    subset="training",
)
test_dataset = torchaudio.datasets.SPEECHCOMMANDS(
    f"./data/",
    download=True,
    subset="testing",
)

## reset dataloader

In [None]:
def label_to_index(word, labels):
    """
    Convert a label word to its corresponding index.

    Args:
        word (str): The label word.

    Returns:
        torch.Tensor: The index of the label.

    """
    return torch.tensor(labels.index(word))


def pad_sequence(batch):
    """
    Pad the sequences in a batch with zeros to make them the same length.

    Args:
        batch (list): A list of tensors representing the sequences.

    Returns:
        torch.Tensor: The padded batch of sequences.

    """
    batch = [item.t() for item in batch]
    batch = torch.nn.utils.rnn.pad_sequence(batch, batch_first=True, padding_value=0.0)
    return batch.permute(0, 2, 1)


def collate_fn_outside(transform, labels):
    """
    Collate function for the data loader.

    Args:
        transform (callable): A function to transform the waveform.
        labels (list): A list of labels.

    Returns:
        callable: A collate function for the data loader.

    """

    def collate_fn_inside(batch):
        """
        Collate function for the data loader.

        Args:
            batch (list): A list of data tuples.

        Returns:
            tuple: A tuple containing the batched tensors and targets.

        """
        tensors, targets = [], []

        for waveform, _, label, *_ in batch:
            waveform = transform(waveform)
            tensors += [waveform]
            targets += [label_to_index(label, labels)]

        tensors = pad_sequence(tensors)
        targets = torch.stack(targets)

        return tensors, targets

    return collate_fn_inside


def reset_dataloader(dataset, batch_size, shuffle):
    waveform, sample_rate, _, _, _ = dataset[0]
    new_sample_rate = 8000
    transform = torchaudio.transforms.Resample(
        orig_freq=sample_rate, new_freq=new_sample_rate
    )
    transformed = transform(waveform)

    labels = sorted(list(set(datapoint[2] for datapoint in dataset)))
    collate_fn = collate_fn_outside(transform, labels)

    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=BATCH_SIZE,
        shuffle=shuffle,
        collate_fn=collate_fn,
    )
    return dataloader

## import models

In [None]:
import lenet_models

model_dict = lenet_models.get_model_dict(INPUT_CHANNEL, NUM_CLASSES, LINEAR_SIZE)

In [None]:

def def_optimizer(model):
    return torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001)

def def_scheduler(optimizer):
    return torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

## define result dict

In [None]:
import torch

results = dict()

## repeat experiment

In [None]:
for model_name, model in model_dict.items():
    print("Traing model: ", model_name)
    results[model_name] = repeat_experiment(
        model,
        [train_dataset, test_dataset],
        reset_dataloader,
        BATCH_SIZE,
        CRITERION,
        def_optimizer,
        def_scheduler,
        NUM_EXPERIMENTS,
        EPOCHS_PER_EXPERIMENT,
        experiment_name + "-" + model_name,
        checkpoint_save=True,
        checkpoint_save_dir=checkpoint_save_dir,
    )
    with torch.no_grad():
        torch.cuda.empty_cache()
    print("***" * 10 + f"{model_name} done" + "***" * 10)

In [None]:
from tcnn.utils.experiment.log import show_repeat_result, save_result

show_repeat_result(results)
save_result(results, save_name=experiment_name)

In [None]:
# from tcnn.utils.experiment.plot import plot_experiment_errorbar

# plot_experiment_errorbar(
#     results, metric_key="accuracy", baseline_key="lenet", ylabel="Accuracy"
# )

In [None]:
# plot_experiment_errorbar(
#     results, metric_key="accuracy", baseline_key="lenet_relu", ylabel="Accuracy"
# )

In [None]:
# plot_experiment_errorbar(
#     results, metric_key="auc_score", baseline_key="lenet", ylabel="AUC Score"
# )

In [None]:
# plot_experiment_errorbar(
#     results, metric_key="auc_score", baseline_key="lenet_relu", ylabel="AUC Score"
# )

In [None]:
# plot_experiment_errorbar(
#     results, metric_key="f1", baseline_key="lenet", ylabel="F1 Score"
# )

In [None]:
# plot_experiment_errorbar(
#     results, metric_key="f1", baseline_key="lenet_relu", ylabel="F1 Score"
# )

In [None]:
import torch

with torch.no_grad():
    torch.cuda.empty_cache()