In [1]:
%load_ext autoreload
%autoreload 2
import sys
import torch
from torch import nn
sys.path.append('..')
# sys.path.append('/system/user/beck/pwbeck/projects/regularization/ml_utilities')
from pathlib import Path
from typing import Union
from ml_utilities.torch_models.base_model import BaseModel
from ml_utilities.torch_models.fc import FC
from ml_utilities.torch_models import get_model_class
from omegaconf import OmegaConf
from ml_utilities import utils as ml_util
from erank.utils import load_directions_matrix_from_task_sweep
import matplotlib.pyplot as plt
from hydra.utils import get_original_cwd
import torchvision
from torch.utils import data
from torchvision import transforms
from erank.data import get_dataset_provider
from erank.data.data_utils import random_split_train_tasks
gpu_id = 0

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
config_dir = './../configs/config_debug_local.yaml'


In [3]:
def random_split_train_tasks_debug(dataset: data.Dataset, num_train_tasks: int = 1, train_task_idx: int = 0,
                             train_val_split: float = 0.8, seed: int = 0,
                             num_subsplit_tasks: int = 0, subsplit_first_n_train_tasks: int = 0, **kwargs):
    """Splits a dataset into different (sample-wise) training tasks. 
    Each training task has different set of data samples. Validation set is same for every task.

    Args:
        dataset (data.Dataset): The dataset to split. 
        num_train_tasks (int, optional): Number of training tasks to split. Defaults to 1.
        train_task_idx (int, optional): The current training task. Defaults to 0.
        train_val_split (float, optional): Fraction of train/val samples. Defaults to 0.8.
        seed (int, optional): The seed. Defaults to 0.

    Returns:
        Tuple[data.Dataset, data.Dataset]: train dataset, val dataset
    """
    assert train_task_idx >= 0 and train_task_idx < (
        num_train_tasks - subsplit_first_n_train_tasks) + num_subsplit_tasks, 'Invalid train_task_idx given.'

    n_train_samples = int(train_val_split * len(dataset))

    n_samples_per_task = int(n_train_samples / num_train_tasks)

    train_split_lengths = num_train_tasks * [n_samples_per_task]

    # make sure that sum of all splits equal total number of samples in dataset
    # n_val_samples can be greater than specified by train_val_split
    n_val_samples = len(dataset) - torch.tensor(train_split_lengths).sum().item()

    split_lengths = num_train_tasks * [n_samples_per_task] + [n_val_samples]
    data_splits = data.random_split(dataset, split_lengths, generator=torch.Generator().manual_seed(seed))

    if num_subsplit_tasks > 0:
        # further split first ´subsplit_first_n_train_tasks´ into `num_subsplit_tasks`
        subsplit_dataset = data.ConcatDataset(data_splits[:subsplit_first_n_train_tasks])
        # remove first n train tasks idxs from data split list
        data_splits = data_splits[subsplit_first_n_train_tasks:]
        n_samples_per_subsplit = int(len(subsplit_dataset) / num_subsplit_tasks)

        subsplit_lengths = num_subsplit_tasks * [n_samples_per_subsplit]
        # distribute remaining samples (due to rounding) from beginning
        samples_remaining = len(subsplit_dataset) - sum(subsplit_lengths)
        for i in range(len(subsplit_lengths)):
            if samples_remaining <= 0:
                break
            subsplit_lengths[i] += 1
            samples_remaining -= 1

        assert sum(subsplit_lengths) == len(subsplit_dataset)

        data_subsplits = data.random_split(subsplit_dataset, subsplit_lengths,
                                           generator=torch.Generator().manual_seed(seed + 1))

        # concat data_splits: [subsplit sets] + train sets + val set
        data_splits = data_subsplits + data_splits
    # # select train task split + val split
    # return data_splits[train_task_idx], data_splits[-1]
    return data_splits


In [4]:
cfg = ml_util.get_config(config_dir).config
cfg

{'experiment_data': {'project_name': 'erank_supervised', 'experiment_name': 'f_mnist-erank-DEBUG', 'experiment_dir': None, 'seed': 0, 'gpu_id': 0}, 'wandb': {'tags': ['DEBUG'], 'notes': 'Trying different things.', 'watch': {'log': 'all', 'log_freq': 100}}, 'model': {'name': 'cnn2d', 'out_channels': 128, 'model_kwargs': {'image_size': 28, 'input_channels': 1, 'act_fn': 'relu', 'layer_configs': [{'out_channels': '${config.model.out_channels}', 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}, {'out_channels': '${config.model.out_channels}', 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}, {'out_channels': '${config.model.out_channels}', 'kernel_size': 3, 'batch_norm': True, 'stride': 1, 'padding': 0, 'max_pool_kernel_size': 2}], 'linear_output_units': [10]}}, 'trainer': {'n_epochs': 300, 'val_every': 1, 'save_every': 50, 'early_stopping_patience': 20, 'batch_size': 512, 'optimizer_scheduler': {'optimiz

In [5]:
data_cfg = cfg.data
provide_dataset = get_dataset_provider(dataset_name=data_cfg.dataset)
train_dataset = provide_dataset(data_cfg.dataset_kwargs)
# train_set, val_set = random_split_train_tasks(
#     train_dataset, num_train_tasks=data_cfg.num_train_tasks, train_task_idx=data_cfg.train_task_idx,
#     train_val_split=data_cfg.train_val_split)

Files already downloaded and verified


In [9]:
data_splits = random_split_train_tasks(train_dataset, **data_cfg.dataset_split)
for ds in data_splits:
    print(len(ds))

100
10012


In [10]:
data_splits = random_split_train_tasks_debug(train_dataset, **data_cfg.dataset_split)
for ds in data_splits:
    print(len(ds))

TypeError: random_split_train_tasks_debug() got an unexpected keyword argument 'restrict_n_samples_train_task'