In [None]:
#!/usr/bin/env python
# coding: utf-8


import os, datetime
import torch, pyro, numpy as np
torch.set_default_tensor_type(torch.cuda.FloatTensor)

import swyft
import click


DEVICE = 'cuda'

from utils import *
# from network import CustomTail, CustomHead

from swyft.utils import tensor_to_array, array_to_tensor
from toolz import compose
from pyrofit.lensing.distributions import get_default_shmf

import torch, numpy as np
torch.set_default_tensor_type(torch.cuda.FloatTensor)
from torch import tensor
import torch.nn as nn
import torchvision.transforms.functional as TF


# @click.command()
# @click.option("--m",    type=int, default = 12,  help="Exponent of subhalo mass.")
# @click.option("--nsub", type=int, default = 1,   help="Number of subhaloes.")
# @click.option("--nsim", type=int, default = 100, help="Number of simulations to run.")

# @click.option("--nmbins",  type=int, default = 2,   help="Number of mass bins.")

# @click.option("--lr",         type=float, default = 1e-3, help="Learning rate.")
# @click.option("--factor",     type=float, default = 1e-1, help = "Factor of Scheduler")
# @click.option("--patience",   type=int,   default = 5,    help = "Patience of Scheduler")
# @click.option("--max_epochs", type=int,   default = 30,   help = "Max number of epochs.")



m = 0
nsub = 3
nsim = 200

nmbins = 2

lr = 1e-3
factor = 1e-1
patience = 5
max_epochs = 1

In [None]:
time_start = datetime.datetime.now()

# Set definitions (should go to click)
system_name = "ngc4414"

# Set utilities
sim_name, sim_path = get_sim_path(m, nsub, nsim, system_name)
store = swyft.Store.load(path=sim_path)
print(f'Store has {len(store)} simulations.')

torch.set_default_tensor_type(torch.cuda.FloatTensor)  # HACK
CONFIG = get_config(system_name, str(nsub), str(m))
torch.set_default_tensor_type(torch.FloatTensor)

prior, n_pars, lows, highs = get_prior(CONFIG)
L = CONFIG.kwargs["defs"]["nx"]
print(f'Image has L = {L}.')

# Set up posterior
torch.set_default_tensor_type(torch.FloatTensor)
dataset = swyft.Dataset(nsim, prior, store)#, simhook = noise)
# marginals = [i for i in range(L**2)]
# post = swyft.Posteriors(dataset)

In [None]:
# Train
post_name, post_path = get_post_path(sim_name, nmbins, lr, factor, patience)
print(f'Training {post_name}!')

In [None]:
class Mapping:
    def __init__(self, nmbins, L, lows, highs):
        self.nmbins = nmbins
        self.L   = L
        self.lows = lows
        self.highs = highs

    def coord_vu(self, coords_v):
                        
        n = len(coords_v[0])/3
        assert n.is_integer()
        n = int(n)

        lows = np.full(coords_v.shape, np.tile(self.lows, n))
        highs = np.full(coords_v.shape, np.tile(self.highs, n))   
                
        u = lambda v: (v - lows) / (highs - lows)
        coords_u = u(coords_v)
        return coords_u

    def coord_to_map(self, XY_u):

        
        n_batch =  XY_u.shape[0]
        n_coords = XY_u.shape[1]*2/3
        assert n_coords.is_integer()

        z = torch.zeros((n_batch, self.nmbins + 1, self.L, self.L), device = DEVICE)
                
        if not (n_batch == 0 or n_coords == 0):
            
            x_sub_u, y_sub_u, log10_m_sub_u = XY_u.view(-1,3).T.to(DEVICE)

            x_i = torch.floor((x_sub_u*self.L).flatten()).type(torch.long) 
            y_i = torch.floor((y_sub_u*self.L).flatten()).type(torch.long) 
            m_i = torch.floor( log10_m_sub_u * self.nmbins ).type(torch.long) + 1
            
            i   = torch.floor(torch.arange(0, n_batch, 1/n_coords*2).to(DEVICE)).type(torch.long)
            xx = tuple(torch.stack((i, m_i, y_i, x_i)))
            z[xx] = 1

            xx = tuple(torch.stack((i, torch.zeros_like(m_i), y_i, x_i)))
            z[xx] = 1
            
        z[:,0] = 1 - z[:,0]

        return z


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), # bias = False becaise BatchNorm2d is set
            nn.BatchNorm2d(out_channels), # BatchNorm2d were not known when paper came out
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(self, in_channels, out_channels, features = [64, 128, 256, 512]):
        super(UNET, self).__init__()
                
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # keep size the same
        

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
        

    def forward(self, x, target):
                
        x = x.unsqueeze(1)
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1] # reverse list

        # the upsampling
        for idx in range(0, len(self.ups), 2): # step of 2 because we want up - double column - up - double column
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2] # //2 because we want still steps of one

            # if statement because we can put in shapes that are not divisble by two around 19:00 of video
            if x.shape != skip_connection.shape: 
                x = TF.resize(x, size=skip_connection.shape[2:]) # hopefully does not impact accuracy too much

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        x = self.final_conv(x)
        return x

