# TODO:
1. Нужно random seed в саму генерацию масок прокинуть, а то там одна и та же маска на слой будет генерится. Или это норм?
2. Привести wandb в порядок (настроить, чтобы графики нормально отображались и чтобы все сохранялось в т ч модели)

In [1]:
import sys
from pathlib import Path

def find_project_root(start_path: Path = Path.cwd(), marker: str = 'pyproject.toml') -> Path:
    current_path = start_path.resolve()
    for parent in [current_path] + list(current_path.parents):
        if (parent / marker).exists():
            return parent
        
def add_project_root_to_sys_path(marker: str = 'pyproject.toml'):
    project_root = find_project_root(marker=marker)
    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))

add_project_root_to_sys_path()


# Imports

In [2]:
import torch
import torch.nn as nn
import wandb
from tqdm import tqdm
from torch.utils.data import DataLoader
from src.asym_ensembles.data_loaders import load_california_housing, load_wine_quality, load_covertype_dataset
from src.asym_ensembles.modeling.training import (
    set_global_seed,
    train_one_model,
    evaluate_model,
    evaluate_ensemble
)
from src.asym_ensembles.modeling.models import MLP, WMLP
import pickle
import io
import tempfile

# Config

In [3]:

config = {
    "task_type": "classification", # "regression" or "classification"
    "batch_size": 64,
    "epochs": 25,
    "lr": 1e-3,
    "weight_decay": 3e-2,
    "hidden_dim": 64,
    "num_layers": 4,
    "ensemble_sizes": [2,4,8,16,32,64],
    "total_models": 64, # max(ensemble_sizes)
    "device": "cuda" if torch.cuda.is_available() else "cpu",
    "mask_params": {
    0: {'mask_constant': 1, 'mask_type': 'random_subsets', 'do_normal_mask': True, 'num_fixed': 2}, # first layer is the most influent!
    1: {'mask_constant': 1, 'mask_type': 'random_subsets', 'do_normal_mask': True, 'num_fixed': 4},
    2: {'mask_constant': 1, 'mask_type': 'random_subsets', 'do_normal_mask': True, 'num_fixed': 4},
    3: {'mask_constant': 1, 'mask_type': 'random_subsets', 'do_normal_mask': True, 'num_fixed': 4},
}
}

In [4]:
wandb.init(project="DeepEnsembleProject", config=config, name="DeepEnsembles")

