In [1]:
import sys
sys.path.append("../") 

In [2]:
import pandas as pd
import numpy as np

import scanpy as sc

import seaborn as sns
import matplotlib.pyplot as plt
import torch

from models import SDCI
from utils import create_intervention_dataset

In [3]:
adata = sc.read("/home/justinhong/dcdfg_preprocess/perturb-cite-seq/SCP1064/ready/control/gene_filtered_adata.h5ad")
adata

AnnData object with n_obs × n_vars = 57523 × 1657
    obs: 'library_preparation_protocol', 'condition', 'MOI', 'sgRNA', 'UMI_count', 'sgRNAs', 'n_genes', 'targets', 'regimes'
    var: 'n_cells', 'highly_variable', 'highly_variable_rank', 'means', 'variances', 'variances_norm', 'targeted'
    uns: 'hvg', 'log1p'
    layers: 'counts'

In [4]:
X_df = pd.DataFrame(adata.X.todense(), index=adata.obs_names, columns=adata.var_names)
X_df["perturbation_label"] = adata.obs["targets"]
dataset = create_intervention_dataset(X_df)

In [5]:
def run_sdci(): 
    wandb.init()
    mv_flavor = wandb.config.mv_flavor
    s1_alpha = wandb.config.s1_alpha
    s2_alpha = wandb.config.s2_alpha
    s1_beta = wandb.config.s1_beta
    s2_beta = wandb.config.s2_beta
    max_gamma = wandb.config.max_gamma

    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print(f"Using {device}")
    model = SDCI(model_variance_flavor=mv_flavor)
    model.train(
        train_dataset,
        device=device,
        log_wandb=True,
        verbose=False,
        stage1_kwargs={"n_epochs": 100, "alpha": s1_alpha, "beta": s1_beta, "n_epochs_check": 5},
        stage2_kwargs={"n_epochs": 100, "alpha": s2_alpha, "beta": s2_beta, "max_gamma": max_gamma, "n_epochs_check": 5}
    )
    val_rec_loss = model._model.reconstruction_loss(val_dataset[:][0].to(device), mask_interventions_oh=val_dataset[:][1].to(device)).cpu().detach().item()

    min_dag_threshold = model.compute_min_dag_threshold()
    n_edges_min_dag = model.get_adjacency_matrix().sum()
    
    wandb.log(dict(val_rec_loss=val_rec_loss, min_dag_threshold=min_dag_threshold, n_edges_min_dag=n_edges_min_dag))
    wandb.finish()

In [6]:
sweep_configuration = {
    "method" : "bayes",
    "name": "sdci_new_sim_sweep",
    "metric": {
        "goal": "minimize",
        "name": "val_rec_loss",
    },
    "parameters": {
        "s1_alpha": {"max": 1e-1, "min": 1e-5, "distribution": "log_uniform_values"},
        "s2_alpha": {"max": 1e-1, "min": 1e-5, "distribution": "log_uniform_values"},
        "s1_beta": {"max": 1e-1, "min": 1e-5, "distribution": "log_uniform_values"},
        "s2_beta": {"max": 1e-1, "min": 1e-5, "distribution": "log_uniform_values"},
        "max_gamma": {"max": 1000, "min": 100, "distribution": "log_uniform_values"},
        "mv_flavor": {"values": ["nn"]},
    },
    "early_terminate": {
        "type": "hyperband",
        "min_iter": 11,
    }
}

In [7]:
import wandb
wandb.login()

sweep_id = wandb.sweep(sweep=sweep_configuration, project="SDCI_perturb_cite_seq")
wandb.agent(sweep_id, function=run_sdci)

