In [1]:
import sys
from pathlib import Path

def find_project_root(start_path: Path = Path.cwd(), marker: str = 'pyproject.toml') -> Path:
    current_path = start_path.resolve()
    for parent in [current_path] + list(current_path.parents):
        if (parent / marker).exists():
            return parent
        
def add_project_root_to_sys_path(marker: str = 'pyproject.toml'):
    project_root = find_project_root(marker=marker)
    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))

add_project_root_to_sys_path()


# Imports

In [2]:
import torch
import torch.nn as nn
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader
from src.asym_ensembles.data_loaders import load_dataset
from src.asym_ensembles.modeling.training import (
    set_global_seed,
    train_one_model,
    evaluate_model,
    evaluate_ensemble,
    average_pairwise_distance
)
from src.asym_ensembles.modeling.models import MLP, WMLP
import numpy as np
import copy

# Config

In [3]:
cfg={
    "batch_size": 64,
    "max_epochs": 200,
    "patience": 16,
    "learning_rate": 1e-3,
    "weight_decay": 3e-2,
    "hidden_dims": [64, 128, 256],
    "ensemble_sizes": [2, 4, 8, 16, 32, 64],
    "total_models": 64,             # max(ensemble_sizes)
    "repeats": 10,                  # different seeds
    "mask_type": "random_subsets",
    "base_seed": 1234,
    "device": torch.device("cuda" if torch.cuda.is_available() else "cpu")
}

In [4]:
wandb.init(project="DeepEnsembleProject", config=cfg, name="Extended_Experiments", settings=wandb.Settings(start_method="fork"))
config = wandb.config

