In [1]:
#!/usr/bin/env python
# coding: utf-8


import os
import torch, pyro, numpy as np
torch.set_default_tensor_type(torch.cuda.FloatTensor)

import swyft
import click
from swyft.networks.normalization import OnlineNormalizationLayer


from torch import tensor
import torch.nn as nn
import torchvision.transforms.functional as TF

DEVICE = 'cuda'

from utils import *
from network import CustomTail, CustomHead, Mapping

import sys
sys.path.append('/home/eliasd/lensing/elias_utils')
from plotting import *

imkwargs = dict(extent=(-2.5, 2.5, -2.5, 2.5), origin='lower')




m = 1
nsub = 3
nsim = 100

lr = 1e-3
factor = 1e-1
patience = 5
max_epochs = 30

n_m = 2

In [2]:
SYSTEM_NAME = "ngc4414"
RUN = f'_M_m{m}_nsub{nsub}_nsim{nsim}'
assert os.path.exists(f'/nfs/scratch/eliasd/store{RUN}.sync')
SIM_PATH = f'/nfs/scratch/eliasd/store{RUN}.zarr' 
print('run', RUN)

run _M_m1_nsub3_nsim100


In [3]:
# Set utilities
store = swyft.DirectoryStore(path=SIM_PATH)
print(f'Store has {len(store)} simulations')
L = store[0][0]['image'].shape[1]

torch.set_default_tensor_type(torch.cuda.FloatTensor)  # HACK
CONFIG = get_config(SYSTEM_NAME, str(nsub), str(m))
torch.set_default_tensor_type(torch.FloatTensor)

Loading existing store.
Store has 83 simulations


In [4]:
torch.set_default_tensor_type(torch.cuda.FloatTensor)  # HACK
ppd = CONFIG.ppd()['model_trace'].nodes
torch.set_default_tensor_type(torch.FloatTensor)

In [5]:
prior, uv = get_prior(CONFIG)

In [6]:
torch.set_default_tensor_type(torch.FloatTensor)
dataset = swyft.Dataset(100, prior, store)#, simhook = noise)
marginals = [i for i in range(L**2)]
post = swyft.Posteriors(dataset)

In [7]:
save_name, save_path = get_name(RUN, lr, factor, patience)

In [8]:
# Train
print(f'Training {save_name}!')
torch.set_default_tensor_type(torch.FloatTensor)
post = swyft.Posteriors(dataset)
post.add(marginals, device = DEVICE, 
#          tail_args = dict(n_m = n_m),
         head = CustomHead, tail = CustomTail)
post.train(marginals, max_epochs = max_epochs,
           optimizer_args = dict(lr=lr),
           scheduler_args = dict(factor = factor, patience = patience)
          )

post.save(save_path)
print('Done!')

Training UNet_M_m1_nsub3_nsim100_lr-3.0_fac-1.0_pat5.pt!


TypeError: __init__() got an unexpected keyword argument 'n_m'

In [None]:
swyft.Posteriors.load(save_path, dataset = dataset)

In [None]:
print(f'Loading {save_name}!')
post = swyft.Posteriors.load(save_path).to(DEVICE)


In [None]:
plot_losses(post)

# Analyze

In [None]:
# torch.set_default_tensor_type(torch.cuda.DoubleTensor)


In [None]:

def get_coords(n_m, L):
    grid = torch.linspace(0,(L-1)/L,L)
    x, y = torch.meshgrid(grid, grid, indexing = 'xy')
    ms = [torch.full((L*L,), m_i) for m_i in np.linspace(0, 1, 2*n_m+1)[1::2]]

    coords = [tensor(()).view(1, -1)]
    for m in ms:
        coord = torch.transpose(torch.stack((x.flatten(), y.flatten(), m)), 0, 1).reshape(1, -1)
        coords.append(coord)
    return coords

coords = get_coords(n_m, L)

In [None]:
Map = Mapping(n_m, L)
plots = torch.cat([Map.coord_to_map(coord.to(DEVICE)).cpu().squeeze() for coord in coords])
plt_imshow(plots, 3, cbar = True)

In [None]:
def get_net(post):
    re = list(post._ratios.values())[0]
    head = re.head.eval()
    tail = re.tail.eval()
    return head, tail

head, tail = get_net(post)

In [None]:
for i in np.random.randint(0, len(store), 3):
    obs0 = store[i][0]
    v0 = store[i][1]

#     obs0['image'] = tensor(obs0['image'], device = 'cpu', dtype = torch.float32).unsqueeze(0)
#     v0 = tensor(v0, device = 'cpu', dtype = torch.float32).unsqueeze(0)
    obs0['image'] = tensor(obs0['image']).unsqueeze(0)
    v0 = tensor(v0).unsqueeze(0)

    f = head(obs0).to(DEVICE, dtype = torch.float)

    ratios = np.zeros((n_m + 1, L, L))

    for i, coord in enumerate(coords):

        params = coord.to(DEVICE, dtype = torch.float)
        logratio = tail(f, params)
        ratio = np.exp(logratio.numpy()).reshape(n_m + 1, L, L)

        ratios[i] = ratio[i]

    lows = [-2.5, -2.5, 10]
    highs = [2.5, 2.5, 12]
    u0 = Map.coord_vu(v0, lows, highs)

    target = Map.coord_to_map(u0).squeeze().numpy()
    plots = np.concatenate((target, ratios))
    plt_imshow(plots, 2, cbar = True)
    print()