In [1]:
#!/usr/bin/env python
# coding: utf-8


import os
import torch, pyro, numpy as np
torch.set_default_tensor_type(torch.cuda.FloatTensor)

import swyft
import click


DEVICE = 'cuda'

from utils import *
from network import CustomHead, CustomTail

In [2]:
m = 1
nsub = 3
nsim = 100

nmbins = 2

lr = 1e-3
factor = 1e-1
patience = 5
max_epochs = 4

In [3]:
# Set definitions (should go to click)
system_name = "ngc4414"

In [4]:
# Set utilities
sim_name, sim_path = get_sim_path(m, nsub, nsim, system_name)
store = swyft.DirectoryStore(path=sim_path)
print(f'Store has {len(store)} simulations.')

torch.set_default_tensor_type(torch.cuda.FloatTensor)  # HACK
CONFIG = get_config(system_name, str(nsub), str(m))
torch.set_default_tensor_type(torch.FloatTensor)

prior, uv, lows, highs = get_prior(CONFIG)
L = CONFIG.kwargs["defs"]["nx"]
print(f'Image has L = {L}.')

Store _M_m1_nsub3_nsim100 exists!
Loading existing store.
Store has 118 simulations.
Image has L = 40.


In [5]:
# Set up posterior
torch.set_default_tensor_type(torch.FloatTensor)
dataset = swyft.Dataset(nsim, prior, store)#, simhook = noise)
marginals = [i for i in range(L**2)]
post = swyft.Posteriors(dataset)

In [6]:
# Train
post_name, post_path = get_post_path(sim_name, nmbins, lr, factor, patience)
print(f'Training {post_name}!')

torch.set_default_tensor_type(torch.FloatTensor)
post = swyft.Posteriors(dataset)
post.add(marginals, device = DEVICE, 
         tail_args = dict(nmbins = nmbins, lows = lows, highs = highs),
         head = CustomHead, tail = CustomTail)
post.train(marginals, max_epochs = max_epochs,
           optimizer_args = dict(lr=lr),
           scheduler_args = dict(factor = factor, patience = patience)
          )
post.save(post_path)

Training UNet_M_m1_nsub3_nsim100_nmbins2_lr-3.0_fac-1.0_pat5.pt!
tensor([1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1, 1, 1,
        1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 2, 1, 1, 2, 1, 2, 2, 1, 1,
        1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 2, 1, 1, 2, 1, 2, 2, 1, 1, 2,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1, 2,
        1, 1, 2, 1, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1,
        1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2,
        2, 2, 1, 2, 2, 2, 2, 2, 1, 2, 2, 2, 1, 2, 1, 1, 2, 2, 1, 2, 1, 1, 2, 2,
        2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 1, 1, 1, 2, 1, 2, 1, 1, 1, 2, 1, 2, 1, 1,
        1, 1, 2, 1, 2, 2, 1, 1, 2, 1, 2, 2, 1, 1, 1, 1, 2, 2, 1, 1, 1, 1, 2, 2,
        2, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 2, 1, 1, 2, 2, 2, 2, 1, 1, 2, 2,
        1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 1, 2, 2, 2, 1, 1, 1, 2, 2, 2, 1,
        2, 2, 1, 2, 2, 1, 2, 2, 1, 2, 2, 1, 1, 2, 1, 1,

In [7]:
print('Done!')

Done!