[34m[1mwandb[0m: Currently logged in as: [33mnovitsk-oleg[0m ([33moanovi[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


# Loading Data

In [13]:
all_datasets = [
    ("california", "regression"),
    ("otto", "classification"),
    ("telcom", "classification"),
    ("mnist", "classification"),
]
for dataset_name, task_type in all_datasets:
    train_ds, val_ds, test_ds = load_dataset(dataset_name=dataset_name)
    train_loader = DataLoader(train_ds, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=config.batch_size, shuffle=False)
    test_loader = DataLoader(test_ds, batch_size=config.batch_size, shuffle=False)
    for hidden_dim in config.hidden_dims:
        print(f"\nDataset: {dataset_name}, Hidden_dim: {hidden_dim}")
        if hidden_dim in [64, 128]:
            second_nfix = 3
        else:
            second_nfix = 4
        mask_params = {
            0: {'mask_constant': 1, 'mask_type': config.mask_type, 'do_normal_mask': True, 'num_fixed': 2},
            1: {'mask_constant': 1, 'mask_type': config.mask_type, 'do_normal_mask': True, 'num_fixed': second_nfix},
            2: {'mask_constant': 1, 'mask_type': config.mask_type, 'do_normal_mask': True, 'num_fixed': second_nfix},
            3: {'mask_constant': 1, 'mask_type': config.mask_type, 'do_normal_mask': True, 'num_fixed': second_nfix},
        }
        in_dim = train_ds.tensors[0].shape[1]
        if task_type == "regression":
            out_dim = 1
            criterion = nn.MSELoss()
        else:
            out_dim = len(torch.unique(train_ds.tensors[1]))
            criterion = nn.CrossEntropyLoss()

        for rep_i in range(config.repeats):
            print(f"\nRepetition {rep_i + 1}/{config.repeats}")
            current_seed = config.base_seed + rep_i * 10000

            mlp_metrics = []
            wmlp_metrics = []
            wmlp_masked_ratios = []
            mlp_models = []
            wmlp_models = []

            for i in tqdm(range(config.total_models), desc="Training MLP"):
                seed_value = current_seed + i
                set_global_seed(seed_value)

                mlp = MLP(in_dim, hidden_dim, out_dim, num_layers=4, norm=None)
                optimizer = torch.optim.AdamW(mlp.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

                mlp, train_time, train_losses, val_losses = train_one_model(
                    mlp, train_loader, val_loader, criterion, optimizer,
                    device=config.device, max_epochs=config.max_epochs, patience=config.patience
                )
                mlp.to("cpu")
                mlp_models.append(copy.deepcopy(mlp))

                test_metric = evaluate_model(mlp, test_loader, criterion, config.device, task_type=task_type)
                mlp_metrics.append(test_metric)

            for i in tqdm(range(config.total_models), desc="Training WMLP"):
                seed_value_wmlp = current_seed + 2000 + i
                set_global_seed(seed_value_wmlp)

                wmlp = WMLP(in_dim, hidden_dim, out_dim, num_layers=4, mask_params=mask_params, norm=None)
                optimizer_wmlp = torch.optim.AdamW(wmlp.parameters(), lr=config.learning_rate, weight_decay=config.weight_decay)

                wmlp, train_time_wmlp, train_losses_w, val_losses_w = train_one_model(
                    wmlp, train_loader, val_loader, criterion, optimizer_wmlp,
                    device=config.device, max_epochs=config.max_epochs, patience=config.patience
                )
                wmlp.to("cpu")
                wmlp_models.append(copy.deepcopy(wmlp))
                test_metric_wmlp = evaluate_model(wmlp, test_loader, criterion, config.device, task_type=task_type)
                wmlp_metrics.append(test_metric_wmlp)

                ratio, masked = wmlp.report_masked_ratio()
                wmlp_masked_ratios.append(ratio)

            avg_dist_mlp = average_pairwise_distance(mlp_models)
            avg_dist_wmlp = average_pairwise_distance(wmlp_models)

            avg_wmlp_masked_ratio = float(np.mean(wmlp_masked_ratios)) if wmlp_masked_ratios else 0.0

            ensemble_results_mlp = {}
            ensemble_results_wmlp = {}
            for ens_size in config.ensemble_sizes:
                mlp_sub = mlp_models[:ens_size]
                wmlp_sub = wmlp_models[:ens_size]
                ens_metric_mlp = evaluate_ensemble(mlp_sub, test_loader, config.device, task_type=task_type)
                ens_metric_wmlp = evaluate_ensemble(wmlp_sub, test_loader, config.device, task_type=task_type)
                ensemble_results_mlp[ens_size] = ens_metric_mlp
                ensemble_results_wmlp[ens_size] = ens_metric_wmlp

            log_dict = {
                "dataset_name": dataset_name,
                "task_type": task_type,
                "hidden_dim": hidden_dim,
                "repeat_index": rep_i + 1,
                "avg_dist_mlp": avg_dist_mlp,
                "avg_dist_wmlp": avg_dist_wmlp,
                "avg_wmlp_masked_ratio": avg_wmlp_masked_ratio,
                "mean_mlp_metric": float(np.mean(mlp_metrics)),
                "mean_wmlp_metric": float(np.mean(wmlp_metrics)),

            }

            for ens_size in config.ensemble_sizes:
                log_dict[f"mlp_ens_{ens_size}"] = ensemble_results_mlp[ens_size]
                log_dict[f"wmlp_ens_{ens_size}"] = ensemble_results_wmlp[ens_size]

            wandb.log(log_dict)

            print(f"Repetition {rep_i + 1}/{config.repeats} finished.")
            break
        break


Dataset: california, Hidden_dim: 64

Repetition 1/10


Epoch:  50%|█████     | 101/200 [00:21<00:20,  4.73it/s]
Training MLP:   2%|▏         | 1/64 [00:22<23:15, 22.15s/it]

Early stopping at epoch 102


Epoch:  60%|█████▉    | 119/200 [00:25<00:17,  4.72it/s]
Training MLP:   3%|▎         | 2/64 [00:47<24:46, 23.98s/it]

Early stopping at epoch 120


Epoch:  56%|█████▌    | 111/200 [00:24<00:20,  4.45it/s]
Training MLP:   5%|▍         | 3/64 [01:12<24:50, 24.44s/it]

Early stopping at epoch 112


Epoch:  51%|█████     | 102/200 [00:21<00:20,  4.75it/s]
Training MLP:   6%|▋         | 4/64 [01:33<23:16, 23.28s/it]

Early stopping at epoch 103


Epoch:  66%|██████▌   | 131/200 [00:27<00:14,  4.78it/s]
Training MLP:   8%|▊         | 5/64 [02:01<24:21, 24.78s/it]

Early stopping at epoch 132


Epoch:  50%|█████     | 100/200 [00:21<00:21,  4.55it/s]
Training MLP:   9%|▉         | 6/64 [02:23<23:02, 23.84s/it]

Early stopping at epoch 101


Epoch:  59%|█████▉    | 118/200 [00:25<00:17,  4.71it/s]
Training MLP:  11%|█         | 7/64 [02:48<23:02, 24.25s/it]

Early stopping at epoch 119


Epoch:  62%|██████▎   | 125/200 [00:27<00:16,  4.62it/s]
Training MLP:  12%|█▎        | 8/64 [03:15<23:28, 25.15s/it]

Early stopping at epoch 126


Epoch:  64%|██████▎   | 127/200 [00:27<00:15,  4.59it/s]
Training MLP:  14%|█▍        | 9/64 [03:43<23:46, 25.94s/it]

Early stopping at epoch 128


Epoch:  31%|███       | 62/200 [00:13<00:29,  4.70it/s]
Training MLP:  16%|█▌        | 10/64 [03:56<19:48, 22.01s/it]

Early stopping at epoch 63


Epoch:  66%|██████▌   | 132/200 [00:27<00:14,  4.82it/s]
Training MLP:  17%|█▋        | 11/64 [04:23<20:54, 23.67s/it]

Early stopping at epoch 133


Epoch:  41%|████      | 82/200 [00:17<00:25,  4.68it/s]
Training MLP:  19%|█▉        | 12/64 [04:41<18:54, 21.81s/it]

Early stopping at epoch 83


Epoch:  63%|██████▎   | 126/200 [00:26<00:15,  4.71it/s]
Training MLP:  20%|██        | 13/64 [05:08<19:48, 23.31s/it]

Early stopping at epoch 127


Epoch:  58%|█████▊    | 116/200 [00:26<00:18,  4.45it/s]
Training MLP:  22%|██▏       | 14/64 [05:34<20:07, 24.15s/it]

Early stopping at epoch 117


Epoch:  72%|███████▏  | 143/200 [00:30<00:12,  4.62it/s]
Training MLP:  23%|██▎       | 15/64 [06:05<21:23, 26.20s/it]

Early stopping at epoch 144


Epoch:  84%|████████▍ | 169/200 [00:35<00:06,  4.77it/s]
Training MLP:  25%|██▌       | 16/64 [06:40<23:11, 28.98s/it]

Early stopping at epoch 170


Epoch:  48%|████▊     | 96/200 [00:21<00:23,  4.42it/s]
Training MLP:  27%|██▋       | 17/64 [07:02<21:00, 26.81s/it]

Early stopping at epoch 97


Epoch:  67%|██████▋   | 134/200 [00:28<00:13,  4.75it/s]
Training MLP:  28%|██▊       | 18/64 [07:30<20:53, 27.25s/it]

Early stopping at epoch 135


Epoch:  64%|██████▍   | 128/200 [00:26<00:14,  4.82it/s]
Training MLP:  30%|██▉       | 19/64 [07:57<20:16, 27.04s/it]

Early stopping at epoch 129


Epoch:  51%|█████     | 102/200 [00:21<00:20,  4.77it/s]
Training MLP:  31%|███▏      | 20/64 [08:18<18:35, 25.35s/it]

Early stopping at epoch 103


Epoch:  57%|█████▋    | 114/200 [00:23<00:17,  4.88it/s]
Training MLP:  33%|███▎      | 21/64 [08:42<17:44, 24.76s/it]

Early stopping at epoch 115


Epoch:  36%|███▌      | 71/200 [00:15<00:28,  4.60it/s]
Training MLP:  34%|███▍      | 22/64 [08:57<15:22, 21.96s/it]

Early stopping at epoch 72


Epoch:  68%|██████▊   | 137/200 [00:28<00:12,  4.87it/s]
Training MLP:  36%|███▌      | 23/64 [09:25<16:16, 23.82s/it]

Early stopping at epoch 138


Epoch:  54%|█████▎    | 107/200 [00:22<00:19,  4.65it/s]
Training MLP:  38%|███▊      | 24/64 [09:48<15:43, 23.58s/it]

Early stopping at epoch 108


Epoch:  76%|███████▌  | 151/200 [00:31<00:10,  4.73it/s]
Training MLP:  39%|███▉      | 25/64 [10:20<16:57, 26.10s/it]

Early stopping at epoch 152


Epoch:  61%|██████    | 122/200 [00:25<00:16,  4.84it/s]
Training MLP:  41%|████      | 26/64 [10:45<16:21, 25.84s/it]

Early stopping at epoch 123


Epoch:  62%|██████▏   | 123/200 [00:24<00:15,  4.98it/s]
Training MLP:  42%|████▏     | 27/64 [11:10<15:43, 25.51s/it]

Early stopping at epoch 124


Epoch:  49%|████▉     | 98/200 [00:20<00:21,  4.67it/s]
Training MLP:  44%|████▍     | 28/64 [11:31<14:29, 24.16s/it]

Early stopping at epoch 99


Epoch:  55%|█████▌    | 110/200 [00:24<00:19,  4.52it/s]
Training MLP:  45%|████▌     | 29/64 [11:55<14:07, 24.22s/it]

Early stopping at epoch 111


Epoch:  56%|█████▌    | 111/200 [00:23<00:18,  4.75it/s]
Training MLP:  47%|████▋     | 30/64 [12:19<13:35, 23.98s/it]

Early stopping at epoch 112


Epoch:  68%|██████▊   | 136/200 [00:30<00:14,  4.51it/s]
Training MLP:  48%|████▊     | 31/64 [12:49<14:12, 25.84s/it]

Early stopping at epoch 137


Epoch:  52%|█████▏    | 103/200 [00:22<00:20,  4.64it/s]
Training MLP:  50%|█████     | 32/64 [13:11<13:11, 24.75s/it]

Early stopping at epoch 104


Epoch:  54%|█████▍    | 108/200 [00:23<00:19,  4.66it/s]
Training MLP:  52%|█████▏    | 33/64 [13:34<12:32, 24.29s/it]

Early stopping at epoch 109


Epoch:  59%|█████▉    | 118/200 [00:24<00:16,  4.87it/s]
Training MLP:  53%|█████▎    | 34/64 [13:59<12:08, 24.28s/it]

Early stopping at epoch 119


Epoch:  52%|█████▎    | 105/200 [00:23<00:21,  4.51it/s]
Training MLP:  55%|█████▍    | 35/64 [14:22<11:35, 23.99s/it]

Early stopping at epoch 106


Epoch:  65%|██████▌   | 130/200 [00:27<00:14,  4.68it/s]
Training MLP:  56%|█████▋    | 36/64 [14:50<11:43, 25.14s/it]

Early stopping at epoch 131


Epoch:  76%|███████▌  | 152/200 [00:31<00:09,  4.89it/s]
Training MLP:  58%|█████▊    | 37/64 [15:21<12:07, 26.93s/it]

Early stopping at epoch 153


Epoch:  50%|████▉     | 99/200 [00:20<00:21,  4.75it/s]
Training MLP:  59%|█████▉    | 38/64 [15:42<10:53, 25.12s/it]

Early stopping at epoch 100


Epoch:  50%|█████     | 101/200 [00:20<00:20,  4.86it/s]
Training MLP:  61%|██████    | 39/64 [16:03<09:55, 23.82s/it]

Early stopping at epoch 102


Epoch:  39%|███▉      | 78/200 [00:16<00:25,  4.85it/s]
Training MLP:  62%|██████▎   | 40/64 [16:19<08:36, 21.51s/it]

Early stopping at epoch 79


Epoch:  50%|█████     | 100/200 [00:21<00:21,  4.67it/s]
Training MLP:  64%|██████▍   | 41/64 [16:40<08:14, 21.48s/it]

Early stopping at epoch 101


Epoch:  55%|█████▍    | 109/200 [00:22<00:18,  4.80it/s]
Training MLP:  66%|██████▌   | 42/64 [17:03<08:00, 21.85s/it]

Early stopping at epoch 110


Epoch:  46%|████▌     | 92/200 [00:19<00:23,  4.64it/s]
Training MLP:  67%|██████▋   | 43/64 [17:23<07:26, 21.26s/it]

Early stopping at epoch 93


Epoch:  66%|██████▋   | 133/200 [00:27<00:13,  4.89it/s]
Training MLP:  69%|██████▉   | 44/64 [17:50<07:40, 23.05s/it]

Early stopping at epoch 134


Epoch:  69%|██████▉   | 138/200 [00:27<00:12,  4.96it/s]
Training MLP:  70%|███████   | 45/64 [18:18<07:45, 24.49s/it]

Early stopping at epoch 139


Epoch:  61%|██████    | 122/200 [00:25<00:16,  4.72it/s]
Training MLP:  72%|███████▏  | 46/64 [18:44<07:28, 24.90s/it]

Early stopping at epoch 123


Epoch:  63%|██████▎   | 126/200 [00:27<00:15,  4.63it/s]
Training MLP:  73%|███████▎  | 47/64 [19:11<07:15, 25.60s/it]

Early stopping at epoch 127


Epoch:  80%|███████▉  | 159/200 [00:32<00:08,  4.83it/s]
Training MLP:  75%|███████▌  | 48/64 [19:44<07:24, 27.80s/it]

Early stopping at epoch 160


Epoch:  65%|██████▌   | 130/200 [00:26<00:14,  4.86it/s]
Training MLP:  77%|███████▋  | 49/64 [20:11<06:52, 27.49s/it]

Early stopping at epoch 131


Epoch:  60%|█████▉    | 119/200 [00:23<00:16,  4.97it/s]
Training MLP:  78%|███████▊  | 50/64 [20:35<06:10, 26.45s/it]

Early stopping at epoch 120


Epoch:  53%|█████▎    | 106/200 [00:22<00:19,  4.71it/s]
Training MLP:  80%|███████▉  | 51/64 [20:57<05:28, 25.27s/it]

Early stopping at epoch 107


Epoch:  54%|█████▎    | 107/200 [00:23<00:20,  4.46it/s]
Training MLP:  81%|████████▏ | 52/64 [21:21<04:58, 24.90s/it]

Early stopping at epoch 108


Epoch:  61%|██████    | 122/200 [00:27<00:17,  4.49it/s]
Training MLP:  83%|████████▎ | 53/64 [21:48<04:41, 25.60s/it]

Early stopping at epoch 123




In [None]:
wandb.finish()