In [38]:
import os

import scipy.io as sio
import matplotlib.pyplot as plt

import numpy as np
import time

from tqdm import tqdm

import pypulseq

from cest_mrf.write_scenario import write_yaml_dict
from cest_mrf.dictionary.generation import generate_mrf_cest_dictionary
from cest_mrf.metrics.dot_product import dot_prod_matching, dot_prod_indexes

from configs import ConfigPreclinical
from sequences import write_sequence_preclinical

In [41]:
def gen_dict(x, cfg):
    # Define output filenames
    yaml_fn = cfg['yaml_fn']
    seq_fn = cfg['seq_fn']
    dict_fn = cfg['dict_fn']

    # Write the .yaml according to the config.py file (inside cest_mrf folder)
    write_yaml_dict(cfg, yaml_fn)

    # Write the seq file for a 2d experiment
    # for more info about the seq file, check out the pulseq-cest repository
    seq_defs = {}
    seq_defs['n_pulses'] = 1  # number of pulses
    seq_defs['tp'] = 3  # pulse duration [s]
    seq_defs['td'] = 0  # interpulse delay [s]
    seq_defs['Trec'] = 1  # delay before readout [s]
    seq_defs['Trec_M0'] = 'NaN'  # delay before m0 readout [s]
    seq_defs['M0_offset'] = 'NaN'  # dummy m0 offset [ppm]
    seq_defs['DCsat'] = seq_defs['tp'] / (seq_defs['tp'] + seq_defs['td'])  # duty cycle
    seq_defs['offsets_ppm'] = [3.0] * len(x)  # offset vector [ppm]
    seq_defs['num_meas'] = len(seq_defs['offsets_ppm'])  # number of repetition
    seq_defs['Tsat'] = seq_defs['n_pulses'] * (seq_defs['tp'] + seq_defs['td']) - seq_defs['td']
    seq_defs['B0'] = cfg['b0']  # B0 [T]

    seqid = os.path.splitext(seq_fn)[1][1:]
    seq_defs['seq_id_string'] = seqid  # unique seq id

    # we vary B1 for the dictionary generation
    seq_defs['B1pa'] = x

    # Create .seq file
    write_sequence_preclinical(seq_defs=seq_defs, seq_fn=seq_fn)

    start = time.perf_counter()
    dictionary = generate_mrf_cest_dictionary(seq_fn=seq_fn, param_fn=yaml_fn, dict_fn=dict_fn, num_workers=cfg['num_workers'],
                                    axes='xy')  # axes can also be 'z' if no readout is simulated
    end = time.perf_counter()
    s = (end - start)
    print(f"Dictionary simulation and preparation took {s:.03f} s.")
    
    dictionary['sig'] = np.array(dictionary['sig']).T
    for key in dictionary.keys():
        if key != 'sig':
            dictionary[key] = np.expand_dims( np.squeeze(np.array(dictionary[key])), 0)
    return dictionary


def dot_prod_wrap(dictionary, acq_sig, params=None, batch_size=256):
    # Calculate dot product
    dp = dot_prod_indexes(dictionary['sig'], acq_sig, batch_size=batch_size)

    quant_maps = {}
    shape = dp['dp'].shape
    quant_maps['dp'] = dp['dp']

    for p in params:
        if p in dictionary.keys():
            quant_maps[p] = dictionary[p][0, dp['dp_indexes'].flatten().astype(int)].reshape(shape)
        else:
            print(f"Parameter {p} not found in dictionary.")

    return dp, quant_maps


In [51]:
cfg = ConfigPreclinical().get_config()

sig_n = 12

b1 = np.random.randn(sig_n)
b1


In [54]:
dictionary = gen_dict(b1, cfg)

No MT pools found in param files! specify with "mt_pool"
Found 12259 different parameter combinations.
Future 0 is finished
Future 1 is finished
Future 2 is finished
Future 3 is finished
Future 4 is finished
Future 5 is finished
Future 6 is finished
Future 7 is finished
Future 8 is finished
Future 9 is finished
Future 10 is finished
Future 12 is finished
Future 11 is finished
Future 13 is finished
Future 14 is finished
Future 15 is finished
Future 17 is finished
Future 16 is finished
Dictionary simulation took 4.268 s.
Dictionary simulation and preparation took 4.289 s.


In [55]:
n_iter = 20
noise_std = 0.05
params = ['fs_0', 'ksw_0']

error = dict([(p, []) for p in params])

for n in tqdm(range(n_iter)):
    noise_signal = dictionary['sig'] + noise_std * np.random.randn(*dictionary['sig'].shape)
    dp, quant_maps = dot_prod_wrap(dictionary, noise_signal, params=params, batch_size=299)

    for p in params:
        error[p].append(np.linalg.norm(quant_maps[p] - dictionary[p][0]) / np.linalg.norm(dictionary[p][0]))

for p in params:
    print(f'MAE of {p} is {np.mean(error[p])*100:.3f} +- {np.std(error[p]):.3f} %')


100%|██████████| 20/20 [00:06<00:00,  2.86it/s]

MAE of fs_0 is 19.512 +- 0.001 %
MAE of ksw_0 is 33.828 +- 0.003 %



