In [None]:
import os
import sys
import numpy as np

from multifield_combined import MultifieldDataset

import torch
from torch import nn
from torch.backends import cudnn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import Dataset

import matplotlib.pyplot as plt

In [None]:
if torch.cuda.is_available():
    print("CUDA Available")
    device = torch.device('cuda')
else:
    print('CUDA Not Available')
    device = torch.device('cpu')
#cudnn.benchmark = True      #May train faster but cost more memory

In [None]:
# Location of data
DDIR = '/home/masterdesky/data/CAMELS/2D_maps/data/' # Data directory
TARGET = 'SIMBA'#'IllustrisTNG'                      # Target simulations

FILE = os.listdir(DDIR)
FILE = sorted([
    f for f in FILE if ('.txt' not in f) & ('_CV_' not in f) \
                     & ('Nbody' not in f) & (TARGET in f)
])

# Data parameters
fdata      = [FILE[0]]
fmaps      = [f"{'_'.join(f.split('_')[:2])}_{TARGET}.npy" for f in fdata]
fmaps_norm = [None] #if you want to normalize the maps according to the properties of some data set, put that data set here (This is mostly used when training on IllustrisTNG and testing on SIMBA, or vicerversa)
fparams    = os.path.join(DDIR, f"params_{TARGET}.txt")
splits     = 6   #number of maps per simulation (1 - 15)
seed       = 1   #random seed to split maps among training, validation and testing

# Training parameters
channels   = 1                   # Number of fields to consider
params     = [0, 1, 2, 3, 4, 5]  # Omega_m, sigma_8, A_SN1, A_AGN1, A_SN2, A_AGN2
g          = params              # Mean of the posterior
h          = [6+i for i in g]    # Variance of the posterior
memmap     = False               # Keep rotations and flippings in memory

# Optimizer parameters
beta1 = 0.5
beta2 = 0.999

# Hyperparameters
batch_size = 128
lr         = 1e-3
wd         = 0.0005  #value of weight decay
dr         = 0.2     #dropout value for fully connected layers
hidden     = 5       #this determines the number of layers in the CNNs; > 1
epochs     = 100     #number of epochs to train the network

# output files names
SUFFIX = '_'.join([f.split('_')[1] for f in fmaps])
floss  = f"loss_{TARGET}_{SUFFIX}.txt"   #file with the training and validation losses for each epoch
fmodel = f"weights_{TARGET}_{SUFFIX}.pt" #file containing the weights of the best-model

In [None]:
def prepare_raw_dataset(fdata, splits, *, verbose=True):
    '''
    
    Parameters
    ----------
    fdata : str
        Name of the file containing the selected dataset.
    splits : int
        Number of maps to select for every simulation. Should be between
        1 and 15. (At least 1 map should be chosen out of the total 15.)
    verbose : bool
        Controls verbosity.
    '''
    assert 1 <= splits <= 15, "Value of `splits` should be between 1 and 15!"
    
    maps = np.load(os.path.join(DDIR, fdata))
    if verbose:
        print(f"Shape of the maps: {maps.shape}")
        
    # define the array that will contain the indexes of the maps
    indexes = np.arange(15000)%15 < splits
    if verbose:
        print(f"Selected {np.sum(indexes):,} maps out of 15,000")

    # Save these maps to a new file
    np.save(f"{'_'.join(fdata.split('_')[:2])}_{TARGET}.npy", maps[indexes])
    del maps

In [None]:
for fin in fdata:
    prepare_raw_dataset(fdata=fin, splits=6, verbose=True)

In [None]:
# This routine returns the data loader need to train the network
def create_dataset_multifield(mode, fmaps, fmaps_norm, fparams,
                              splits, batch_size, *, memmap=True,
                              shuffle=True, seed=None, verbose=False):

    # whether rotations and flippings are kept in memory
    data_set = MultifieldDataset(
        mode, fmaps, fmaps_norm, fparams, splits,
        norm_params=True, memmap=memmap, seed=seed, verbose=True
    )
    data_loader = DataLoader(
        dataset=data_set, batch_size=batch_size, shuffle=shuffle
    )
    return data_loader

