In [1]:
import sys
from pathlib import Path
from copy import deepcopy
def find_project_root(start_path: Path = Path.cwd(), marker: str = 'pyproject.toml') -> Path:
    current_path = start_path.resolve()
    for parent in [current_path] + list(current_path.parents):
        if (parent / marker).exists():
            return parent
        
def add_project_root_to_sys_path(marker: str = 'pyproject.toml'):
    project_root = find_project_root(marker=marker)
    if str(project_root) not in sys.path:
        sys.path.insert(0, str(project_root))

add_project_root_to_sys_path()


In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, random_split
import numpy as np
import copy

from src.asym_ensembles.data_loaders import load_california_housing
from src.asym_ensembles.modeling.training import set_global_seed, train_one_model
from src.asym_ensembles.modeling.models import WMLP

In [3]:
train_ds, val_ds, test_ds = load_california_housing()
subset_size = 2000
train_ds_small, _ = random_split(train_ds, [subset_size, len(train_ds) - subset_size])

batch_size = 64
train_loader = DataLoader(train_ds_small, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

in_dim = train_ds.tensors[0].shape[1]
hidden_dim = 16
out_dim = 1
num_layers = 4

In [4]:
seed_value = 123


mask_params = {
    0: {
        "mask_constant": 1,
        "mask_type": "random_subsets",
        "do_normal_mask": True,
        "num_fixed": 2,
    },
    1: {
        "mask_constant": 1,
        "mask_type": "random_subsets",
        "do_normal_mask": True,
        "num_fixed": 3,
    },
    2: {
        "mask_constant": 1,
        "mask_type": "random_subsets",
        "do_normal_mask": True,
        "num_fixed": 3,
    },
    3: {
        "mask_constant": 1,
        "mask_type": "random_subsets",
        "do_normal_mask": True,
        "num_fixed": 3,
    },
}
set_global_seed(seed_value)
wmlp1 = WMLP(in_dim, hidden_dim, out_dim, num_layers, mask_params, norm=None)



In [5]:
set_global_seed(seed_value + 1)
wmlp2 = WMLP(in_dim, hidden_dim, out_dim, num_layers, mask_params, norm=None)

In [6]:
first_layer10 = deepcopy(wmlp1.lins[0])  # SparseLinear
first_layer20 = deepcopy(wmlp2.lins[0])

In [7]:
first_layer10.mask == first_layer20.mask

tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])

In [8]:
first_layer10.weight * (1 - first_layer10.mask) == first_layer20.weight * (1 - first_layer20.mask)

tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])

In [9]:
first_layer10.weight * first_layer10.mask == first_layer20.weight * first_layer20.mask

tensor([[False,  True, False, False, False, False, False,  True],
        [False,  True, False,  True, False, False, False, False],
        [ True, False, False, False, False, False,  True, False],
        [False, False,  True, False, False, False, False,  True],
        [False, False, False,  True, False,  True, False, False],
        [False,  True, False, False, False, False, False,  True],
        [ True, False, False, False, False, False, False,  True],
        [False, False, False,  True, False, False, False,  True],
        [False,  True, False, False, False, False,  True, False],
        [ True, False, False,  True, False, False, False, False],
        [False, False, False, False, False,  True, False,  True],
        [False, False, False, False,  True, False,  True, False],
        [ True, False, False, False,  True, False, False, False],
        [False,  True, False, False, False, False, False,  True],
        [False, False,  True, False,  True, False, False, False],
        [F

In [10]:
criterion = nn.MSELoss()

optimizer1 = torch.optim.AdamW(wmlp1.parameters(), lr=1e-3, weight_decay=0.0)
optimizer2 = torch.optim.AdamW(wmlp2.parameters(), lr=1e-3, weight_decay=0.0)

def train_for_epochs(model, optimizer, train_loader, epochs=5):
    device = "cpu"
    model.to(device)
    for epoch in range(epochs):
        model.train()
        for Xb, yb in train_loader:
            Xb, yb = Xb.to(device), yb.to(device)
            optimizer.zero_grad()
            preds = model(Xb)
            loss = criterion(preds, yb)
            loss.backward()
            optimizer.step()
    return model

train_for_epochs(wmlp1, optimizer1, train_loader, epochs=5)
train_for_epochs(wmlp2, optimizer2, train_loader, epochs=5)

WMLP(
  (lins): ModuleList(
    (0-3): 4 x SparseLinear()
  )
  (activation): GELU(approximate='none')
  (flatten): Flatten(start_dim=1, end_dim=-1)
)

In [11]:
first_layer1 = wmlp1.lins[0]  # SparseLinear
first_layer2 = wmlp2.lins[0]

In [12]:
first_layer1.mask == first_layer2.mask

tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])

In [13]:
first_layer1.weight * (1 - first_layer1.mask) == first_layer2.weight * (1 - first_layer2.mask)

tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])