In [None]:
    
class CustomHead(swyft.Module):
    def __init__(self, obs_shapes) -> None:
        super().__init__(obs_shapes=obs_shapes)
        self.n_features = torch.prod(tensor(obs_shapes['image']))
#         self.onl_norm = OnlineNormalizationLayer(torch.Size([self.n_features]))

    def forward(self, obs) -> torch.Tensor:
        x = obs["image"]
        n_batch = len(x)
        x = x.view(n_batch, self.n_features)
#         x = self.onl_norm(x)    
        return x


class CustomTail(swyft.Module):
    def __init__(self, n_features, marginals, **tail_args):
        super().__init__(n_features = n_features, marginals = marginals, **tail_args)
        
        
        self.n_features = n_features
        self.L = int(np.sqrt(n_features).item())
        self.nmbins = tail_args['nmbins']
        self.lows   = tail_args['lows']
        self.highs  = tail_args['highs']
        self.out_channels = self.nmbins + 1
        
        self.Map  = Mapping(self.nmbins, self.L, self.lows, self.highs)
        self.UNet = UNET(in_channels = 1, out_channels = self.out_channels)
        
       
    def forward(self, sims, target):
        
        sims = sims.view(-1, self.L, self.L)
        
        x = self.UNet(sims, target)
        z = self.Map.coord_to_map(target)
        
        x = x * z
        x = x.view(-1, self.n_features * self.out_channels)
        
        return x

In [None]:
torch.set_default_tensor_type(torch.FloatTensor)


In [None]:
class CustomHead(swyft.Module):
    def __init__(self, obs_shapes) -> None:
        super().__init__(obs_shapes=obs_shapes)
        self.n_features = torch.prod(tensor(obs_shapes['image']))
#         self.onl_norm = OnlineNormalizationLayer(torch.Size([self.n_features]))

    def forward(self, obs) -> torch.Tensor:
        x = obs["image"]
        n_batch = len(x)
        x = x.view(n_batch, self.n_features)
#         x = self.onl_norm(x)    
        return x

In [None]:
class Mapping:
    def __init__(self, nmbins, L, lows, highs):
        self.nmbins = nmbins
        self.L   = L
        self.lows = lows
        self.highs = highs

    def coord_vu(self, coords_v):
                        
        n = len(coords_v[0])/3
        assert n.is_integer()
        n = int(n)

        lows = np.full(coords_v.shape, np.tile(self.lows, n))
        highs = np.full(coords_v.shape, np.tile(self.highs, n))   
                
        u = lambda v: (v - lows) / (highs - lows)
        coords_u = u(coords_v)
        return coords_u

    def coord_to_map(self, XY_u):
        
        assert 1 == 2

        
        n_batch =  XY_u.shape[0]
        n_coords = XY_u.shape[1]*2/3
        assert n_coords.is_integer()

        z = torch.zeros((n_batch, self.nmbins + 1, self.L, self.L), device = DEVICE)
                
        if not (n_batch == 0 or n_coords == 0):
            
            x_sub_u, y_sub_u, log10_m_sub_u = XY_u.view(-1,3).T.to(DEVICE)

            x_i = torch.floor((x_sub_u*self.L).flatten()).type(torch.long) 
            y_i = torch.floor((y_sub_u*self.L).flatten()).type(torch.long) 
            m_i = torch.floor( log10_m_sub_u * self.nmbins ).type(torch.long) + 1
            
            i   = torch.floor(torch.arange(0, n_batch, 1/n_coords*2).to(DEVICE)).type(torch.long)
            xx = tuple(torch.stack((i, m_i, y_i, x_i)))
            z[xx] = 1

            xx = tuple(torch.stack((i, torch.zeros_like(m_i), y_i, x_i)))
            z[xx] = 1
            
        z[:,0] = 1 - z[:,0]

        return z


