# Imports

In [None]:
# Standard library imports
import os
import time

# Third party imports
import numpy as np
import matplotlib.pyplot as plt
from scipy import stats
from torch import as_tensor as pytT
from torch import float32 as pytFl32
from torch import save

from bmctool.bmc_tool import BMCTool
from bmctool.set_params import load_params

%matplotlib inline

# Define parameter and seq files

In [None]:
config_file = 'library/config_WASABITI.yaml'
seq_file = 'library/WASABITI_sim.seq'

# Load sample and experimental settings from configs
parameters can be printed using `sim_params.print_settings()`

In [None]:
sim_params = load_params(config_file)

# Define parameter spaces
### Number of samples for different sizes:
- tiny: 4,096
- small: 65,536
- medium: 589,824
- large: 2,359,296

In [None]:
param_size = 'tiny'  # can be 'tiny', 'small', 'medium' or 'large'


##################################################################
# B0-shift range
##################################################################
b0_sizes = {'tiny': 8, 'small': 16, 'medium': 24, 'large': 32}
b0_size = b0_sizes[param_size]

b0_dist = stats.norm(loc=0, scale=0.35)
b0_lim = [-1, 1]
b0_bounds = b0_dist.cdf(b0_lim)
b0_pp = np.linspace(*b0_bounds, num=b0_size)
b0_var = b0_dist.ppf(b0_pp)

##################################################################
# B1 range
##################################################################
b1_sizes = {'tiny': 8, 'small': 16, 'medium': 32, 'large': 48}
b1_size = b1_sizes[param_size]

b1_dist = stats.norm(loc=1, scale=0.45)
b1_lim = [0.1, 1.9]
b1_bounds = b1_dist.cdf(b1_lim)
b1_pp = np.linspace(*b1_bounds, num=b1_size)
b1_var = b1_dist.ppf(b1_pp)

##################################################################
# T1 range
##################################################################
t1_sizes = {'tiny': 8, 'small': 16, 'medium': 32, 'large': 48}
t1_size = t1_sizes[param_size]

class sum_gaussians_t1(stats.rv_continuous):
    def _pdf(self, x):
        return (stats.norm.pdf(x, loc=3, scale=2) + stats.norm.pdf(x, loc=1.2, scale=0.6)) / 2
    
t1_dist = sum_gaussians_t1()
t1_lim = [0.05, 7]
t1_bounds = t1_dist.cdf(t1_lim)
t1_pp = np.linspace(*t1_bounds, num=t1_size)
t1_var = t1_dist.ppf(t1_pp)


##################################################################
# T2 range
##################################################################
t2_sizes = {'tiny': 8, 'small': 16, 'medium': 24, 'large': 32}
t2_size = t2_sizes[param_size]

class sum_gaussians_t2(stats.rv_continuous):
    def _pdf(self, x):
        return (4 * stats.norm.pdf(x, loc=2, scale=2) + stats.norm.pdf(x, loc=0.15, scale=0.1)) / 5
    
t2_dist = sum_gaussians_t2()
t2_lim = [0.01, 5]
t2_bounds = t2_dist.cdf(t2_lim)
t2_pp = np.linspace(*t2_bounds, num=t2_size)
t2_var = t2_dist.ppf(t2_pp)

n_total = len(b0_var)*len(b1_var)*len(t1_var)*len(t2_var)

# Define output settings

In [None]:
# create save data directory
filepath_output = 'data'
if not os.path.exists(filepath_output):
    os.makedirs(filepath_output)
    print('created a new folder for the data')
    
subfolder_name = 'example_' + param_size
subfolder_path = os.path.join(filepath_output, subfolder_name)
if not os.path.exists(subfolder_path):
    os.makedirs(subfolder_path)
    print(f'created a new subfolder "{subfolder_name}" in {filepath_output} folder')

# Print parameter distributions