[34m[1mwandb[0m: Currently logged in as: [33mjustinhong[0m ([33mazizi-causal-perturb[0m). Use [1m`wandb login --relogin`[0m to force relogin


Create sweep with ID: c55u19ta
Sweep URL: https://wandb.ai/azizi-causal-perturb/SDCI_perturb_cite_seq/sweeps/c55u19ta


[34m[1mwandb[0m: Agent Starting Run: efid05rs with config:
[34m[1mwandb[0m: 	max_gamma: 379.8147342065447
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.01710698860588679
[34m[1mwandb[0m: 	s1_beta: 9.746829410362016e-05
[34m[1mwandb[0m: 	s2_alpha: 0.0001275049488058736
[34m[1mwandb[0m: 	s2_beta: 8.220218668276032e-05


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668109150001935, max=1.0…

Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668120450049173, max=1.0…

Epoch 0: loss=2778.63, gamma=0.00
Epoch 5: loss=1980.28, gamma=0.00
Epoch 10: loss=1968.14, gamma=0.00
Epoch 15: loss=1952.55, gamma=0.00
Epoch 20: loss=1950.08, gamma=0.00
Epoch 25: loss=1947.80, gamma=0.00
Epoch 30: loss=1947.20, gamma=0.00
Epoch 35: loss=1941.39, gamma=0.00
Epoch 40: loss=1936.75, gamma=0.00
Epoch 45: loss=1944.45, gamma=0.00
Epoch 50: loss=1942.87, gamma=0.00
Epoch 55: loss=1942.63, gamma=0.00
Epoch 60: loss=1954.55, gamma=0.00
Epoch 65: loss=1941.45, gamma=0.00
Epoch 70: loss=1936.16, gamma=0.00
Epoch 75: loss=1934.88, gamma=0.00
Epoch 80: loss=1953.16, gamma=0.00
Epoch 85: loss=1942.72, gamma=0.00
Epoch 90: loss=1938.32, gamma=0.00
Epoch 95: loss=1932.26, gamma=0.00
Fraction of possible edges in mask: 0.0005230093140091833
Epoch 0: loss=2073.02, gamma=0.00
Epoch 5: loss=1748.83, gamma=19.18
Epoch 10: loss=1743.25, gamma=38.37
Epoch 15: loss=1734.76, gamma=57.55
Epoch 20: loss=1732.12, gamma=57.55
Epoch 25: loss=1740.84, gamma=57.55
Epoch 30: loss=1730.30, gamma=5

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▃▃▃▃▃▃▂▂▃▃▃▃▂▂▂▃▃▂▂▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▆█████████████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▇▇▇▇████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▂▂▃▃▄▄▅▅▅▆▆▆▇▇▇▇▇▇▇▁▂▃▃▄▄▅▅▅▆▆▆▆▇▇▇▇███
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.00013
dag,0.04872
epoch,195.0
epoch_loss,1718.26027
gamma,57.54769
is_prescreen,0.0
l1,1.07072
l2,4.47317
min_dag_threshold,0.30581
n_edges_min_dag,1297.0


[34m[1mwandb[0m: Agent Starting Run: vpkla5x8 with config:
[34m[1mwandb[0m: 	max_gamma: 130.23940569670566
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.0002663329330034641
[34m[1mwandb[0m: 	s1_beta: 5.045031732000572e-05
[34m[1mwandb[0m: 	s2_alpha: 0.0002603114167775845
[34m[1mwandb[0m: 	s2_beta: 0.012439464244132533


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668133566660494, max=1.0…

Epoch 0: loss=2550.01, gamma=0.00
Epoch 5: loss=1570.41, gamma=0.00
Epoch 10: loss=1483.57, gamma=0.00
Epoch 15: loss=1430.48, gamma=0.00
Epoch 20: loss=1391.57, gamma=0.00
Epoch 25: loss=1365.28, gamma=0.00
Epoch 30: loss=1344.74, gamma=0.00
Epoch 35: loss=1330.04, gamma=0.00
Epoch 40: loss=1319.62, gamma=0.00
Epoch 45: loss=1306.83, gamma=0.00
Epoch 50: loss=1293.52, gamma=0.00
Epoch 55: loss=1294.62, gamma=0.00
Epoch 60: loss=1293.94, gamma=0.00
Epoch 65: loss=1284.24, gamma=0.00
Epoch 70: loss=1288.93, gamma=0.00
Epoch 75: loss=1276.15, gamma=0.00
Epoch 80: loss=1288.57, gamma=0.00
Epoch 85: loss=1276.28, gamma=0.00
Epoch 90: loss=1279.99, gamma=0.00
Epoch 95: loss=1275.00, gamma=0.00
Fraction of possible edges in mask: 0.2421984747504142
Epoch 0: loss=2201.03, gamma=0.00
Epoch 5: loss=1823.24, gamma=6.58
Epoch 10: loss=1831.31, gamma=13.16
Epoch 15: loss=1834.33, gamma=19.73
Epoch 20: loss=1839.01, gamma=26.31
Epoch 25: loss=1836.32, gamma=32.89
Epoch 30: loss=1836.07, gamma=39.47

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▂▄▅▅▆▆▆▆▇▇▇▇▇▇▇█████▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁██████▇▇▇▇▇▇▇▇▇▇▇▇▇▇
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.00026
dag,0.48508
epoch,195.0
epoch_loss,1868.49944
gamma,124.97721
is_prescreen,0.0
l1,19.73564
l2,127.34352
min_dag_threshold,0.5906
n_edges_min_dag,543.0


[34m[1mwandb[0m: Agent Starting Run: tneiaqfx with config:
[34m[1mwandb[0m: 	max_gamma: 435.4514118905002
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.0012401475545467755
[34m[1mwandb[0m: 	s1_beta: 2.4257638326644822e-05
[34m[1mwandb[0m: 	s2_alpha: 0.0001775998403766347
[34m[1mwandb[0m: 	s2_beta: 0.006201787704172663


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668163766613966, max=1.0…

Epoch 0: loss=2553.19, gamma=0.00
Epoch 5: loss=1674.78, gamma=0.00
Epoch 10: loss=1634.72, gamma=0.00
Epoch 15: loss=1618.27, gamma=0.00
Epoch 20: loss=1604.25, gamma=0.00
Epoch 25: loss=1600.20, gamma=0.00
Epoch 30: loss=1703.05, gamma=0.00
Epoch 35: loss=1594.31, gamma=0.00
Epoch 40: loss=1598.14, gamma=0.00
Epoch 45: loss=1596.26, gamma=0.00
Epoch 50: loss=1604.13, gamma=0.00
Epoch 55: loss=1608.84, gamma=0.00
Epoch 60: loss=1607.75, gamma=0.00
Epoch 65: loss=1614.50, gamma=0.00
Epoch 70: loss=1617.80, gamma=0.00
Epoch 75: loss=1618.21, gamma=0.00
Epoch 80: loss=1620.32, gamma=0.00
Epoch 85: loss=1622.02, gamma=0.00
Epoch 90: loss=1624.92, gamma=0.00
Epoch 95: loss=1624.68, gamma=0.00
Fraction of possible edges in mask: 0.03090198346547574
Epoch 0: loss=2037.76, gamma=0.00
Epoch 5: loss=1718.55, gamma=21.99
Epoch 10: loss=1736.34, gamma=43.98
Epoch 15: loss=1738.05, gamma=65.98
Epoch 20: loss=1741.99, gamma=87.97
Epoch 25: loss=1747.53, gamma=109.96
Epoch 30: loss=1742.29, gamma=13

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▂▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▄▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▅▆▇██████████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▄▆▆▇▇▇▇█████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▇██████████████████
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.00018
dag,0.16524
epoch,195.0
epoch_loss,1735.083
gamma,131.95497
is_prescreen,0.0
l1,7.40486
l2,101.56779
min_dag_threshold,0.47793
n_edges_min_dag,9879.0


[34m[1mwandb[0m: Agent Starting Run: kpxjqe0x with config:
[34m[1mwandb[0m: 	max_gamma: 279.05772338714326
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.03453908158454036
[34m[1mwandb[0m: 	s1_beta: 0.001251086873833919
[34m[1mwandb[0m: 	s2_alpha: 0.03910128891660832
[34m[1mwandb[0m: 	s2_beta: 0.010209036556436284


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668166983314827, max=1.0…

Epoch 0: loss=2926.79, gamma=0.00
Epoch 5: loss=2131.68, gamma=0.00
Epoch 10: loss=2113.41, gamma=0.00
Epoch 15: loss=2098.95, gamma=0.00
Epoch 20: loss=2122.46, gamma=0.00
Epoch 25: loss=2098.26, gamma=0.00
Epoch 30: loss=2092.30, gamma=0.00
Epoch 35: loss=2087.49, gamma=0.00
Epoch 40: loss=2106.17, gamma=0.00
Epoch 45: loss=2098.28, gamma=0.00
Epoch 50: loss=2113.09, gamma=0.00
Epoch 55: loss=2095.40, gamma=0.00
Epoch 60: loss=2083.15, gamma=0.00
Epoch 65: loss=2092.84, gamma=0.00
Epoch 70: loss=2089.26, gamma=0.00
Epoch 75: loss=2085.43, gamma=0.00
Epoch 80: loss=2092.74, gamma=0.00
Epoch 85: loss=2082.61, gamma=0.00
Epoch 90: loss=2084.40, gamma=0.00
Epoch 95: loss=2081.63, gamma=0.00
Fraction of possible edges in mask: 0.00027789422464415516
Epoch 0: loss=2152.67, gamma=0.00
Epoch 5: loss=1832.58, gamma=14.09
Epoch 10: loss=1826.91, gamma=28.19
Epoch 15: loss=1834.69, gamma=42.28
Epoch 20: loss=1826.69, gamma=42.28
Epoch 25: loss=1834.73, gamma=42.28
Epoch 30: loss=1827.25, gamma=

0,1
alpha,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▆█████████████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▂▂▂▃▃▃▃▄▄▄▄▄▄▄▄▄▄▄▄████████████████████
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.0391
dag,0.0214
epoch,195.0
epoch_loss,1829.33844
gamma,42.28147
is_prescreen,0.0
l1,17.40967
l2,40.66154
min_dag_threshold,0.10552
n_edges_min_dag,665.0


[34m[1mwandb[0m: Agent Starting Run: tjxqb361 with config:
[34m[1mwandb[0m: 	max_gamma: 139.51198383007883
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.00015560863357465352
[34m[1mwandb[0m: 	s1_beta: 0.0004654182153203256
[34m[1mwandb[0m: 	s2_alpha: 0.00030504021624210465
[34m[1mwandb[0m: 	s2_beta: 0.006583412518639233


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666828768332683, max=1.0)…

Epoch 0: loss=2589.15, gamma=0.00
Epoch 5: loss=1563.41, gamma=0.00
Epoch 10: loss=1465.60, gamma=0.00
Epoch 15: loss=1411.89, gamma=0.00
Epoch 20: loss=1380.32, gamma=0.00
Epoch 25: loss=1353.72, gamma=0.00
Epoch 30: loss=1327.13, gamma=0.00
Epoch 35: loss=1317.29, gamma=0.00
Epoch 40: loss=1300.80, gamma=0.00
Epoch 45: loss=1292.46, gamma=0.00
Epoch 50: loss=1285.78, gamma=0.00
Epoch 55: loss=1288.70, gamma=0.00
Epoch 60: loss=1280.83, gamma=0.00
Epoch 65: loss=1278.38, gamma=0.00
Epoch 70: loss=1280.84, gamma=0.00
Epoch 75: loss=1269.92, gamma=0.00
Epoch 80: loss=1263.24, gamma=0.00
Epoch 85: loss=1268.02, gamma=0.00
Epoch 90: loss=1262.80, gamma=0.00
Epoch 95: loss=1250.19, gamma=0.00
Fraction of possible edges in mask: 0.2360531153108063
Epoch 0: loss=2187.48, gamma=0.00
Epoch 5: loss=1794.01, gamma=7.05
Epoch 10: loss=1802.49, gamma=14.09
Epoch 15: loss=1816.17, gamma=21.14
Epoch 20: loss=1814.59, gamma=28.18
Epoch 25: loss=1814.34, gamma=35.23
Epoch 30: loss=1815.49, gamma=42.28

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▁▄▅▅▆▆▆▇▇▇▇▇▇███████▅▃▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁
l2,▁▂▂▂▃▃▃▃▃▄▄▄▄▄▄▄▄▄▄▅▇████▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.00031
dag,0.54188
epoch,195.0
epoch_loss,1833.07584
gamma,133.87514
is_prescreen,0.0
l1,25.43314
l2,100.86334
min_dag_threshold,0.5976
n_edges_min_dag,1583.0


[34m[1mwandb[0m: Agent Starting Run: sqt4mc6c with config:
[34m[1mwandb[0m: 	max_gamma: 225.2755333518242
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.00020716085277656376
[34m[1mwandb[0m: 	s1_beta: 0.0002551681871473278
[34m[1mwandb[0m: 	s2_alpha: 0.0009053005841410048
[34m[1mwandb[0m: 	s2_beta: 0.05678023992100078


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666816943337229, max=1.0)…

Epoch 0: loss=2561.49, gamma=0.00
Epoch 5: loss=1556.95, gamma=0.00
Epoch 10: loss=1461.53, gamma=0.00
Epoch 15: loss=1418.14, gamma=0.00
Epoch 20: loss=1377.34, gamma=0.00
Epoch 25: loss=1351.44, gamma=0.00
Epoch 30: loss=1328.47, gamma=0.00
Epoch 35: loss=1308.78, gamma=0.00
Epoch 40: loss=1306.45, gamma=0.00
Epoch 45: loss=1292.78, gamma=0.00
Epoch 50: loss=1290.79, gamma=0.00
Epoch 55: loss=1290.82, gamma=0.00
Epoch 60: loss=1284.36, gamma=0.00
Epoch 65: loss=1313.62, gamma=0.00
Epoch 70: loss=1291.10, gamma=0.00
Epoch 75: loss=1289.76, gamma=0.00
Epoch 80: loss=1282.52, gamma=0.00
Epoch 85: loss=1281.45, gamma=0.00
Epoch 90: loss=1281.43, gamma=0.00
Epoch 95: loss=1280.46, gamma=0.00
Fraction of possible edges in mask: 0.23714028996423067
Epoch 0: loss=2369.95, gamma=0.00
Epoch 5: loss=1954.25, gamma=11.38
Epoch 10: loss=1957.94, gamma=22.76
Epoch 15: loss=1964.77, gamma=34.13
Epoch 20: loss=1963.83, gamma=45.51
Epoch 25: loss=1967.67, gamma=56.89
Epoch 30: loss=1963.98, gamma=68.

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▃▄▄▅▅▆▇▇█████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▁▄▅▅▆▆▆▆▇▇▇▇▇▇██████▅▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
l2,▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂█▅▅▆▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.00091
dag,0.22284
epoch,195.0
epoch_loss,1969.40806
gamma,125.15307
is_prescreen,0.0
l1,40.80511
l2,173.98794
min_dag_threshold,0.46096
n_edges_min_dag,191.0


[34m[1mwandb[0m: Agent Starting Run: os91d8xw with config:
[34m[1mwandb[0m: 	max_gamma: 102.72307343650188
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.00019687184369677663
[34m[1mwandb[0m: 	s1_beta: 0.0011462184680038135
[34m[1mwandb[0m: 	s2_alpha: 0.0009892934072085974
[34m[1mwandb[0m: 	s2_beta: 0.008875165091681516


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668259633297565, max=1.0…

Epoch 0: loss=2577.03, gamma=0.00
Epoch 5: loss=1577.68, gamma=0.00
Epoch 10: loss=1496.56, gamma=0.00
Epoch 15: loss=1457.74, gamma=0.00
Epoch 20: loss=1428.85, gamma=0.00
Epoch 25: loss=1411.11, gamma=0.00
Epoch 30: loss=1393.25, gamma=0.00
Epoch 35: loss=1382.27, gamma=0.00
Epoch 40: loss=1383.72, gamma=0.00
Epoch 45: loss=1369.47, gamma=0.00
Epoch 50: loss=1369.80, gamma=0.00
Epoch 55: loss=1363.78, gamma=0.00
Epoch 60: loss=1364.93, gamma=0.00
Epoch 65: loss=1376.82, gamma=0.00
Epoch 70: loss=1355.52, gamma=0.00
Epoch 75: loss=1360.35, gamma=0.00
Epoch 80: loss=1371.29, gamma=0.00
Epoch 85: loss=1368.18, gamma=0.00
Epoch 90: loss=1372.14, gamma=0.00
Epoch 95: loss=1379.08, gamma=0.00
Fraction of possible edges in mask: 0.1187988705038408
Epoch 0: loss=2148.53, gamma=0.00
Epoch 5: loss=1805.59, gamma=5.19
Epoch 10: loss=1805.35, gamma=10.38
Epoch 15: loss=1809.15, gamma=15.56
Epoch 20: loss=1814.70, gamma=20.75
Epoch 25: loss=1826.87, gamma=25.94
Epoch 30: loss=1824.13, gamma=31.13

0,1
alpha,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▄▄▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▁▄▅▆▆▆▇▇▇▇▇▇▇▇███████▇▆▆▆▆▆▆▅▅▅▅▅▅▅▅▅▅▅▅
l2,▁▂▃▃▄▄▄▅▅▅▅▅▅▆▆▆▆▆▆▆▅██████████████████▇
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.00099
dag,0.33971
epoch,195.0
epoch_loss,1824.5044
gamma,98.57265
is_prescreen,0.0
l1,56.11456
l2,113.02623
min_dag_threshold,0.67066
n_edges_min_dag,789.0


[34m[1mwandb[0m: Agent Starting Run: 5u77nju9 with config:
[34m[1mwandb[0m: 	max_gamma: 106.3564641948426
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.00021468080608109185
[34m[1mwandb[0m: 	s1_beta: 0.0008547274228438926
[34m[1mwandb[0m: 	s2_alpha: 0.00012380540325343658
[34m[1mwandb[0m: 	s2_beta: 0.0054504142203411895


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666837064998011, max=1.0)…

Epoch 0: loss=2615.78, gamma=0.00
Epoch 5: loss=1586.64, gamma=0.00
Epoch 10: loss=1495.07, gamma=0.00
Epoch 15: loss=1450.55, gamma=0.00
Epoch 20: loss=1421.10, gamma=0.00
Epoch 25: loss=1402.96, gamma=0.00
Epoch 30: loss=1382.26, gamma=0.00
Epoch 35: loss=1364.73, gamma=0.00
Epoch 40: loss=1360.98, gamma=0.00
Epoch 45: loss=1357.91, gamma=0.00
Epoch 50: loss=1362.20, gamma=0.00
Epoch 55: loss=1357.41, gamma=0.00
Epoch 60: loss=1352.27, gamma=0.00
Epoch 65: loss=1348.07, gamma=0.00
Epoch 70: loss=1338.70, gamma=0.00
Epoch 75: loss=1337.86, gamma=0.00
Epoch 80: loss=1344.20, gamma=0.00
Epoch 85: loss=1351.06, gamma=0.00
Epoch 90: loss=1345.47, gamma=0.00
Epoch 95: loss=1353.59, gamma=0.00
Fraction of possible edges in mask: 0.1396398447143098
Epoch 0: loss=2100.80, gamma=0.00
Epoch 5: loss=1731.75, gamma=5.37
Epoch 10: loss=1750.60, gamma=10.74
Epoch 15: loss=1754.45, gamma=16.11
Epoch 20: loss=1748.50, gamma=21.49
Epoch 25: loss=1756.40, gamma=26.86
Epoch 30: loss=1757.21, gamma=32.23

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▅▄▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▃▅▆▆▇▇▇▇▇▇██████████▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▂▂▃▃▄▄▄▄▄▅▅▅▅▅▅▅▅▅▅▅███████▇▇█▇▇▇▇▇▇▇▇▇
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.00012
dag,0.45998
epoch,195.0
epoch_loss,1763.15594
gamma,102.05923
is_prescreen,0.0
l1,9.96977
l2,102.09747
min_dag_threshold,0.62888
n_edges_min_dag,2619.0


[34m[1mwandb[0m: Agent Starting Run: 8fi9weq2 with config:
[34m[1mwandb[0m: 	max_gamma: 120.75269762012086
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.00042407463918453834
[34m[1mwandb[0m: 	s1_beta: 0.002997180887759907
[34m[1mwandb[0m: 	s2_alpha: 1.5526992189246795e-05
[34m[1mwandb[0m: 	s2_beta: 0.01703884804368981


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668720149997776, max=1.0…

Epoch 0: loss=2572.28, gamma=0.00
Epoch 5: loss=1631.64, gamma=0.00
Epoch 10: loss=1582.15, gamma=0.00
Epoch 15: loss=1559.00, gamma=0.00
Epoch 20: loss=1551.29, gamma=0.00
Epoch 25: loss=1549.66, gamma=0.00
Epoch 30: loss=1544.68, gamma=0.00
Epoch 35: loss=1534.40, gamma=0.00
Epoch 40: loss=1541.82, gamma=0.00
Epoch 45: loss=1532.73, gamma=0.00
Epoch 50: loss=1538.09, gamma=0.00
Epoch 55: loss=1529.92, gamma=0.00
Epoch 60: loss=1527.95, gamma=0.00
Epoch 65: loss=1530.62, gamma=0.00
Epoch 70: loss=1543.21, gamma=0.00
Epoch 75: loss=1531.39, gamma=0.00
Epoch 80: loss=1533.19, gamma=0.00
Epoch 85: loss=1530.02, gamma=0.00
Epoch 90: loss=1555.74, gamma=0.00
Epoch 95: loss=1552.58, gamma=0.00
Fraction of possible edges in mask: 0.034702906307397634
Epoch 0: loss=2078.35, gamma=0.00
Epoch 5: loss=1780.68, gamma=6.10
Epoch 10: loss=1783.75, gamma=12.20
Epoch 15: loss=1790.34, gamma=18.30
Epoch 20: loss=1788.37, gamma=24.39
Epoch 25: loss=1793.09, gamma=30.49
Epoch 30: loss=1787.78, gamma=36.

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▃▃▄▄▅▅▆▆▇▇███████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▄▇▇▇▇███████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▃▃▄▄▅▅▅▅▆▆▆▆▆▆▆▆▆▆▆▆███████████████████
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,2e-05
dag,0.20352
epoch,195.0
epoch_loss,1790.58923
gamma,79.28207
is_prescreen,0.0
l1,0.51874
l2,134.78114
min_dag_threshold,0.4625
n_edges_min_dag,2559.0


[34m[1mwandb[0m: Agent Starting Run: 8jvv8779 with config:
[34m[1mwandb[0m: 	max_gamma: 107.03193215397135
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.00022654299265432448
[34m[1mwandb[0m: 	s1_beta: 0.02032616201510188
[34m[1mwandb[0m: 	s2_alpha: 0.00030328065024813107
[34m[1mwandb[0m: 	s2_beta: 0.0009538140205485206


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668624200004464, max=1.0…

Epoch 0: loss=2632.38, gamma=0.00
Epoch 5: loss=1722.10, gamma=0.00
Epoch 10: loss=1697.53, gamma=0.00
Epoch 15: loss=1695.07, gamma=0.00
Epoch 20: loss=1693.33, gamma=0.00
Epoch 25: loss=1689.44, gamma=0.00
Epoch 30: loss=1684.57, gamma=0.00
Epoch 35: loss=1687.72, gamma=0.00
Epoch 40: loss=1694.49, gamma=0.00
Epoch 45: loss=1683.55, gamma=0.00
Epoch 50: loss=1682.05, gamma=0.00
Epoch 55: loss=1690.23, gamma=0.00
Epoch 60: loss=1686.29, gamma=0.00
Epoch 65: loss=1682.82, gamma=0.00
Epoch 70: loss=1687.15, gamma=0.00
Epoch 75: loss=1685.35, gamma=0.00
Epoch 80: loss=1684.01, gamma=0.00
Epoch 85: loss=1688.65, gamma=0.00
Epoch 90: loss=1694.01, gamma=0.00
Epoch 95: loss=1688.31, gamma=0.00
Fraction of possible edges in mask: 0.0013479508852005482
Epoch 0: loss=2052.80, gamma=0.00
Epoch 5: loss=1721.61, gamma=5.41
Epoch 10: loss=1715.13, gamma=10.81
Epoch 15: loss=1719.62, gamma=16.22
Epoch 20: loss=1720.04, gamma=21.62
Epoch 25: loss=1717.83, gamma=27.03
Epoch 30: loss=1716.93, gamma=32

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▃▄▅▆▆▇███████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▆███████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▃▆▇▇▇▇▇▇▇███████████▁▁▁▁▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.0003
dag,0.10889
epoch,195.0
epoch_loss,1708.81241
gamma,48.65088
is_prescreen,0.0
l1,3.14913
l2,21.80109
min_dag_threshold,0.46398
n_edges_min_dag,2940.0


[34m[1mwandb[0m: Agent Starting Run: 2u72m9ry with config:
[34m[1mwandb[0m: 	max_gamma: 133.8864402612173
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.00013073105914233614
[34m[1mwandb[0m: 	s1_beta: 0.002292925114470436
[34m[1mwandb[0m: 	s2_alpha: 0.0002300575209519635
[34m[1mwandb[0m: 	s2_beta: 0.07937219329446436


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668031733327857, max=1.0…

Epoch 0: loss=2566.64, gamma=0.00
Epoch 5: loss=1584.02, gamma=0.00
Epoch 10: loss=1508.69, gamma=0.00
Epoch 15: loss=1475.39, gamma=0.00
Epoch 20: loss=1450.77, gamma=0.00
Epoch 25: loss=1436.32, gamma=0.00
Epoch 30: loss=1421.97, gamma=0.00
Epoch 35: loss=1412.82, gamma=0.00
Epoch 40: loss=1414.28, gamma=0.00
Epoch 45: loss=1404.27, gamma=0.00
Epoch 50: loss=1402.81, gamma=0.00
Epoch 55: loss=1393.42, gamma=0.00
Epoch 60: loss=1399.76, gamma=0.00
Epoch 65: loss=1390.40, gamma=0.00
Epoch 70: loss=1402.21, gamma=0.00
Epoch 75: loss=1414.54, gamma=0.00
Epoch 80: loss=1400.86, gamma=0.00
Epoch 85: loss=1404.57, gamma=0.00
Epoch 90: loss=1405.99, gamma=0.00
Epoch 95: loss=1412.31, gamma=0.00
Fraction of possible edges in mask: 0.07373375110948267
Epoch 0: loss=2274.58, gamma=0.00
Epoch 5: loss=1938.25, gamma=6.76
Epoch 10: loss=1942.49, gamma=13.52
Epoch 15: loss=1950.65, gamma=20.29
Epoch 20: loss=1950.14, gamma=27.05
Epoch 25: loss=1953.56, gamma=33.81
Epoch 30: loss=1950.26, gamma=33.8

0,1
alpha,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▄▅▇███████████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▃▆▇▇▇▇▇▇▇███████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▂▂▃▃▃▃▄▄▄▄▄▄▄▄▄▄▄▄▄█▆▆▆▆▆▆▆▆▆▆▇▆▆▆▆▇▆▆▆
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.00023
dag,0.23619
epoch,195.0
epoch_loss,1955.87453
gamma,33.80971
is_prescreen,0.0
l1,5.99562
l2,202.52316
min_dag_threshold,0.50207
n_edges_min_dag,32.0


[34m[1mwandb[0m: Agent Starting Run: 2ccgby68 with config:
[34m[1mwandb[0m: 	max_gamma: 101.96183624461476
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.0025333012997127164
[34m[1mwandb[0m: 	s1_beta: 0.002126573629660055
[34m[1mwandb[0m: 	s2_alpha: 3.5997661638930846e-05
[34m[1mwandb[0m: 	s2_beta: 0.06075662201613709


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668091150010392, max=1.0…

Epoch 0: loss=2600.80, gamma=0.00
Epoch 5: loss=1756.94, gamma=0.00
Epoch 10: loss=1733.99, gamma=0.00
Epoch 15: loss=1727.65, gamma=0.00
Epoch 20: loss=1727.18, gamma=0.00
Epoch 25: loss=1730.59, gamma=0.00
Epoch 30: loss=1725.35, gamma=0.00
Epoch 35: loss=1738.02, gamma=0.00
Epoch 40: loss=1733.18, gamma=0.00
Epoch 45: loss=1740.41, gamma=0.00
Epoch 50: loss=1736.98, gamma=0.00
Epoch 55: loss=1741.81, gamma=0.00
Epoch 60: loss=1746.40, gamma=0.00
Epoch 65: loss=1744.42, gamma=0.00
Epoch 70: loss=1743.62, gamma=0.00
Epoch 75: loss=1753.68, gamma=0.00
Epoch 80: loss=1749.82, gamma=0.00
Epoch 85: loss=1746.45, gamma=0.00
Epoch 90: loss=1744.39, gamma=0.00
Epoch 95: loss=1751.34, gamma=0.00
Fraction of possible edges in mask: 0.005511993703492326
Epoch 0: loss=2191.73, gamma=0.00
Epoch 5: loss=1900.52, gamma=5.15
Epoch 10: loss=1906.48, gamma=10.30
Epoch 15: loss=1896.60, gamma=15.45
Epoch 20: loss=1898.45, gamma=20.60
Epoch 25: loss=1906.33, gamma=25.75
Epoch 30: loss=1900.59, gamma=30.

VBox(children=(Label(value='0.009 MB of 0.021 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.453204…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▅▆▇██████████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▅▇▇▇████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,4e-05
dag,0.16699
epoch,195.0
epoch_loss,1911.9174
gamma,30.89753
is_prescreen,0.0
l1,0.26012
l2,157.60912
min_dag_threshold,0.46325
n_edges_min_dag,402.0


[34m[1mwandb[0m: Agent Starting Run: sh4gxyj6 with config:
[34m[1mwandb[0m: 	max_gamma: 103.15908236229876
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.0004328072242515362
[34m[1mwandb[0m: 	s1_beta: 0.0006352128434378249
[34m[1mwandb[0m: 	s2_alpha: 0.00021341337848999553
[34m[1mwandb[0m: 	s2_beta: 0.03264345270552534


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666804704994623, max=1.0)…

Epoch 0: loss=2615.25, gamma=0.00
Epoch 5: loss=1600.11, gamma=0.00
Epoch 10: loss=1536.13, gamma=0.00
Epoch 15: loss=1502.67, gamma=0.00
Epoch 20: loss=1484.73, gamma=0.00
Epoch 25: loss=1465.67, gamma=0.00
Epoch 30: loss=1451.77, gamma=0.00
Epoch 35: loss=1437.43, gamma=0.00
Epoch 40: loss=1434.73, gamma=0.00
Epoch 45: loss=1427.95, gamma=0.00
Epoch 50: loss=1426.78, gamma=0.00
Epoch 55: loss=1424.35, gamma=0.00
Epoch 60: loss=1435.16, gamma=0.00
Epoch 65: loss=1437.00, gamma=0.00
Epoch 70: loss=1435.68, gamma=0.00
Epoch 75: loss=1429.33, gamma=0.00
Epoch 80: loss=1432.68, gamma=0.00
Epoch 85: loss=1442.64, gamma=0.00
Epoch 90: loss=1433.44, gamma=0.00
Epoch 95: loss=1434.61, gamma=0.00
Fraction of possible edges in mask: 0.09504965856888481
Epoch 0: loss=2175.39, gamma=0.00
Epoch 5: loss=1854.22, gamma=5.21
Epoch 10: loss=1859.42, gamma=10.42
Epoch 15: loss=1865.47, gamma=15.63
Epoch 20: loss=1867.39, gamma=20.84
Epoch 25: loss=1855.72, gamma=26.05
Epoch 30: loss=1866.80, gamma=31.2

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▃▃▄▄▅▅▅▆▆▇▇█████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▃▅▆▆▇▇▇▇▇▇██████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▁▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▃████████▇███▇▇████▇█
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.00021
dag,0.24412
epoch,195.0
epoch_loss,1882.07366
gamma,83.36087
is_prescreen,0.0
l1,8.98689
l2,174.02378
min_dag_threshold,0.52845
n_edges_min_dag,371.0


[34m[1mwandb[0m: Agent Starting Run: ieyqyoj9 with config:
[34m[1mwandb[0m: 	max_gamma: 126.95383967625128
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 5.835678069858703e-05
[34m[1mwandb[0m: 	s1_beta: 0.0001032920428387671
[34m[1mwandb[0m: 	s2_alpha: 2.202722223954116e-05
[34m[1mwandb[0m: 	s2_beta: 0.0046034151334148335


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.0166682409166242, max=1.0))…

Epoch 0: loss=2580.85, gamma=0.00
Epoch 5: loss=1554.35, gamma=0.00
Epoch 10: loss=1414.77, gamma=0.00
Epoch 15: loss=1333.41, gamma=0.00
Epoch 20: loss=1277.69, gamma=0.00
Epoch 25: loss=1235.76, gamma=0.00
Epoch 30: loss=1209.81, gamma=0.00
Epoch 35: loss=1178.35, gamma=0.00
Epoch 40: loss=1157.46, gamma=0.00
Epoch 45: loss=1139.81, gamma=0.00
Epoch 50: loss=1128.66, gamma=0.00
Epoch 55: loss=1126.91, gamma=0.00
Epoch 60: loss=1114.51, gamma=0.00
Epoch 65: loss=1110.07, gamma=0.00
Epoch 70: loss=1093.72, gamma=0.00
Epoch 75: loss=1095.33, gamma=0.00
Epoch 80: loss=1089.21, gamma=0.00
Epoch 85: loss=1087.74, gamma=0.00
Epoch 90: loss=1087.12, gamma=0.00
Epoch 95: loss=1081.49, gamma=0.00
Fraction of possible edges in mask: 0.606907510756109
Epoch 0: loss=2306.69, gamma=0.00
Epoch 5: loss=1853.86, gamma=6.41
Epoch 10: loss=1890.29, gamma=12.82
Epoch 15: loss=1910.94, gamma=19.24
Epoch 20: loss=1936.87, gamma=25.65
Epoch 25: loss=1973.50, gamma=32.06
Epoch 30: loss=1966.67, gamma=38.47


VBox(children=(Label(value='0.009 MB of 0.017 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.551835…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▂▂▂▂▂▂▂▂▂▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▇▅▅▅▅▅▅▅▅▅▅▅▅▆▅▅▅▅▅▅
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▃▃▄▄▅▅▆▆▇▇███████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▂▃▄▅▅▆▆▆▆▇▇▇▇▇██████▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂█▄▄▄▄▄▄▄▄▃▃▃▃▃▃▃▃▃▃▃
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,2e-05
dag,2.04418
epoch,195.0
epoch_loss,1980.51526
gamma,83.35353
is_prescreen,0.0
l1,3.57488
l2,78.0489
min_dag_threshold,0.51083
n_edges_min_dag,2790.0


[34m[1mwandb[0m: Agent Starting Run: kgmbzgq1 with config:
[34m[1mwandb[0m: 	max_gamma: 101.87861367094712
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.0010301548064181407
[34m[1mwandb[0m: 	s1_beta: 7.35481246512472e-05
[34m[1mwandb[0m: 	s2_alpha: 0.004998822131971627
[34m[1mwandb[0m: 	s2_beta: 0.009976953901426344


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668406933361742, max=1.0…

Epoch 0: loss=2601.66, gamma=0.00
Epoch 5: loss=1655.70, gamma=0.00
Epoch 10: loss=1610.69, gamma=0.00
Epoch 15: loss=1590.68, gamma=0.00
Epoch 20: loss=1574.90, gamma=0.00
Epoch 25: loss=1564.65, gamma=0.00
Epoch 30: loss=1566.53, gamma=0.00
Epoch 35: loss=1567.80, gamma=0.00
Epoch 40: loss=1567.65, gamma=0.00
Epoch 45: loss=1609.04, gamma=0.00
Epoch 50: loss=1562.84, gamma=0.00
Epoch 55: loss=1566.16, gamma=0.00
Epoch 60: loss=1573.08, gamma=0.00
Epoch 65: loss=1575.85, gamma=0.00
Epoch 70: loss=1577.76, gamma=0.00
Epoch 75: loss=1576.85, gamma=0.00
Epoch 80: loss=1575.05, gamma=0.00
Epoch 85: loss=1576.22, gamma=0.00
Epoch 90: loss=1584.49, gamma=0.00
Epoch 95: loss=1579.78, gamma=0.00
Fraction of possible edges in mask: 0.0426299938557332
Epoch 0: loss=2143.92, gamma=0.00
Epoch 5: loss=1865.08, gamma=5.15
Epoch 10: loss=1873.24, gamma=10.29
Epoch 15: loss=1859.44, gamma=15.44
Epoch 20: loss=1858.41, gamma=20.58
Epoch 25: loss=1847.63, gamma=25.73
Epoch 30: loss=1853.77, gamma=30.87

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▁▄▅▆▆▆▇▇▇▇▇█████████▄▄▄▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
l2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▇▇▇█▇██████████████
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.005
dag,0.14803
epoch,195.0
epoch_loss,1857.37604
gamma,97.76231
is_prescreen,0.0
l1,111.41081
l2,93.07453
min_dag_threshold,0.53306
n_edges_min_dag,1484.0


[34m[1mwandb[0m: Agent Starting Run: zq4co3aw with config:
[34m[1mwandb[0m: 	max_gamma: 120.48080208449136
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.0003952111630139574
[34m[1mwandb[0m: 	s1_beta: 0.0003620010894622518
[34m[1mwandb[0m: 	s2_alpha: 0.002125893840061103
[34m[1mwandb[0m: 	s2_beta: 0.015735280309879657


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668024633448415, max=1.0…

Epoch 0: loss=2584.48, gamma=0.00
Epoch 5: loss=1593.17, gamma=0.00
Epoch 10: loss=1520.76, gamma=0.00
Epoch 15: loss=1486.68, gamma=0.00
Epoch 20: loss=1454.89, gamma=0.00
Epoch 25: loss=1433.42, gamma=0.00
Epoch 30: loss=1414.09, gamma=0.00
Epoch 35: loss=1412.27, gamma=0.00
Epoch 40: loss=1403.00, gamma=0.00
Epoch 45: loss=1393.54, gamma=0.00
Epoch 50: loss=1388.29, gamma=0.00
Epoch 55: loss=1395.14, gamma=0.00
Epoch 60: loss=1393.99, gamma=0.00
Epoch 65: loss=1400.05, gamma=0.00
Epoch 70: loss=1393.62, gamma=0.00
Epoch 75: loss=1404.63, gamma=0.00
Epoch 80: loss=1396.86, gamma=0.00
Epoch 85: loss=1393.71, gamma=0.00
Epoch 90: loss=1399.23, gamma=0.00
Epoch 95: loss=1405.43, gamma=0.00
Fraction of possible edges in mask: 0.12305688017659941
Epoch 0: loss=2199.49, gamma=0.00
Epoch 5: loss=1868.65, gamma=6.08
Epoch 10: loss=1873.22, gamma=12.17
Epoch 15: loss=1877.63, gamma=18.25
Epoch 20: loss=1875.52, gamma=24.34
Epoch 25: loss=1888.58, gamma=30.42
Epoch 30: loss=1889.54, gamma=36.5

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▂▂▃▃▄▄▄▅▅▅▆▆▇▇▇██
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▁▄▅▅▆▆▆▇▇▇▇▇▇███████▇▅▅▅▅▅▅▅▅▅▄▄▄▄▄▄▄▄▄▄
l2,▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▆██████████████████▇
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.00213
dag,0.23951
epoch,195.0
epoch_loss,1891.0105
gamma,115.61289
is_prescreen,0.0
l1,88.38353
l2,119.43993
min_dag_threshold,0.56503
n_edges_min_dag,578.0


[34m[1mwandb[0m: Agent Starting Run: pxyd9m2k with config:
[34m[1mwandb[0m: 	max_gamma: 110.670357371518
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 5.386983639847875e-05
[34m[1mwandb[0m: 	s1_beta: 7.162855030717716e-05
[34m[1mwandb[0m: 	s2_alpha: 0.00093529850114538
[34m[1mwandb[0m: 	s2_beta: 0.07154275455708342


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666810599999735, max=1.0)…

Epoch 0: loss=2580.45, gamma=0.00
Epoch 5: loss=1546.45, gamma=0.00
Epoch 10: loss=1407.87, gamma=0.00
Epoch 15: loss=1325.37, gamma=0.00
Epoch 20: loss=1271.29, gamma=0.00
Epoch 25: loss=1226.90, gamma=0.00
Epoch 30: loss=1199.70, gamma=0.00
Epoch 35: loss=1179.52, gamma=0.00
Epoch 40: loss=1143.42, gamma=0.00
Epoch 45: loss=1126.81, gamma=0.00
Epoch 50: loss=1121.90, gamma=0.00
Epoch 55: loss=1108.98, gamma=0.00
Epoch 60: loss=1098.55, gamma=0.00
Epoch 65: loss=1091.80, gamma=0.00
Epoch 70: loss=1087.27, gamma=0.00
Epoch 75: loss=1081.18, gamma=0.00
Epoch 80: loss=1084.16, gamma=0.00
Epoch 85: loss=1074.35, gamma=0.00
Epoch 90: loss=1072.08, gamma=0.00
Epoch 95: loss=1071.82, gamma=0.00
Fraction of possible edges in mask: 0.6491813046751423
Epoch 0: loss=2937.17, gamma=0.00
Epoch 5: loss=2072.26, gamma=5.59
Epoch 10: loss=2099.33, gamma=11.18
Epoch 15: loss=2120.43, gamma=16.77
Epoch 20: loss=2113.87, gamma=16.77
Epoch 25: loss=2135.69, gamma=16.77
Epoch 30: loss=2125.43, gamma=16.77

VBox(children=(Label(value='0.009 MB of 0.021 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.453402…

0,1
alpha,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████████████████████
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,▇▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅▅
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▆█████████████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▁▂▂▂▂▂▃▃▃▃▃▃▃▃▃▃▃▃▃▃█▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
l2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.00094
dag,1.12101
epoch,195.0
epoch_loss,2137.95302
gamma,16.76824
is_prescreen,0.0
l1,91.94756
l2,240.69692
min_dag_threshold,0.39198
n_edges_min_dag,197.0


[34m[1mwandb[0m: Agent Starting Run: s0ihn4zy with config:
[34m[1mwandb[0m: 	max_gamma: 244.17190722070103
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.001310898718854395
[34m[1mwandb[0m: 	s1_beta: 0.0006761695449736915
[34m[1mwandb[0m: 	s2_alpha: 1.0917198489429712e-05
[34m[1mwandb[0m: 	s2_beta: 0.04376851827936573


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668164633301785, max=1.0…

Epoch 0: loss=2606.94, gamma=0.00
Epoch 5: loss=1684.70, gamma=0.00
Epoch 10: loss=1649.38, gamma=0.00
Epoch 15: loss=1631.66, gamma=0.00
Epoch 20: loss=1629.08, gamma=0.00
Epoch 25: loss=1624.53, gamma=0.00
Epoch 30: loss=1620.06, gamma=0.00
Epoch 35: loss=1624.01, gamma=0.00
Epoch 40: loss=1630.21, gamma=0.00
Epoch 45: loss=1626.39, gamma=0.00
Epoch 50: loss=1649.46, gamma=0.00
Epoch 55: loss=1641.67, gamma=0.00
Epoch 60: loss=1646.85, gamma=0.00
Epoch 65: loss=1655.19, gamma=0.00
Epoch 70: loss=1646.30, gamma=0.00
Epoch 75: loss=1692.83, gamma=0.00
Epoch 80: loss=1645.84, gamma=0.00
Epoch 85: loss=1648.26, gamma=0.00
Epoch 90: loss=1651.80, gamma=0.00
Epoch 95: loss=1648.92, gamma=0.00
Fraction of possible edges in mask: 0.022554958772953135
Epoch 0: loss=2160.30, gamma=0.00
Epoch 5: loss=1857.56, gamma=12.33
Epoch 10: loss=1874.27, gamma=24.66
Epoch 15: loss=1870.95, gamma=37.00
Epoch 20: loss=1867.08, gamma=49.33
Epoch 25: loss=1863.60, gamma=49.33
Epoch 30: loss=1868.88, gamma=49

VBox(children=(Label(value='0.009 MB of 0.021 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.453287…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▅▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▅▆████████████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▄▆▇▇▇▇██████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂███████▇████████████
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,1e-05
dag,0.19211
epoch,195.0
epoch_loss,1864.38936
gamma,49.32766
is_prescreen,0.0
l1,0.20534
l2,159.74644
min_dag_threshold,0.50423
n_edges_min_dag,319.0


[34m[1mwandb[0m: Agent Starting Run: rz8l857l with config:
[34m[1mwandb[0m: 	max_gamma: 402.92698514993464
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.0005495843786141105
[34m[1mwandb[0m: 	s1_beta: 2.5234937831259743e-05
[34m[1mwandb[0m: 	s2_alpha: 2.3734338091786897e-05
[34m[1mwandb[0m: 	s2_beta: 0.03171255980780661


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668393300157427, max=1.0…

Epoch 0: loss=2589.56, gamma=0.00
Epoch 5: loss=1605.49, gamma=0.00
Epoch 10: loss=1548.21, gamma=0.00
Epoch 15: loss=1507.41, gamma=0.00
Epoch 20: loss=1486.27, gamma=0.00
Epoch 25: loss=1459.36, gamma=0.00
Epoch 30: loss=1453.04, gamma=0.00
Epoch 35: loss=1449.24, gamma=0.00
Epoch 40: loss=1443.42, gamma=0.00
Epoch 45: loss=1527.43, gamma=0.00
Epoch 50: loss=1445.20, gamma=0.00
Epoch 55: loss=1436.06, gamma=0.00
Epoch 60: loss=1438.46, gamma=0.00
Epoch 65: loss=1424.27, gamma=0.00
Epoch 70: loss=1435.56, gamma=0.00
Epoch 75: loss=1427.65, gamma=0.00
Epoch 80: loss=1427.57, gamma=0.00
Epoch 85: loss=1438.80, gamma=0.00
Epoch 90: loss=1441.27, gamma=0.00
Epoch 95: loss=1447.01, gamma=0.00
Fraction of possible edges in mask: 0.10453885401957788
Epoch 0: loss=2163.67, gamma=0.00
Epoch 5: loss=1854.95, gamma=20.35
Epoch 10: loss=1857.24, gamma=40.70
Epoch 15: loss=1861.98, gamma=61.05
Epoch 20: loss=1874.94, gamma=81.40
Epoch 25: loss=1863.21, gamma=81.40
Epoch 30: loss=1869.61, gamma=81.

VBox(children=(Label(value='0.009 MB of 0.016 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.597530…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▂▂▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁▁▅▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▅▆████████████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▃▅▆▆▆▇▇▇▇▇▇▇████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁████▇███▇██████▇████
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,2e-05
dag,0.2666
epoch,195.0
epoch_loss,1861.03064
gamma,81.39939
is_prescreen,0.0
l1,1.09282
l2,171.37743
min_dag_threshold,0.46468
n_edges_min_dag,616.0


[34m[1mwandb[0m: Agent Starting Run: jac469dp with config:
[34m[1mwandb[0m: 	max_gamma: 113.28231662351838
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.0010838231341493013
[34m[1mwandb[0m: 	s1_beta: 0.0006353580069591618
[34m[1mwandb[0m: 	s2_alpha: 3.850717548064476e-05
[34m[1mwandb[0m: 	s2_beta: 0.0144955224084971


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668226166560392, max=1.0…

Epoch 0: loss=2568.74, gamma=0.00
Epoch 5: loss=1660.76, gamma=0.00
Epoch 10: loss=1623.69, gamma=0.00
Epoch 15: loss=1613.62, gamma=0.00
Epoch 20: loss=1602.34, gamma=0.00
Epoch 25: loss=1594.21, gamma=0.00
Epoch 30: loss=1598.80, gamma=0.00
Epoch 35: loss=1594.77, gamma=0.00
Epoch 40: loss=1595.27, gamma=0.00
Epoch 45: loss=1612.90, gamma=0.00
Epoch 50: loss=1599.04, gamma=0.00
Epoch 55: loss=1599.20, gamma=0.00
Epoch 60: loss=1601.09, gamma=0.00
Epoch 65: loss=1608.46, gamma=0.00
Epoch 70: loss=1603.92, gamma=0.00
Epoch 75: loss=1609.50, gamma=0.00
Epoch 80: loss=1604.40, gamma=0.00
Epoch 85: loss=1612.56, gamma=0.00
Epoch 90: loss=1603.24, gamma=0.00
Epoch 95: loss=1609.56, gamma=0.00
Fraction of possible edges in mask: 0.031473797269789404
Epoch 0: loss=2052.94, gamma=0.00
Epoch 5: loss=1765.24, gamma=5.72
Epoch 10: loss=1770.46, gamma=11.44
Epoch 15: loss=1773.95, gamma=17.16
Epoch 20: loss=1773.76, gamma=22.89
Epoch 25: loss=1766.56, gamma=28.61
Epoch 30: loss=1777.89, gamma=34.

VBox(children=(Label(value='0.009 MB of 0.017 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.551988…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▄▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▃▃▃▄▄▅▅▆▆▇███████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▄▆▇▇▇▇▇█████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▁▁▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▆███████████████████
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,4e-05
dag,0.20837
epoch,195.0
epoch_loss,1780.42182
gamma,80.09861
is_prescreen,0.0
l1,1.28189
l2,130.37527
min_dag_threshold,0.47715
n_edges_min_dag,3665.0


[34m[1mwandb[0m: Sweep Agent: Waiting for job.
[34m[1mwandb[0m: Job received.
[34m[1mwandb[0m: Agent Starting Run: 02to0jjh with config:
[34m[1mwandb[0m: 	max_gamma: 550.6643734227673
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.00036202833126873935
[34m[1mwandb[0m: 	s1_beta: 0.00013244359440484923
[34m[1mwandb[0m: 	s2_alpha: 0.0001021447314594023
[34m[1mwandb[0m: 	s2_beta: 0.06238442335414219
[34m[1mwandb[0m: W&B API key is configured. Use [1m`wandb login --relogin`[0m to force relogin


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668131550007577, max=1.0…

Epoch 0: loss=2602.07, gamma=0.00
Epoch 5: loss=1585.85, gamma=0.00
Epoch 10: loss=1506.59, gamma=0.00
Epoch 15: loss=1467.04, gamma=0.00
Epoch 20: loss=1430.43, gamma=0.00
Epoch 25: loss=1412.22, gamma=0.00
Epoch 30: loss=1393.75, gamma=0.00
Epoch 35: loss=1385.10, gamma=0.00
Epoch 40: loss=1378.11, gamma=0.00
Epoch 45: loss=1374.35, gamma=0.00
Epoch 50: loss=1364.65, gamma=0.00
Epoch 55: loss=1360.43, gamma=0.00
Epoch 60: loss=1361.58, gamma=0.00
Epoch 65: loss=1357.47, gamma=0.00
Epoch 70: loss=1359.63, gamma=0.00
Epoch 75: loss=1360.56, gamma=0.00
Epoch 80: loss=1358.10, gamma=0.00
Epoch 85: loss=1355.00, gamma=0.00
Epoch 90: loss=1363.38, gamma=0.00
Epoch 95: loss=1350.32, gamma=0.00
Fraction of possible edges in mask: 0.16222321207117152
Epoch 0: loss=2286.29, gamma=0.00
Epoch 5: loss=1934.20, gamma=27.81
Epoch 10: loss=1933.88, gamma=55.62
Epoch 15: loss=1931.74, gamma=83.43
Epoch 20: loss=1936.61, gamma=111.25
Epoch 25: loss=1937.27, gamma=111.25
Epoch 30: loss=1939.69, gamma=1

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▆▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄▄
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▅▆████████████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▃▅▅▆▆▆▇▇▇▇▇▇▇███████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆▆
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,0.0001
dag,0.1909
epoch,195.0
epoch_loss,1939.88243
gamma,111.24533
is_prescreen,0.0
l1,4.2258
l2,188.47032
min_dag_threshold,0.42756
n_edges_min_dag,95.0


[34m[1mwandb[0m: Agent Starting Run: mwhkda91 with config:
[34m[1mwandb[0m: 	max_gamma: 823.4816832901289
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.0011316913410376202
[34m[1mwandb[0m: 	s1_beta: 2.050114940872295e-05
[34m[1mwandb[0m: 	s2_alpha: 1.2243951558967965e-05
[34m[1mwandb[0m: 	s2_beta: 0.009309430673225172


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666808683318474, max=1.0)…

Epoch 0: loss=2586.72, gamma=0.00
Epoch 5: loss=1666.00, gamma=0.00
Epoch 10: loss=1625.23, gamma=0.00
Epoch 15: loss=1604.40, gamma=0.00
Epoch 20: loss=1604.15, gamma=0.00
Epoch 25: loss=1588.51, gamma=0.00
Epoch 30: loss=1586.90, gamma=0.00
Epoch 35: loss=1582.39, gamma=0.00
Epoch 40: loss=1583.18, gamma=0.00
Epoch 45: loss=1587.13, gamma=0.00
Epoch 50: loss=1593.56, gamma=0.00
Epoch 55: loss=1594.81, gamma=0.00
Epoch 60: loss=1596.90, gamma=0.00
Epoch 65: loss=1599.18, gamma=0.00
Epoch 70: loss=1599.68, gamma=0.00
Epoch 75: loss=1614.14, gamma=0.00
Epoch 80: loss=1607.74, gamma=0.00
Epoch 85: loss=1615.36, gamma=0.00
Epoch 90: loss=1615.92, gamma=0.00
Epoch 95: loss=1620.73, gamma=0.00
Fraction of possible edges in mask: 0.03505947045671169
Epoch 0: loss=2042.12, gamma=0.00
Epoch 5: loss=1749.83, gamma=41.59
Epoch 10: loss=1753.82, gamma=83.18
Epoch 15: loss=1767.75, gamma=124.77
Epoch 20: loss=1773.66, gamma=166.36
Epoch 25: loss=1763.92, gamma=166.36
Epoch 30: loss=1776.25, gamma=

VBox(children=(Label(value='0.009 MB of 0.009 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▄▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▅▆████████████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▄▆▆▇▇▇▇▇████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅███████████████████
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,1e-05
dag,0.13707
epoch,195.0
epoch_loss,1760.15598
gamma,166.35994
is_prescreen,0.0
l1,0.48338
l2,115.12058
min_dag_threshold,0.41357
n_edges_min_dag,9034.0


[34m[1mwandb[0m: Agent Starting Run: jtot6d1x with config:
[34m[1mwandb[0m: 	max_gamma: 781.8636584170466
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.00250379396522531
[34m[1mwandb[0m: 	s1_beta: 1.5190388810681324e-05
[34m[1mwandb[0m: 	s2_alpha: 6.195452509393702e-05
[34m[1mwandb[0m: 	s2_beta: 0.07308963748069325


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.01666868783334697, max=1.0)…

Epoch 0: loss=2603.43, gamma=0.00
Epoch 5: loss=1745.48, gamma=0.00
Epoch 10: loss=1739.02, gamma=0.00
Epoch 15: loss=1710.35, gamma=0.00
Epoch 20: loss=1709.45, gamma=0.00
Epoch 25: loss=1707.27, gamma=0.00
Epoch 30: loss=1708.27, gamma=0.00
Epoch 35: loss=1718.61, gamma=0.00
Epoch 40: loss=1711.08, gamma=0.00
Epoch 45: loss=1717.64, gamma=0.00
Epoch 50: loss=1733.25, gamma=0.00
Epoch 55: loss=1730.48, gamma=0.00
Epoch 60: loss=1732.08, gamma=0.00
Epoch 65: loss=1732.07, gamma=0.00
Epoch 70: loss=1729.36, gamma=0.00
Epoch 75: loss=1733.86, gamma=0.00
Epoch 80: loss=1727.45, gamma=0.00
Epoch 85: loss=1733.35, gamma=0.00
Epoch 90: loss=1739.07, gamma=0.00
Epoch 95: loss=1736.67, gamma=0.00
Fraction of possible edges in mask: 0.007663761828259913
Epoch 0: loss=2222.37, gamma=0.00
Epoch 5: loss=1932.78, gamma=39.49
Epoch 10: loss=1936.42, gamma=78.98
Epoch 15: loss=1932.02, gamma=118.46
Epoch 20: loss=1934.17, gamma=118.46
Epoch 25: loss=1936.67, gamma=118.46
Epoch 30: loss=1935.70, gamma

VBox(children=(Label(value='0.009 MB of 0.021 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=0.453287…

0,1
alpha,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
dag,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
epoch_loss,█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▅▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃▃
gamma,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▆█████████████████
is_prescreen,████████████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l1,▅▆▇▇▇███████████████▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
l2,▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇▇
min_dag_threshold,▁
n_edges_min_dag,▁

0,1
alpha,6e-05
dag,0.05561
epoch,195.0
epoch_loss,1942.08581
gamma,118.46419
is_prescreen,0.0
l1,0.48184
l2,175.24546
min_dag_threshold,0.25195
n_edges_min_dag,2429.0


[34m[1mwandb[0m: Agent Starting Run: v0evv4rl with config:
[34m[1mwandb[0m: 	max_gamma: 788.5012912260496
[34m[1mwandb[0m: 	mv_flavor: nn
[34m[1mwandb[0m: 	s1_alpha: 0.0001521174327874298
[34m[1mwandb[0m: 	s1_beta: 1.9629115357268193e-05
[34m[1mwandb[0m: 	s2_alpha: 4.915065785584007e-05
[34m[1mwandb[0m: 	s2_beta: 0.0058967113188155545


Using cpu




VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668109583163945, max=1.0…

Epoch 0: loss=2562.72, gamma=0.00
Epoch 5: loss=1551.51, gamma=0.00
Epoch 10: loss=1447.75, gamma=0.00
Epoch 15: loss=1385.96, gamma=0.00
Epoch 20: loss=1337.06, gamma=0.00
Epoch 25: loss=1303.06, gamma=0.00
Epoch 30: loss=1271.99, gamma=0.00
Epoch 35: loss=1254.63, gamma=0.00
Epoch 40: loss=1236.33, gamma=0.00
Epoch 45: loss=1222.99, gamma=0.00
Epoch 50: loss=1207.96, gamma=0.00
Epoch 55: loss=1204.39, gamma=0.00
Epoch 60: loss=1191.77, gamma=0.00
Epoch 65: loss=1190.78, gamma=0.00
Epoch 70: loss=1182.83, gamma=0.00
Epoch 75: loss=1185.46, gamma=0.00
Epoch 80: loss=1170.36, gamma=0.00
Epoch 85: loss=1168.90, gamma=0.00
Epoch 90: loss=1160.68, gamma=0.00
Epoch 95: loss=1157.12, gamma=0.00
Fraction of possible edges in mask: 0.4162181691833151
Epoch 0: loss=2241.35, gamma=0.00
Epoch 5: loss=1866.27, gamma=39.82
Epoch 10: loss=1893.37, gamma=79.65


[34m[1mwandb[0m: Ctrl + C detected. Stopping sweep.


## Final Run

In [7]:
mv_flavor = "nn"
s1_alpha = 3e-4
s2_alpha = 2e-4
s1_beta = 1e-3
s2_beta = 1e-2
max_gamma = 150

train_dataset, val_dataset = torch.utils.data.random_split(dataset, [0.9, 0.1])

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Using {device}")
model = SDCI(model_variance_flavor=mv_flavor)
model.train(
    train_dataset,
    device=device,
    log_wandb=True,
    verbose=False,
    stage1_kwargs={"n_epochs": 100, "alpha": s1_alpha, "beta": s1_beta, "n_epochs_check": 5},
    stage2_kwargs={"n_epochs": 200, "alpha": s2_alpha, "beta": s2_beta, "max_gamma": max_gamma, "n_epochs_check": 5}
)
val_rec_loss = model._model.reconstruction_loss(val_dataset[:][0].to(device), mask_interventions_oh=val_dataset[:][1].to(device)).cpu().detach().item()

min_dag_threshold = model.compute_min_dag_threshold()
n_edges_min_dag = model.get_adjacency_matrix().sum()


Using cpu


[34m[1mwandb[0m: Currently logged in as: [33mjustinhong[0m ([33mazizi-causal-perturb[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch 0: loss=2568.90, gamma=0.00
Epoch 5: loss=1589.86, gamma=0.00
Epoch 10: loss=1518.75, gamma=0.00
Epoch 15: loss=1484.07, gamma=0.00
Epoch 20: loss=1456.98, gamma=0.00
Epoch 25: loss=1440.65, gamma=0.00
Epoch 30: loss=1425.02, gamma=0.00
Epoch 35: loss=1412.95, gamma=0.00
Epoch 40: loss=1400.56, gamma=0.00
Epoch 45: loss=1393.00, gamma=0.00
Epoch 50: loss=1391.28, gamma=0.00
Epoch 55: loss=1388.51, gamma=0.00
Epoch 60: loss=1393.43, gamma=0.00
Epoch 65: loss=1388.32, gamma=0.00
Epoch 70: loss=1406.35, gamma=0.00
Epoch 75: loss=1396.85, gamma=0.00
Epoch 80: loss=1399.29, gamma=0.00
Epoch 85: loss=1401.64, gamma=0.00
Epoch 90: loss=1396.68, gamma=0.00
Epoch 95: loss=1416.35, gamma=0.00
Fraction of possible edges in mask: 0.10504401691549066
Epoch 0: loss=2104.48, gamma=0.00
Epoch 5: loss=1770.53, gamma=3.77
Epoch 10: loss=1775.59, gamma=7.54
Epoch 15: loss=1783.65, gamma=11.31
Epoch 20: loss=1774.82, gamma=15.08
Epoch 25: loss=1786.33, gamma=18.84
Epoch 30: loss=1782.88, gamma=22.61

In [9]:
dag_thresh_adj_mtx = model.get_adjacency_matrix()
no_thresh_adj_mtx = model.get_adjacency_matrix(threshold=False)
fixed_thresh_adj_mtx = (no_thresh_adj_mtx > 0.3).astype(int)

In [10]:
min_dag_threshold

0.5586318969726562

In [11]:
n_edges_min_dag

1574

In [13]:
np.savetxt("results/dag_thresh_adj_mtx.csv", dag_thresh_adj_mtx, delimiter=",")
np.savetxt("results/no_thresh_adj_mtx.csv", no_thresh_adj_mtx, delimiter=",")
np.savetxt("results/fixed_thresh_adj_mtx.csv", fixed_thresh_adj_mtx, delimiter=",")