class DoubleConv(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False), # bias = False becaise BatchNorm2d is set
            nn.BatchNorm2d(out_channels), # BatchNorm2d were not known when paper came out
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

class UNET(nn.Module):
    def __init__(self, in_channels, out_channels, features = [64, 128, 256, 512]):
        super(UNET, self).__init__()
                
        self.ups = nn.ModuleList()
        self.downs = nn.ModuleList()
        self.pool = nn.MaxPool2d(kernel_size=2, stride=2) # keep size the same
        

        # Down part of UNET
        for feature in features:
            self.downs.append(DoubleConv(in_channels, feature))
            in_channels = feature

        # Up part of UNET
        for feature in reversed(features):
            self.ups.append(
                nn.ConvTranspose2d(
                    feature*2, feature, kernel_size=2, stride=2,
                )
            )
            self.ups.append(DoubleConv(feature*2, feature))

        self.bottleneck = DoubleConv(features[-1], features[-1]*2)
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
        

    def forward(self, x, target):
                
        x = x.unsqueeze(1)
        skip_connections = []

        for down in self.downs:
            x = down(x)
            skip_connections.append(x)
            x = self.pool(x)

        x = self.bottleneck(x)
        skip_connections = skip_connections[::-1] # reverse list

        # the upsampling
        for idx in range(0, len(self.ups), 2): # step of 2 because we want up - double column - up - double column
            x = self.ups[idx](x)
            skip_connection = skip_connections[idx//2] # //2 because we want still steps of one

            # if statement because we can put in shapes that are not divisble by two around 19:00 of video
            if x.shape != skip_connection.shape: 
                x = TF.resize(x, size=skip_connection.shape[2:]) # hopefully does not impact accuracy too much

            concat_skip = torch.cat((skip_connection, x), dim=1)
            x = self.ups[idx+1](concat_skip)

        x = self.final_conv(x)
        return x

class CustomObservationTransform(torch.nn.Module):
    def __init__(self, observation_key: str, observation_shapes: dict):
        super().__init__()
        self.observation_key = observation_key
        self.n_features = torch.prod(tensor(observation_shapes[observation_key]))
#         self.online_z_score = swyft.networks.OnlineDictStandardizingLayer(observation_shapes)


    def forward(self, obs: dict) -> torch.Tensor:      
        x = obs
#         x = self.online_z_score(obs)
        x = x[self.observation_key]
        x = x.view(len(x), self.n_features)
        print(x)
        return x
    
class CustomTail(nn.Module):
    def __init__(self, n_features, marginals, **tail_args):
        super().__init__(n_features = n_features, marginals = marginals, **tail_args)
        
        
        self.n_features = n_features
        self.L = int(np.sqrt(n_features).item())
        self.nmbins = tail_args['nmbins']
        self.lows   = tail_args['lows']
        self.highs  = tail_args['highs']
        self.out_channels = self.nmbins + 1
        
        self.Map  = Mapping(self.nmbins, self.L, self.lows, self.highs)
        self.UNet = UNET(in_channels = 1, out_channels = self.out_channels)
        
       
    def forward(self, sims, target):
        
        sims = sims.view(-1, self.L, self.L)
        
        x = self.UNet(sims, target)
        z = self.Map.coord_to_map(target)
        
        x = x * z
        x = x.view(-1, self.n_features * self.out_channels)
        
        return x

class CustomMarginalClassifier(torch.nn.Module):
    def __init__(self, n_marginals: int, n_combined_features: int):
        super().__init__()
        self.n_marginals = n_marginals
        self.n_combined_features = n_combined_features
        
        self.n_features = 1600 #n_features
        self.L = int(np.sqrt(self.n_features).item())
        self.nmbins = 2 #tail_args['nmbins']
        self.lows   = lows #tail_args['lows']
        self.highs  = highs #tail_args['highs']
        self.out_channels = self.nmbins + 1
        
        self.Map  = Mapping(self.nmbins, self.L, self.lows, self.highs)
        self.UNet = UNET(in_channels = 1, out_channels = self.out_channels)
        
    def forward(self, sims, target):
        
        sims = sims.view(-1, self.L, self.L)
        
        x = self.UNet(sims, target)
        z = self.Map.coord_to_map(target)
        
        x = x * z
        x = x.view(-1, self.n_features * self.out_channels)
        
        return x
    
    def forward(
        self, features: torch.Tensor, marginal_block: torch.Tensor
    ) -> torch.Tensor:
        sims = features
        target = marginal_block
        
        x = self.UNet(sims, target)
        z = self.Map.coord_to_map(target)
        
        x = x * z
        x = x.view(-1, self.n_features * self.out_channels)
        
        return x
        
        
#         fb = features.unsqueeze(1).expand(-1, self.n_marginals, -1)  # B, M, O
#         combined = torch.cat([fb, marginal_block], dim=2)  # B, M, O + P
#         return self.net(combined).squeeze(-1)  # B, M
        
        
        
#         print(n_marginals, n_combined_features)

# class CustomMarginalClassifier(torch.nn.Module):
#     def __init__(
#         self,
#         n_marginals: int,
#         n_combined_features: int,
#         hidden_features: int,
#         num_blocks=2,
#     ) -> None:
#         super().__init__()
#         self.n_marginals = n_marginals
#         self.n_combined_features = n_combined_features

#         blocks = [
#             LinearWithChannel(self.n_marginals, self.n_combined_features, hidden_features),
#             torch.nn.ReLU(),
#             BatchNorm1dWithChannel(self.n_marginals, hidden_features),
#         ]
#         for _ in range(num_blocks - 1):
#             blocks.append(LinearWithChannel(self.n_marginals, hidden_features, hidden_features))
#             blocks.append(torch.nn.ReLU())
#             blocks.append(BatchNorm1dWithChannel(self.n_marginals, hidden_features))

#         self.net = torch.nn.Sequential(
#             *blocks,
#             LinearWithChannel(self.n_marginals, hidden_features, 1)
#         )

#     def forward(
#         self, features: torch.Tensor, marginal_block: torch.Tensor
#     ) -> torch.Tensor:
#         fb = features.unsqueeze(1).expand(-1, self.n_marginals, -1)  # B, M, O
#         combined = torch.cat([fb, marginal_block], dim=2)  # B, M, O + P
#         return self.net(combined).squeeze(-1)  # B, M

    
def get_custom_marginal_classifier(
    observation_transform,
    marginal_indices: tuple,
    n_parameters: int,
    marginal_classifier,
    parameter_online_z_score: bool = False
) -> torch.nn.Module:
    n_observation_features = observation_transform.n_features

    parameter_transform = swyft.networks.ParameterTransform(
        n_parameters, marginal_indices, online_z_score=parameter_online_z_score
    )
    
    n_marginals, n_block_parameters = parameter_transform.marginal_block_shape

    marginal_classifier = marginal_classifier(
        n_marginals,
        n_observation_features + n_block_parameters,
    )

    return swyft.networks.Network(
        observation_transform,
        parameter_transform,
        marginal_classifier,
    )
    

observation_key = 'image'
marginal_indices = [i for i in range(L**2)]
observation_shapes = {"image": (L, L)}
n_parameters = n_pars

observation_transform = CustomObservationTransform(observation_key, observation_shapes)
marginal_classifier = CustomMarginalClassifier
    
network = get_custom_marginal_classifier(
    observation_transform = observation_transform,
    marginal_indices = marginal_indices,
#     observation_shapes = observation_shapes,
    n_parameters= n_parameters,
    marginal_classifier = marginal_classifier)

mre = swyft.MarginalRatioEstimator(
    marginal_indices = marginal_indices,
    network = network,
    device = DEVICE,
)

mre.train(dataset, max_epochs=2)

In [None]:
swyft.Posterios

In [None]:
# Train
post_name, post_path = get_post_path(sim_name, nmbins, lr, factor, patience)
print(f'Training {post_name}!')

torch.set_default_tensor_type(torch.FloatTensor)
post = swyft.Posteriors(dataset)
post.add(marginals, device = DEVICE, 
         tail_args = dict(nmbins = nmbins, lows = lows, highs = highs),
         head = CustomHead, tail = CustomTail)
post.train(marginals, max_epochs = max_epochs,
           optimizer_args = dict(lr=lr),
           scheduler_args = dict(factor = factor, patience = patience)
          )
post.save(post_path)

print('Done!')
print(f"Total training time is {str(datetime.datetime.now() - time_start).split('.')[0]}!")