[34m[1mwandb[0m: Currently logged in as: [33mnovitsk-oleg[0m ([33moanovi[0m). Use [1m`wandb login --relogin`[0m to force relogin
[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.


# Loading Data

In [5]:
if config["task_type"] == "regression":
    train_ds, val_ds, test_ds = load_california_housing()
    criterion = nn.MSELoss()
else:
    train_ds, val_ds, test_ds = load_covertype_dataset()
    criterion = nn.CrossEntropyLoss()

train_loader = DataLoader(train_ds, batch_size=config["batch_size"], shuffle=True)
val_loader = DataLoader(val_ds, batch_size=config["batch_size"], shuffle=False)
test_loader = DataLoader(test_ds, batch_size=config["batch_size"], shuffle=False)

in_dim = train_ds.tensors[0].shape[1]
if config["task_type"] == "regression":
    out_dim = 1
else:
    out_dim = len(torch.unique(train_ds.tensors[1]))

# Train base estimators (MLP and WMLP)

In [6]:
mlp_models = []
wmlp_models = []

mlp_test_metrics = []
wmlp_test_metrics = []

mlp_times = []
wmlp_times = []

## MLP

In [8]:
for i in tqdm(range(config["total_models"]), desc="Train MLP"):
    seed_value = 1000 + i
    set_global_seed(seed_value)

    mlp = MLP(in_dim, config["hidden_dim"], out_dim, config["num_layers"], norm=None)
    
    wandb.watch(mlp, log="all", log_freq=100)

    optimizer = torch.optim.AdamW(mlp.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
    mlp, train_time, train_losses, val_losses = train_one_model(
        mlp, train_loader, val_loader, criterion, optimizer,
        device=config["device"], epochs=config["epochs"]
    )
    mlp_models.append(mlp)
    mlp_times.append(train_time)

    metric_mlp = evaluate_model(mlp, test_loader, criterion, config["device"], task_type=config["task_type"])
    mlp_test_metrics.append(metric_mlp)

    wandb.log({
        "model_idx": i,
        "MLP_test_metric": metric_mlp,
        "MLP_train_time": train_time
    })

Train MLP:   0%|          | 0/64 [00:00<?, ?it/s]

Train MLP:   2%|▏         | 1/64 [02:47<2:55:25, 167.07s/it]


KeyboardInterrupt: 

## WMLP

In [None]:
for i in tqdm(range(config["total_models"]), desc="Train WMLP"):
    seed_value_wmlp = 2000 + i
    set_global_seed(seed_value_wmlp)

    wmlp = WMLP(in_dim, config["hidden_dim"], out_dim, config["num_layers"], config["mask_params"], norm=None)
    
    wandb.watch(wmlp, log="all", log_freq=100)

    optimizer_wmlp = torch.optim.AdamW(wmlp.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
    wmlp, wmlp_train_time, train_losses_w, val_losses_w = train_one_model(
        wmlp, train_loader, val_loader, criterion, optimizer_wmlp,
        device=config["device"], epochs=config["epochs"]
    )
    wmlp_models.append(wmlp)
    wmlp_times.append(wmlp_train_time)

    metric_wmlp = evaluate_model(wmlp, test_loader, criterion, config["device"], task_type=config["task_type"])
    wmlp_test_metrics.append(metric_wmlp)

    wandb.log({
        "model_idx": i,
        "WMLP_test_metric": metric_wmlp,
        "WMLP_train_time": wmlp_train_time
    })

Train WMLP:   0%|          | 0/64 [00:00<?, ?it/s]

In [9]:
metrics_table = wandb.Table(columns=["model_idx", "MLP_test_metric", "WMLP_test_metric"])
time_table = wandb.Table(columns=["model_idx", "MLP_train_time", "WMLP_train_time"])

for i in range(config["total_models"]):
    metrics_table.add_data(i, mlp_test_metrics[i], wmlp_test_metrics[i])
    time_table.add_data(i, mlp_times[i], wmlp_times[i])

wandb.log({"test_metrics_table": metrics_table})
wandb.log({"train_times_table": time_table})

# Building ensembles

In [10]:
ensemble_table = wandb.Table(columns=["ensemble_size", "mlp_ens_metric", "wmlp_ens_metric"])

for ensemble_size in config["ensemble_sizes"]:
    # MLP
    mlp_sub = mlp_models[:ensemble_size]
    mlp_ens_metric = evaluate_ensemble(mlp_sub, test_loader, config["device"], config["task_type"])

    # WMLP
    wmlp_sub = wmlp_models[:ensemble_size]
    wmlp_ens_metric = evaluate_ensemble(wmlp_sub, test_loader, config["device"], config["task_type"])

    wandb.log({
        "ensemble_size": ensemble_size,
        "mlp_ens_metric": mlp_ens_metric,
        "wmlp_ens_metric": wmlp_ens_metric
    })

    ensemble_table.add_data(ensemble_size, mlp_ens_metric, wmlp_ens_metric)

wandb.log({"ensemble_metrics_table": ensemble_table})

In [11]:
mlp_state_dicts = [m.state_dict() for m in mlp_models]
wmlp_state_dicts = [m.state_dict() for m in wmlp_models]

with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as mlp_temp:
    pickle.dump(mlp_state_dicts, mlp_temp)
    mlp_temp_path = mlp_temp.name

with tempfile.NamedTemporaryFile(delete=False, suffix=".pkl") as wmlp_temp:
    pickle.dump(wmlp_state_dicts, wmlp_temp)
    wmlp_temp_path = wmlp_temp.name

artifact = wandb.Artifact("model_lists", type="model_collection")
artifact.add_file(mlp_temp_path, "mlp_models_state_dicts.pkl")
artifact.add_file(wmlp_temp_path, "wmlp_models_state_dicts.pkl")
wandb.run.log_artifact(artifact)

wandb.finish()

0,1
MLP_test_metric,▄▄▄▆▆▆▃▄▄▄▆▆▆▃▆▃▁▄▄█▃▆▄▄▄▆▆▆▆▆▆▃▃▃▄▄▃▃▆▆
MLP_train_time,▂▃▂▂▃▆▄▄▂▆▂▁▆▁▃▂▂▂█▂▂▂▅▂▃▄▁▁▃▂▁▂▁▂▁▁▁▁▆▂
WMLP_test_metric,▇▄▅▅▄▄▅▂▂▇▄▅▅▁▂▂▄▄▇▅▅▅█▄▁▂▅█▇▅▄█▇▂▇▅▂▅▄▅
WMLP_train_time,▄▁▁▁▁▁▂▁▁▂▂▁▁█▂▁▂▂▁▁▃▁▁▃▁▁▃▁▁▁▁▁▁▁▁▂▁▃▁▄
ensemble_size,▁▁▂▃▄█
mlp_ens_metric,▁▁▁▁▁▁
model_idx,▁▂▂▃▃▄▄▄▅▅▅▅▆▆▇█▁▁▁▂▂▂▂▃▃▄▄▄▅▅▆▆▆▆▇▇▇▇██
wmlp_ens_metric,▁▁█▅▅▅

0,1
MLP_test_metric,0.97222
MLP_train_time,0.01651
WMLP_test_metric,0.91667
WMLP_train_time,0.03329
ensemble_size,64.0
mlp_ens_metric,0.97222
model_idx,63.0
wmlp_ens_metric,0.97222
