In [1]:
%load_ext autoreload
%autoreload 2

import torch, click, numpy as np

import swyft
from utils import *
from plot import *

DEVICE = 'cuda'


from torch import tensor

imkwargs = dict(extent=(-2.5, 2.5, -2.5, 2.5), origin='lower')


In [2]:
m = 1
nsub = 3
# nsim = 25000
nsim = 200
nsim = 25000

nmbins = 4

lr = 1e-3
factor = 1e-1
patience = 5

system_name = "ngc4414"

In [None]:
# Set utilities
sim_name, sim_path = get_sim_path(m, nsub, nsim, system_name)
store = swyft.Store.load(path=sim_path)
print(f'Store has {len(store)} simulations.')

config = get_config(system_name, str(nsub), str(m))

prior, n_pars, lows, highs = get_prior(config)
L = config.kwargs["defs"]["nx"]
print(f'Image has L = {L}.')

dataset = swyft.Dataset(nsim, prior, store)#, simhook = noise)

Store _M_m1_nsub3_nsim25000 exists!
Loading existing store.
Store has 24938 simulations.
Image has L = 40.


In [None]:
# Load network
mre_name, mre_path = get_mre_path(sim_name, nmbins, lr, factor, patience)
print(f'Loading {mre_name}!')

torch.set_default_tensor_type(torch.cuda.FloatTensor)  # HACK

marginal_indices, _ = swyft.utils.get_corner_marginal_indices(n_pars)

network = get_custom_marginal_classifier(
    observation_transform = CustomObservationTransform('image', {'image': (L, L)}),
    marginal_indices = marginal_indices,
    L = L,
    nmbins = nmbins, 
    lows = lows,
    highs = highs,
    marginal_classifier = CustomMarginalClassifier,
)

mre = swyft.MarginalRatioEstimator.load(
    network=network,
    device=DEVICE,
    filename=mre_path,
)

torch.set_default_tensor_type(torch.FloatTensor)

In [None]:
plot_losses(mre)

In [None]:
def get_coords(nmbins, L):
    grid = torch.linspace(lows[0], highs[0], L+1)[:-1]
    x, y = torch.meshgrid(grid, grid, indexing = 'xy')
    m = torch.logspace(lows[-1], highs[-1], 2*nmbins+1)
    m_centers, m_edges = m[1::2], m[0::2]
    ms = [torch.full((L*L,), m_i) for m_i in m_centers]

    coord_empty = torch.tensor((), device = DEVICE, dtype = torch.float).reshape(1, -1)
    
    coord_full = []
    for m in ms:
        coord = torch.transpose(torch.stack((x.flatten(), y.flatten(), m)), 0, 1).reshape(1, -1).to(DEVICE, dtype = torch.float)
        coord_full.append(coord)

    coords = [coord_empty] + coord_full
    return coords, coord_empty, coord_full, m_centers, m_edges 

def get_obs(store, i = -1):
    obs0_i = np.random.randint(0, len(store)) if i == -1 else i
    print(f'i = {obs0_i}')
    
    obs0 = store[obs0_i][0]
    v0 = store[obs0_i][1]

    obs0['image'] = tensor(obs0['image']).unsqueeze(0).to(DEVICE, dtype = torch.float)
    v0 = tensor(v0).unsqueeze(0).to(DEVICE)
    
    return obs0, v0, obs0_i

coords, coord_empty, coord_full, m_centers, m_edges = get_coords(nmbins, L)

mbins = np.linspace(lows[2], highs[2], nmbins + 1)
title_mbins = [f'mass {mbins[i]} - {mbins[i+1]}' for i in range(nmbins)]
title_halos = [f'{j} {k}' for j in ['no halo', 'halo'] for k in title_mbins ]

In [None]:
obs0, v0, obs0_i = get_obs(store , i = 7768)

plots = []
print('target')
target_plots = mre.network.parameter_transform(v0).squeeze()
plt_imshow(target_plots, nrows = 2, titles = title_halos, cbar = True)
plots.append(target_plots)

# print('v0')
# v0_plots = mre.network(obs0, v0).view(nmbins*2, L, L)
# plt_imshow(v0_plots, nrows = 2, titles = title_halos, cbar = True)
# plots.append(v0_plots)