In [14]:
first_layer1.weight * first_layer1.mask == first_layer2.weight * first_layer2.mask

tensor([[False,  True, False, False, False, False, False,  True],
        [False,  True, False,  True, False, False, False, False],
        [ True, False, False, False, False, False,  True, False],
        [False, False,  True, False, False, False, False,  True],
        [False, False, False,  True, False,  True, False, False],
        [False,  True, False, False, False, False, False,  True],
        [ True, False, False, False, False, False, False,  True],
        [False, False, False,  True, False, False, False,  True],
        [False,  True, False, False, False, False,  True, False],
        [ True, False, False,  True, False, False, False, False],
        [False, False, False, False, False,  True, False,  True],
        [False, False, False, False,  True, False,  True, False],
        [ True, False, False, False,  True, False, False, False],
        [False,  True, False, False, False, False, False,  True],
        [False, False,  True, False,  True, False, False, False],
        [F

In [15]:
first_layer1.weight * (1 - first_layer1.mask) == first_layer10.weight * (1 - first_layer10.mask)

tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])

In [16]:
first_layer1.weight * (first_layer1.mask) == first_layer10.weight * (first_layer10.mask)

tensor([[False,  True, False, False, False, False, False,  True],
        [False,  True, False,  True, False, False, False, False],
        [ True, False, False, False, False, False,  True, False],
        [False, False,  True, False, False, False, False,  True],
        [False, False, False,  True, False,  True, False, False],
        [False,  True, False, False, False, False, False,  True],
        [ True, False, False, False, False, False, False,  True],
        [False, False, False,  True, False, False, False,  True],
        [False,  True, False, False, False, False,  True, False],
        [ True, False, False,  True, False, False, False, False],
        [False, False, False, False, False,  True, False,  True],
        [False, False, False, False,  True, False,  True, False],
        [ True, False, False, False,  True, False, False, False],
        [False,  True, False, False, False, False, False,  True],
        [False, False,  True, False,  True, False, False, False],
        [F

In [17]:
first_layer2.weight * (1 - first_layer2.mask) == first_layer20.weight * (1 - first_layer20.mask)

tensor([[True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True],
        [True, True, True, True, True, True, True, True]])

In [18]:
first_layer2.weight * (first_layer2.mask) == first_layer20.weight * (first_layer20.mask)

tensor([[False,  True, False, False, False, False, False,  True],
        [False,  True, False,  True, False, False, False, False],
        [ True, False, False, False, False, False,  True, False],
        [False, False,  True, False, False, False, False,  True],
        [False, False, False,  True, False,  True, False, False],
        [False,  True, False, False, False, False, False,  True],
        [ True, False, False, False, False, False, False,  True],
        [False, False, False,  True, False, False, False,  True],
        [False,  True, False, False, False, False,  True, False],
        [ True, False, False,  True, False, False, False, False],
        [False, False, False, False, False,  True, False,  True],
        [False, False, False, False,  True, False,  True, False],
        [ True, False, False, False,  True, False, False, False],
        [False,  True, False, False, False, False, False,  True],
        [False, False,  True, False,  True, False, False, False],
        [F