In [2]:
#!/usr/bin/env python
# coding: utf-8


import os
import torch, pyro, numpy as np
torch.set_default_tensor_type(torch.cuda.FloatTensor)

import swyft
import click


DEVICE = 'cuda'

from utils import *
from network import UNET, CustomHead

In [3]:
import optuna
import joblib

In [3]:
# @click.command()
# @click.option("--m",    type=int, default = 12,  help="Exponent of subhalo mass.")
# @click.option("--nsub", type=int, default = 1,   help="Number of subhaloes.")
# @click.option("--nsim", type=int, default = 100, help="Number of simulations to run.")

# @click.option("--lr",         type=float, default = 1e-3, help="Learning rate.")
# @click.option("--factor",     type=float, default = 1e-1, help = "Factor of Scheduler")
# @click.option("--patience",   type=int,   default = 5,    help = "Patience of Scheduler")
# @click.option("--max_epochs", type=int,   default = 30,   help = "Max number of epochs.")

In [4]:
m = 9
nsub = 3
nsim = 10000

lr = 1e-3
factor = 1e-1
patience = 5
max_epochs = 3

In [5]:
SYSTEM_NAME = "ngc4414"
RUN = f'_m{m}_nsub{nsub}_nsim{nsim}'
assert os.path.exists(f'/nfs/scratch/eliasd/store{RUN}.sync')
SIM_PATH = f'/nfs/scratch/eliasd/store{RUN}.zarr' 
print('run', RUN)

run _m9_nsub3_nsim10000


In [12]:
RUN

'_m9_nsub3_nsim10000'

In [6]:
# Set utilities
store = swyft.DirectoryStore(path=SIM_PATH)
print(f'Store has {len(store)} simulations')

torch.set_default_tensor_type(torch.cuda.FloatTensor)  # HACK
CONFIG = get_config(SYSTEM_NAME, str(nsub), str(m))
torch.set_default_tensor_type(torch.FloatTensor)

prior, uv = get_prior(CONFIG)

Loading existing store.
Store has 9907 simulations


In [7]:
# Set up posterior
idx = 0
img_0 = store[idx][0]['image']
L1, L2 = torch.tensor(img_0.shape)
assert L1 == L2
L = L1.item()
print(f'L = {L}')

torch.set_default_tensor_type(torch.FloatTensor)
dataset = swyft.Dataset(100, prior, store)#, simhook = noise)
marginals = [i for i in range(L**2)]
post = swyft.Posteriors(dataset)

L = 40


In [None]:
L

In [8]:
# Train

def objective(trail):
    lr       = trail.suggest_float('lr', 1e-5, 1e-1, log = True)
    factor   = trail.suggest_float('factor', 1e-4, 1e-1, log = True)
    patience = trail.suggest_int('patience', 2, 5)

    save_name, save_path = get_name(RUN, lr, factor, patience, 'posts_gridsearch')
    print(f'Training {save_name}!')

    torch.set_default_tensor_type(torch.FloatTensor)
    post = swyft.Posteriors(dataset)
    post.add(marginals, device = DEVICE, head = CustomHead, tail = UNET)
    post.train(marginals, max_epochs = max_epochs,
               optimizer_args = dict(lr=lr),
               scheduler_args = dict(factor = factor, patience = patience)
              )

    epoch, tl, vl = get_losses(post)
    post.save(save_path)
    
    return vl[-1]

In [9]:
study = optuna.create_study(direction="minimize")
study.optimize(objective, n_trials=3)

[32m[I 2021-12-06 12:38:37,388][0m A new study created in memory with name: no-name-aabb64df-a044-47bc-b338-127202a2f5e5[0m


Training UNet_m9_nsub3_nsim10000_lr-1.1029432224148172_fac-2.410522606112229_pat4.pt!
Training: lr=0.079, Epoch=3, VL=3.837e+15


[32m[I 2021-12-06 12:38:39,961][0m Trial 0 finished with value: 3836944198926336.0 and parameters: {'lr': 0.07889632561623462, 'factor': 0.0038857727101072157, 'patience': 4}. Best is trial 0 with value: 3836944198926336.0.[0m


Training UNet_m9_nsub3_nsim10000_lr-1.767950070640385_fac-3.759193575348488_pat4.pt!
Training: lr=0.017, Epoch=3, VL=3.081e+12


[32m[I 2021-12-06 12:38:42,297][0m Trial 1 finished with value: 62066950144.0 and parameters: {'lr': 0.017062785427686106, 'factor': 0.00017410306817631252, 'patience': 4}. Best is trial 1 with value: 62066950144.0.[0m


Training UNet_m9_nsub3_nsim10000_lr-2.401718425175746_fac-1.2027074914660374_pat5.pt!
Training: lr=0.004, Epoch=3, VL=1.847e+06


[32m[I 2021-12-06 12:38:44,592][0m Trial 2 finished with value: 2377.765625 and parameters: {'lr': 0.003965350444233987, 'factor': 0.06270360474300139, 'patience': 5}. Best is trial 2 with value: 2377.765625.[0m


In [10]:
joblib.dump(study, 'studies/study.pkl')

['studies/study.pkl']