# for coord, p in zip(coords, ['empty'] + title_mbins):
#     print(p)
#     coords_plots = mre.network(obs0, coord).view(nmbins*2, L, L)
#     plots.append(coords_plots)
#     plt_imshow(coords_plots, nrows = 2, titles = title_halos, cbar = True)
    
# plots = torch.cat(plots)
# print(plots.shape)

In [None]:
for _ in range(1):
    
    obs0, v0, obs0_i = get_obs(store, i = obs0_i)

    target = mre.network.parameter_transform(v0).squeeze()

    ratios = np.zeros((nmbins*2, L, L))

    plots = mre.network(obs0, coord_empty).view(nmbins*2, L, L)
    ratios[:nmbins] = plots[:nmbins]

    for i, coord in enumerate(coord_full):
        ratio = mre.network(obs0, coord).view(nmbins*2, L, L)
        ratios[i+nmbins] = ratio[i+nmbins]

    ratios = np.exp(ratios)

    titles = [f'{i} {j}' for i in ['target', 'ratio '] for j in title_halos ]

    plots = np.concatenate((target, ratios))
#     plt_imshow(plots, nrows = 4, titles = titles, cbar = True, tl = True, y = 18, **imkwargs)
    print()

In [None]:
# Old stuff about weighing the frequency of halos in mass bins

# from pyrofit.lensing.distributions import get_default_shmf
# low, high = lows[-1], highs[-1]
# z_lens = config.kwargs['defs']['z_lens']
# shmf = get_default_shmf(z_lens = z_lens, log_range = (low, high))

# m_weight = tensor([torch.diff(
#     shmf.cdf(torch.pow(10, tensor([mbins[i], mbins[i+1]])))
#     ) for i in range(len(mbins)-1)])

# mbins, m_weight

# m_weights = np.tile(m_weight, 2)

# plt.plot(torch.logspace(low, high), shmf.cdf(torch.logspace(low, high)))
# plt.plot(torch.logspace(low, high), 1 - shmf.cdf(torch.logspace(low, high)))
# plt.xscale('log')

In [None]:
prior = nsub/(L*L)/nmbins
prior0 = 1 - prior
prior1 = prior

prior0, prior1

In [None]:
priors = np.repeat(np.array([prior0, prior1]), nmbins) 

posts = ratios * (priors)[:, np.newaxis, np.newaxis]

In [None]:
# plots = [np.sum(posts, axis = 0)]
# plt_imshow(plots, cbar = True, y = 4)

In [None]:
posts_sum = np.sum(posts.reshape(2, nmbins, L, L).transpose([1,0,2,3]), axis = 1)
posts_sum = np.tile(posts_sum, (2,1,1))
# plt_imshow(posts_sum, 2, cbar = True, tl = True)

In [None]:
posts_norm = posts / posts_sum
# plt_imshow( posts_norm , 2, cbar = True, tl = True)
# posts_norm = posts_norm[nmbins:]

## Final results

In [None]:
plt_imshow(posts, nrows = 2, cbar = True, tl = True, **imkwargs)

In [None]:
plots = np.concatenate((target[nmbins:].numpy(), posts_norm[nmbins:]))
plt_imshow(plots, 2, cbar = True, tl = True)

In [None]:
plt_imshow(posts_norm[nmbins:], cbar = True, ylog = True, tl = True)

# Trying some stuff

In [None]:
# posts_norm = posts/np.sum(posts, axis = 0)

# posts0 = posts_norm[:nmbins]
# posts1 = posts_norm[nmbins:]

# posts0_norm = posts0/np.sum(posts0, axis = 0)
# posts1_norm = posts1/np.sum(posts1, axis = 0)


In [None]:
# print('All')
# plt_imshow(posts_norm, nrows = 2, cbar = True, tl = True)
# print("Only 'no halo' (0)")
# plt_imshow(posts0_norm, cbar = True, tl = True)
# print("Only 'halo' (1)")
# plt_imshow(posts1_norm, cbar = True, tl = True)

In [None]:
# print('All')
# plt_imshow(posts_norm, nrows = 2, cbar = True, tl = True, ylog = True)
# print("Only 'no halo' (0)")
# plt_imshow(posts0_norm, cbar = True, tl = True, ylog = True)
# print("Only 'halo' (1)")
# plt_imshow(posts1_norm, cbar = True, tl = True, ylog = True)