In [None]:
# get training set
print('\nPreparing training set')
train_loader = create_dataset_multifield('train', fmaps, fmaps_norm, fparams,
                                         splits, batch_size, memmap=memmap,
                                         shuffle=True, seed=seed, verbose=True)

# get validation set
print('\nPreparing validation set')
valid_loader = create_dataset_multifield('valid', fmaps, fmaps_norm, fparams,
                                         splits, batch_size, memmap=True,
                                         shuffle=True, seed=seed, verbose=True)  

In [None]:
class model_so3(nn.Module):
    def __init__(self,
                 n_channels, n_filters,
                 kernel_size, padding, padding_mode, stride,
                 dropout_rate=0.2):
        super(model_so3, self).__init__()

        # Possible activation functions
        self.ReLU      = nn.ReLU()
        self.LeakyReLU = nn.LeakyReLU(0.2)
        self.tanh      = nn.Tanh()

        # Input parameter dependent modules
        self.dropout   = nn.Dropout(p=dropout_rate)
        
        # input: 1x256x256 ---------------> output: 2*hiddenx128x128
        self.C01 = nn.Conv2d(n_channels, 2*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding, 
                             padding_mode=padding_mode,
                             bias=True)
        self.C02 = nn.Conv2d(2*n_filters, 2*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding, 
                             padding_mode=padding_mode,
                             bias=True)
        self.C03 = nn.Conv2d(2*n_filters, 2*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=0, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B01 = nn.BatchNorm2d(2*hidden)
        self.B02 = nn.BatchNorm2d(2*hidden)
        self.B03 = nn.BatchNorm2d(2*hidden)
        
        # input: 2*hiddenx128x128 ----------> output: 4*hiddenx64x64
        self.C11 = nn.Conv2d(2*n_filters, 4*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding, 
                             padding_mode=padding_mode,
                             bias=True)
        self.C12 = nn.Conv2d(4*n_filters, 4*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding, 
                             padding_mode=padding_mode,
                             bias=True)
        self.C13 = nn.Conv2d(4*n_filters, 4*n_filters,
                             kernel_size=2,
                             stride=2,
                             padding=0, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B11 = nn.BatchNorm2d(4*hidden)
        self.B12 = nn.BatchNorm2d(4*hidden)
        self.B13 = nn.BatchNorm2d(4*hidden)
        
        # input: 4*hiddenx64x64 --------> output: 8*hiddenx32x32
        self.C21 = nn.Conv2d(4*n_filters, 8*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B21 = nn.BatchNorm2d(8*n_filters)
        self.C22 = nn.Conv2d(8*n_filters, 8*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B22 = nn.BatchNorm2d(8*n_filters)
        self.C23 = nn.Conv2d(8*n_filters, 8*n_filters,
                             kernel_size=2,
                             stride=2,
                             padding=0, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B23 = nn.BatchNorm2d(8*n_filters)
        
        # input: 8*hiddenx32x32 ----------> output: 16*hiddenx16x16
        self.C31 = nn.Conv2d(8*n_filters, 16*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B31 = nn.BatchNorm2d(16*n_filters)
        self.C32 = nn.Conv2d(16*n_filters, 16*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B32 = nn.BatchNorm2d(16*n_filters)
        self.C33 = nn.Conv2d(16*n_filters, 16*n_filters,
                             kernel_size=2,
                             stride=2,
                             padding=0, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B33 = nn.BatchNorm2d(16*n_filters)
        
        # input: 16*hiddenx16x16 ----------> output: 32*hiddenx8x8
        self.C41 = nn.Conv2d(16*n_filters, 32*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B41 = nn.BatchNorm2d(32*n_filters)
        self.C42 = nn.Conv2d(32*n_filters, 32*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B42 = nn.BatchNorm2d(32*n_filters)
        self.C43 = nn.Conv2d(32*n_filters, 32*n_filters,
                             kernel_size=2,
                             stride=2,
                             padding=0, 
                             padding_mode=padding_mode,
                             bias=True)        
        self.B43 = nn.BatchNorm2d(32*n_filters)
        
        # input: 32*hiddenx8x8 ----------> output:64*hiddenx4x4
        self.C51 = nn.Conv2d(32*n_filters, 64*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B51 = nn.BatchNorm2d(64*n_filters)
        self.C52 = nn.Conv2d(64*n_filters, 64*n_filters,
                             kernel_size=kernel_size,
                             stride=stride,
                             padding=padding, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B52 = nn.BatchNorm2d(64*n_filters)
        self.C53 = nn.Conv2d(64*n_filters, 64*n_filters,
                             kernel_size=2,
                             stride=2,
                             padding=0, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B53 = nn.BatchNorm2d(64*n_filters)

        # input: 64*hiddenx4x4 ----------> output: 128*hiddenx1x1
        self.C61 = nn.Conv2d(64*n_filters, 128*n_filters,
                             kernel_size=4,
                             stride=4,
                             padding=0, 
                             padding_mode=padding_mode,
                             bias=True)
        self.B61 = nn.BatchNorm2d(128*n_filters)

        self.P0  = nn.AvgPool2d(kernel_size=2, stride=2, padding=0)

        self.FC1  = nn.Linear(128*n_filters, 64*n_filters)  
        self.FC2  = nn.Linear(64*n_filters,  12)

        # Set conventional values for batch normalization modules
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight)


    def forward(self, image):

        x = self.LeakyReLU(self.B01(self.C01(image)))
        x = self.LeakyReLU(self.B02(self.C02(x)))
        x = self.LeakyReLU(self.B03(self.C03(x)))

        x = self.LeakyReLU(self.B11(self.C11(x)))
        x = self.LeakyReLU(self.B12(self.C12(x)))
        x = self.LeakyReLU(self.B13(self.C13(x)))

        x = self.LeakyReLU(self.B21(self.C21(x)))
        x = self.LeakyReLU(self.B22(self.C22(x)))
        x = self.LeakyReLU(self.B23(self.C23(x)))

        x = self.LeakyReLU(self.B31(self.C31(x)))
        x = self.LeakyReLU(self.B32(self.C32(x)))
        x = self.LeakyReLU(self.B33(self.C33(x)))

        x = self.LeakyReLU(self.B41(self.C41(x)))
        x = self.LeakyReLU(self.B42(self.C42(x)))
        x = self.LeakyReLU(self.B43(self.C43(x)))

        x = self.LeakyReLU(self.B51(self.C51(x)))
        x = self.LeakyReLU(self.B52(self.C52(x)))
        x = self.LeakyReLU(self.B53(self.C53(x)))

        x = self.LeakyReLU(self.B61(self.C61(x)))

        x = x.view(image.shape[0], -1)
        x = self.dropout(x)
        x = self.dropout(self.LeakyReLU(self.FC1(x)))
        x = self.FC2(x)

        # enforce the errors to be positive
        y = torch.clone(x)
        y[:,6:12] = torch.square(x[:,6:12])

        return y

In [None]:
model = model_so3(
    n_channels=1,
    n_filters=5,
    kernel_size=3,
    padding=1,
    padding_mode='circular',
    stride=1,
    dropout_rate=0.2
)
model.to(device=device)

In [None]:
from torchinfo import summary

In [None]:
# now that architecture is defined above, use it
network_total_params = sum(p.numel() for p in model.parameters())
print(f"total number of parameters in the model = {network_total_params:,}")

In [None]:
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=lr,
    weight_decay=wd,
    betas=(beta1, beta2)
)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer,
    'min',
    factor=0.3,
    patience=10
)

In [None]:
print('Computing initial validation loss')
model.eval()
valid_loss1 = torch.zeros(len(g)).to(device)
valid_loss2 = torch.zeros(len(g)).to(device)
min_valid_loss, points = 0.0, 0
for x, y in valid_loader:
    with torch.no_grad():
        bs   = x.shape[0]                #batch size
        x    = x.to(device=device)       #maps
        y    = y.to(device=device)[:,g]  #parameters
        p    = model(x)                  #NN output
        y_NN = p[:,g]                    #posterior mean
        e_NN = p[:,h]                    #posterior std
        loss1 = torch.mean((y_NN - y)**2,                axis=0)
        loss2 = torch.mean(((y_NN - y)**2 - e_NN**2)**2, axis=0)
        loss  = torch.mean(torch.log(loss1) + torch.log(loss2))
        valid_loss1 += loss1*bs
        valid_loss2 += loss2*bs
        points += bs
min_valid_loss = torch.log(valid_loss1/points) + torch.log(valid_loss2/points)
min_valid_loss = torch.mean(min_valid_loss).item()
print('Initial valid loss = %.3e'%min_valid_loss)

In [None]:
%%time
# do a loop over all epochs
for epoch in range(epochs):

    # do training
    train_loss1, train_loss2 = torch.zeros(len(g)).to(device), torch.zeros(len(g)).to(device)
    train_loss, points = 0.0, 0
    model.train()
    for x, y in train_loader:
        bs   = x.shape[0]         #batch size
        x    = x.to(device)       #maps
        y    = y.to(device)[:,g]  #parameters
        p    = model(x)           #NN output
        y_NN = p[:,g]             #posterior mean
        e_NN = p[:,h]             #posterior std
        loss1 = torch.mean((y_NN - y)**2,                axis=0)
        loss2 = torch.mean(((y_NN - y)**2 - e_NN**2)**2, axis=0)
        loss  = torch.mean(torch.log(loss1) + torch.log(loss2))
        train_loss1 += loss1*bs
        train_loss2 += loss2*bs
        points      += bs
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        #if points>18000:  break
    train_loss = torch.log(train_loss1/points) + torch.log(train_loss2/points)
    train_loss = torch.mean(train_loss).item()

    # do validation: cosmo alone & all params
    valid_loss1, valid_loss2 = torch.zeros(len(g)).to(device), torch.zeros(len(g)).to(device)
    valid_loss, points = 0.0, 0
    model.eval()
    for x, y in valid_loader:
        with torch.no_grad():
            bs    = x.shape[0]         #batch size
            x     = x.to(device)       #maps
            y     = y.to(device)[:,g]  #parameters
            p     = model(x)           #NN output
            y_NN  = p[:,g]             #posterior mean
            e_NN  = p[:,h]             #posterior std
            loss1 = torch.mean((y_NN - y)**2,                axis=0)
            loss2 = torch.mean(((y_NN - y)**2 - e_NN**2)**2, axis=0)
            loss  = torch.mean(torch.log(loss1) + torch.log(loss2))
            valid_loss1 += loss1*bs
            valid_loss2 += loss2*bs
            points     += bs
    valid_loss = torch.log(valid_loss1/points) + torch.log(valid_loss2/points)
    valid_loss = torch.mean(valid_loss).item()

    scheduler.step(valid_loss)

    # verbose
    print('%03d %.3e %.3e '%(epoch, train_loss, valid_loss), end='')

    # save model if it is better
    if valid_loss<min_valid_loss:
        torch.save(model.state_dict(), fmodel)
        min_valid_loss = valid_loss
        print('(C) ', end='')
    print('')

    # save losses to file
    f = open(floss, 'a')
    f.write('%d %.5e %.5e\n'%(epoch, train_loss, valid_loss))
    f.close()

In [None]:
# load the weights in case they exists
if os.path.exists(fmodel):  
    model.load_state_dict(torch.load(fmodel, map_location=torch.device(device)))
    print('Weights loaded')

In [None]:
# load test set
test_loader  = create_dataset_multifield('test', fmaps, fmaps_norm, fparams,
                                         splits, batch_size, memmap=False,
                                         shuffle=True, seed=seed, verbose=True)  

# get the number of maps in the test set
num_maps = 0
for x, y in test_loader:
      num_maps += x.shape[0]
print('\nNumber of maps in the test set: %d'%num_maps)

# define the arrays containing the value of the parameters
params_true = np.zeros((num_maps,6), dtype=np.float32)
params_NN   = np.zeros((num_maps,6), dtype=np.float32)
errors_NN   = np.zeros((num_maps,6), dtype=np.float32)

# get test loss
test_loss1, test_loss2 = torch.zeros(len(g)).to(device), torch.zeros(len(g)).to(device)
test_loss, points = 0.0, 0
model.eval()
for x, y in test_loader:
    with torch.no_grad():
        bs    = x.shape[0]    #batch size
        x     = x.to(device)  #send data to device
        y     = y.to(device)  #send data to device
        p     = model(x)      #prediction for mean and variance
        y_NN  = p[:,:6]       #prediction for mean
        e_NN  = p[:,6:]       #prediction for error
        loss1 = torch.mean((y_NN[:,g] - y[:,g])**2,                     axis=0)
        loss2 = torch.mean(((y_NN[:,g] - y[:,g])**2 - e_NN[:,g]**2)**2, axis=0)
        test_loss1 += loss1*bs
        test_loss2 += loss2*bs

        # save results to their corresponding arrays
        params_true[points:points+x.shape[0]] = y.cpu().numpy() 
        params_NN[points:points+x.shape[0]]   = y_NN.cpu().numpy()
        errors_NN[points:points+x.shape[0]]   = e_NN.cpu().numpy()
        points    += x.shape[0]
test_loss = torch.log(test_loss1/points) + torch.log(test_loss2/points)
test_loss = torch.mean(test_loss).item()
print('Test loss = %.3e\n'%test_loss)

Norm_error = np.sqrt(np.mean((params_true - params_NN)**2, axis=0))
print('Normalized Error Omega_m = %.3f'%Norm_error[0])
print('Normalized Error sigma_8 = %.3f'%Norm_error[1])
print('Normalized Error A_SN1   = %.3f'%Norm_error[2])
print('Normalized Error A_AGN1  = %.3f'%Norm_error[3])
print('Normalized Error A_SN2   = %.3f'%Norm_error[4])
print('Normalized Error A_AGN2  = %.3f\n'%Norm_error[5])

# de-normalize
minimum = np.array([0.1, 0.6, 0.25, 0.25, 0.5, 0.5])
maximum = np.array([0.5, 1.0, 4.00, 4.00, 2.0, 2.0])
params_true = params_true*(maximum - minimum) + minimum
params_NN   = params_NN*(maximum - minimum) + minimum
errors_NN   = errors_NN*(maximum - minimum)

error = np.sqrt(np.mean((params_true - params_NN)**2, axis=0))
print('Error Omega_m = %.3f'%error[0])
print('Error sigma_8 = %.3f'%error[1])
print('Error A_SN1   = %.3f'%error[2])
print('Error A_AGN1  = %.3f'%error[3])
print('Error A_SN2   = %.3f'%error[4])
print('Error A_AGN2  = %.3f\n'%error[5])

mean_error = np.absolute(np.mean(errors_NN, axis=0))
print('Bayesian error Omega_m = %.3f'%mean_error[0])
print('Bayesian error sigma_8 = %.3f'%mean_error[1])
print('Bayesian error A_SN1   = %.3f'%mean_error[2])
print('Bayesian error A_AGN1  = %.3f'%mean_error[3])
print('Bayesian error A_SN2   = %.3f'%mean_error[4])
print('Bayesian error A_AGN2  = %.3f\n'%mean_error[5])

rel_error = np.sqrt(np.mean((params_true - params_NN)**2/params_true**2, axis=0))
print('Relative error Omega_m = %.3f'%rel_error[0])
print('Relative error sigma_8 = %.3f'%rel_error[1])
print('Relative error A_SN1   = %.3f'%rel_error[2])
print('Relative error A_AGN1  = %.3f'%rel_error[3])
print('Relative error A_SN2   = %.3f'%rel_error[4])
print('Relative error A_AGN2  = %.3f\n'%rel_error[5])

# save results to file
#dataset = np.zeros((num_maps,18), dtype=np.float32)
#dataset[:,:6]   = params_true
#dataset[:,6:12] = params_NN
#dataset[:,12:]  = errors_NN
#np.savetxt(fresults,  dataset)
#np.savetxt(fresults1, Norm_error)

In [None]:
# select the first map of every simulation in the test set
indexes = np.arange(50)*splits

In [None]:
fig=plt.figure(figsize=(7,7))
plt.xlabel(r'${\rm Truth}$')
plt.ylabel(r'${\rm Inference}$')
plt.text(0.1, 0.45, r'$\Omega_{\rm m}$',fontsize=18)

plt.errorbar(params_true[indexes,0], params_NN[indexes,0], errors_NN[indexes,0], 
             linestyle='None', lw=1, fmt='o', ms=2, elinewidth=1, capsize=0, c='r')
plt.plot([0.1,0.5], [0.1,0.5], color='k')
 
plt.show()

In [None]:
fig=plt.figure(figsize=(7,7))
plt.xlabel(r'${\rm Truth}$')
plt.ylabel(r'${\rm Inference}$')
plt.text(0.6, 0.95, r'$\sigma_8$',fontsize=18)

plt.errorbar(params_true[indexes,1], params_NN[indexes,1], errors_NN[indexes,1], 
             linestyle='None', lw=1, fmt='o', ms=2, elinewidth=1, capsize=0, c='r')
plt.plot([0.6,1.0], [0.6,1.0], color='k')
 
plt.show()

In [None]:
fig=plt.figure(figsize=(7,7))
plt.xlabel(r'${\rm Truth}$')
plt.ylabel(r'${\rm Inference}$')
plt.text(0.25, 4.0, r'$A_{\rm SN1}$',fontsize=18)

plt.errorbar(params_true[indexes,2], params_NN[indexes,2], errors_NN[indexes,2], 
             linestyle='None', lw=1, fmt='o', ms=2, elinewidth=1, capsize=0, c='r')
plt.plot([0.25,4.0], [0.25,4.0], color='k')
 
plt.show()

In [None]:
fig=plt.figure(figsize=(7,7))
plt.xlabel(r'${\rm Truth}$')
plt.ylabel(r'${\rm Inference}$')
plt.text(0.5, 2.0, r'$A_{\rm SN2}$',fontsize=18)

plt.errorbar(params_true[indexes,4], params_NN[indexes,4], errors_NN[indexes,4], 
             linestyle='None', lw=1, fmt='o', ms=2, elinewidth=1, capsize=0, c='r')
plt.plot([0.5,2.0], [0.5,2.0], color='k')
 
plt.show()

In [None]:
fig=plt.figure(figsize=(7,7))
plt.xlabel(r'${\rm Truth}$')
plt.ylabel(r'${\rm Inference}$')
plt.text(0.25, 4.0, r'$A_{\rm AGN1}$',fontsize=18)

plt.errorbar(params_true[indexes,3], params_NN[indexes,3], errors_NN[indexes,3], 
             linestyle='None', lw=1, fmt='o', ms=2, elinewidth=1, capsize=0, c='r')
plt.plot([0.25,4.0], [0.25,4.0], color='k')
 
plt.show()

In [None]:
fig=plt.figure(figsize=(7,7))
plt.xlabel(r'${\rm Truth}$')
plt.ylabel(r'${\rm Inference}$')
plt.text(0.5, 2.0, r'$A_{\rm AGN2}$',fontsize=18)

plt.errorbar(params_true[indexes,5], params_NN[indexes,5], errors_NN[indexes,5], 
             linestyle='None', lw=1, fmt='o', ms=2, elinewidth=1, capsize=0, c='r')
plt.plot([0.5,2.0], [0.5,2.0], color='k')
 
plt.show()

In [None]:
n_sims = 600
offset, size_sims = int(0.95*n_sims), int(0.05*n_sims)
size_maps = size_sims*splits
splits = 6

In [None]:
# Shuffle the simulations (indeces from 0 to 999 in case of CAMELS)
np.random.seed(None)
sim_numbers = np.arange(n_sims) #shuffle sims not maps
np.random.shuffle(sim_numbers)
sim_numbers = sim_numbers[offset:offset+size_sims] #select indexes of mode

In [None]:
indexes = np.array(
    np.repeat(sim_numbers * splits, splits) \
  + np.tile(range(splits), len(sim_numbers)), dtype=np.int32
)

In [None]:
indexes = np.zeros(size_maps, dtype=np.int32)
count = 0
for i in sim_numbers:
    for j in range(splits):
        indexes[count] = i*splits + j
        count += 1

In [None]:
indexes