In [None]:
fig, ax = plt.subplots(2,2, figsize=(14,10))
fig.suptitle(f'Parameter Distributions (size = "{param_size}" ({n_total} samples))', fontsize=24)

# print B0 values
x = np.linspace(min(b0_lim), max(b0_lim), 1000)
ax[0,0].plot(x, b0_dist.pdf(x), '-', label='pdf')
ax[0,0].plot(b0_var, [-0.1] * b0_var.size, '.', label='samples')
ax[0,0].set_title(f'B$_0$ ({b0_var.size} samples)', fontsize=16)
ax[0,0].legend()

# print B1 values
x = np.linspace(min(b1_lim), max(b1_lim), 1000)
ax[0,1].plot(x, b1_dist.pdf(x), '-', label='pdf')
ax[0,1].plot(b1_var, [-0.1] * b1_var.size, '.', label='samples')
ax[0,1].set_title(f'B$_1$ ({b1_var.size} samples)', fontsize=16)
ax[0,1].legend()

# print T1 values
x = np.linspace(min(t1_lim), max(t1_lim), 1000)
ax[1,0].plot(x, t1_dist.pdf(x), '-', label='pdf')
ax[1,0].plot(t1_var, [-0.1] * t1_var.size, '.', label='samples')
ax[1,0].set_title(f'T$_1$ ({t1_var.size} samples)', fontsize=16)
ax[1,0].legend()

# print T2 values
x = np.linspace(min(t2_lim), max(t2_lim), 1000)
ax[1,1].plot(x, t2_dist.pdf(x), '-', label='pdf')
ax[1,1].plot(t2_var, [-0.1] * t2_var.size, '.', label='samples')
ax[1,1].set_title(f'T$_2$ ({t2_var.size} samples)', fontsize=16)
ax[1,1].legend()

# Run simulation

In [None]:
pars = []
spec_array = np.zeros([len(b0_var), len(b1_var), len(t1_var), len(t2_var), 31], dtype='float16')

Sim = BMCTool(sim_params, seq_file)
print(f'Simulating {n_total} z-spectra.\n')
count = 0
n_progressbar = len(b0_var)*len(b1_var)*len(t1_var)
loopstart = time.time()
for i, dw_ in enumerate(b0_var):
    sim_params.update_scanner(b0_inhom=dw_)
    for j, rb1_ in enumerate(b1_var):
        sim_params.update_scanner(rel_b1=rb1_)
        for k, t1_ in enumerate(t1_var):
            sim_params.update_water_pool(r1=1/t1_)
            for l, t2_ in enumerate(t2_var):
                sim_params.update_water_pool(r2=1/t2_)
                
                # update parameters and run simulation
                Sim.params = sim_params
                Sim.run()
                
                # write spectrum and parameters
                _, spec_array[i,j,k,l,:] = Sim.get_zspec()
                pars.append([dw_, rb1_, t1_, t2_])
                       
            # update progress bar and estimated time
            b = int(60 * count / n_progressbar)
            left = int(60 - b)
            count += 1
            loopremain = (time.time() - loopstart) * (n_progressbar - count) / (count * 60)
            print('[' + '#' * b + '-' * left + ']' + 
                  f' Estimated remaining time {loopremain:.1f} minutes.', end='\r')

print(' \n ')
print(f'Simulation took {(time.time()-loopstart)/60:.3f} minutes')

# Reshape & Save

In [None]:
# reshape data and convert to pytorch tensor
X = pytT(spec_array.reshape([-1, 31]), dtype=pytFl32)
y = pytT(np.array(pars), dtype=pytFl32)

# save data with pytorch
filepath_save_X = os.path.join(subfolder_path, f'{time.strftime("%Y%m%d")}_X_discrete_{param_size}_{n_total}_samples.pt')
filepath_save_y = os.path.join(subfolder_path, f'{time.strftime("%Y%m%d")}_y_discrete_{param_size}_{n_total}_samples.pt')

save (X, filepath_save_X)
save (y, filepath_save_y)