In [1]:
import multiprocessing
import platform
import torch
import numpy as np
import datetime
from torch.utils.data import DataLoader
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import Subset
import itertools
import pandas as pd

import time
import os
# import tqdm
import tqdm.notebook as tqdm

from utils.normalization import normalize_range, un_normalize_range
from utils.seed import set_seed

from sequential_nn.dataset import GluMemDataset_4pool
from sequential_nn.model import Network
from sequential_nn.multi_dict import pkl_2_dat

from torch.utils.tensorboard import SummaryWriter

In [10]:
glu_dict_upper_folder_fn = f'D:\mrf\data\exp\mt_amide_glu_dicts\high_glu_conc_500\glu'
glu_dict_folder_fn = os.path.join(glu_dict_upper_folder_fn, '107a')
memmap_fn = os.path.join(glu_dict_folder_fn, 'M0_dict.dat')
if not os.path.exists(memmap_fn):
    pkl_2_dat(glu_dict_folder_fn, sched_iter=30, add_iter=6, memmap_fn=memmap_fn, M0_flag=True)
glu_dict_folder_fn = os.path.join(glu_dict_upper_folder_fn, '51_Amide')
memmap_fn = os.path.join(glu_dict_folder_fn, 'M0_dict.dat')
if not os.path.exists(memmap_fn):
    pkl_2_dat(glu_dict_folder_fn, sched_iter=30, add_iter=6, memmap_fn=memmap_fn, M0_flag=True)

Single dict case
Total time taken for memmap generation: 4 minutes 28 seconds
Single dict case
Total time taken for memmap generation: 6 minutes 32 seconds


In [13]:
glu_dict_upper_folder_fn = r'D:\mrf\data\exp\mt_amide_glu_dicts\high_glu_conc_500\amide'
glu_dict_folder_fn = os.path.join(glu_dict_upper_folder_fn, '51_Amide')
memmap_fn = os.path.join(glu_dict_folder_fn, 'M0_dict.dat')
if not os.path.exists(memmap_fn):
    pkl_2_dat(glu_dict_folder_fn, sched_iter=30, add_iter=4, memmap_fn=memmap_fn, M0_flag=True)

Single dict case
Total time taken for memmap generation: 0 minutes 2 seconds


In [2]:
%load_ext tensorboard
%tensorboard --logdir=./runs --port 6007

Launching TensorBoard...

In [2]:
def main():
    torch.multiprocessing.freeze_support()
    dict_name_category = 'high_glu_conc_500'
    fp_prtcl_name = '107a'

    # Schedule iterations
    # number of raw images in the CEST-MRF acquisition schedule
    sched_iter = 30
    add_iter = 6

    # Training properties
    learning_rate = 2e-4
    step_size = 1
    gamma = 1
    batch_size = 1024
    num_epochs = 10  # 150
    noise_std = 1e-3  # noise level for training, 1e-2

    min_delta = 0.05  # minimum absolute change in the loss function
    patience = np.inf

    current_dir = os.getcwd()  # Get the current directory
    parent_dir = os.path.dirname(current_dir)  # Navigate up one directory level
    glu_dict_folder_fn = os.path.join(parent_dir, 'data', 'exp', 'mt_amide_glu_dicts', dict_name_category, 'glu',
                                      fp_prtcl_name)  # dict folder directory
    memmap_fn = os.path.join(glu_dict_folder_fn, 'M0_dict.dat')
    glu_dict_fn = os.path.join(glu_dict_folder_fn, 'dict.pkl')

    if not os.path.exists(memmap_fn):
        pkl_2_dat(glu_dict_folder_fn, sched_iter, add_iter, memmap_fn)

    net_name = f'{dict_name_category}_glu_dict_noise_{noise_std}_lr_{learning_rate}_{batch_size}'  # _cosine
    nn_fn = os.path.join(current_dir, 'mouse_nns', 'glu_amide_mt_nns', dict_name_category, 'glu', fp_prtcl_name,
                         f'M0_{net_name}.pt')  # nn directory

    device = initialize_device()
    print(f"Using device: {device}")

    # Load the entire dataset to get its size
    # full_dataset = Dataset_4pool(glu_dict_fn)
    full_dataset = GluMemDataset_4pool(memmap_fn, sched_iter, add_iter)
    # full_dataset = NoShuffleMultiDataset(glu_dict_folder_fn, add_iter)
    
    (min_param_tensor, max_param_tensor,
    min_water_t1t2_tensor, max_water_t1t2_tensor,
    min_mt_param_tensor, max_mt_param_tensor, 
    min_amide_param_tensor, max_amide_param_tensor) = define_min_max(memmap_fn, sched_iter, add_iter, device)

    # Convert tensors to numpy arrays
    min_param_array = min_param_tensor.cpu().numpy()
    max_param_array = max_param_tensor.cpu().numpy()
    min_water_t1t2_array = min_water_t1t2_tensor.cpu().numpy()
    max_water_t1t2_array = max_water_t1t2_tensor.cpu().numpy()
    min_mt_param_array = min_mt_param_tensor.cpu().numpy()
    max_mt_param_array = max_mt_param_tensor.cpu().numpy()
    min_amide_param_array = min_amide_param_tensor.cpu().numpy()
    max_amide_param_array = max_amide_param_tensor.cpu().numpy()
    
    if not os.path.exists(os.path.dirname(nn_fn)):
        os.makedirs(os.path.dirname(nn_fn))
        
    # Save all arrays to a single .npz file
    np.savez(os.path.join(os.path.dirname(nn_fn),'min_max_values.npz'),
             min_param=min_param_array,
             max_param=max_param_array,
             min_water_t1t2=min_water_t1t2_array,
             max_water_t1t2=max_water_t1t2_array,
             min_mt_param=min_mt_param_array,
             max_mt_param=max_mt_param_array,
             min_amide_param=min_amide_param_array,
             max_amide_param=max_amide_param_array)
    
    dataset_size = len(full_dataset)

    # Split indices for training, validation, and test sets
    train_indices, val_indices, test_indices = split_dataset_indices(dataset_size, val_ratio=0.2, test_ratio=0.1)

    # Create subsets
    train_dataset = Subset(full_dataset, train_indices)
    val_dataset = Subset(full_dataset, val_indices)
    test_dataset = Subset(full_dataset, test_indices)

    # Create DataLoaders
    train_loader = DataLoader(dataset=train_dataset,
                              batch_size=batch_size,
                              shuffle=True,
                              num_workers=1)

    val_loader = DataLoader(dataset=val_dataset,
                            batch_size=batch_size,
                            shuffle=False,
                            num_workers=1)

    test_loader = DataLoader(dataset=test_dataset,
                             batch_size=batch_size,
                             shuffle=False,
                             num_workers=1)

    train_network(train_loader, val_loader, test_loader, device, sched_iter, add_iter, dict_name_category, learning_rate, num_epochs, noise_std, patience,
                  min_delta, min_param_tensor, max_param_tensor, min_water_t1t2_tensor,
                  max_water_t1t2_tensor, min_mt_param_tensor, max_mt_param_tensor, min_amide_param_tensor, max_amide_param_tensor, nn_fn, step_size, gamma, net_name)


# Function to split dataset indices
def split_dataset_indices(dataset_size, val_ratio=0.2, test_ratio=0.1):
    indices = np.arange(dataset_size)
    np.random.shuffle(indices)
    test_split = int(test_ratio * dataset_size)
    val_split = int(val_ratio * dataset_size) + test_split
    test_indices = indices[:test_split]
    val_indices = indices[test_split:val_split]
    train_indices = indices[val_split:]
    return train_indices, val_indices, test_indices

# Function to initialize device
def initialize_device():
    return 'cuda' if torch.cuda.is_available() else 'cpu'


# Function to train the network
def train_network(train_loader, val_loader, test_loader, device, sched_iter, add_iter, dict_name, learning_rate, num_epochs, noise_std, patience, min_delta,
                  min_param_tensor, max_param_tensor, min_water_t1t2_tensor, max_water_t1t2_tensor,
                  min_mt_param_fs_ksw, max_mt_param_fs_ksw,  min_amide_param_fs_ksw, max_amide_param_fs_ksw, nn_fn, step_size, gamma, net_name):
    nn_folder = os.path.dirname(nn_fn)  # Navigate up one directory level
    if not os.path.exists(nn_folder):
        os.makedirs(nn_folder)

    # Initializing the reconstruction network
    reco_net = Network(sched_iter, add_iter=add_iter, n_hidden=2, n_neurons=300).to(device)

    # Print amount of parameters
    print('Number of model parameters: ', sum(p.numel() for p in reco_net.parameters() if p.requires_grad))

    # Setting optimizer
    optimizer = torch.optim.Adam(reco_net.parameters(), lr=learning_rate)
    scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)

    # Storing current time
    t0 = time.time()
    # Get today's date
    today = datetime.datetime.now().strftime('%Y-%m-%d')
    writer = SummaryWriter(log_dir=f'runs/{net_name}')

    loss_per_epoch = []
    val_loss_per_epoch = []
    patience_counter = 0
    min_loss = 100

    reco_net.train()
    cur_val_loss = float('inf')

    pbar = tqdm.tqdm(total=num_epochs)
    for epoch in range(num_epochs):
        # Cumulative loss
        cum_loss = 0
        counter = np.nan
        
        num_steps = len(train_loader)
        inner_pbar = tqdm.tqdm(total=num_steps)
        for counter, dict_params in enumerate(train_loader, 0):
            reco_net, cum_loss = train_step(device, noise_std, reco_net, optimizer, cum_loss, dict_params,
                                            min_param_tensor, max_param_tensor,
                                            min_water_t1t2_tensor, max_water_t1t2_tensor,
                                            min_mt_param_fs_ksw, max_mt_param_fs_ksw, min_amide_param_fs_ksw, max_amide_param_fs_ksw, 
                                            writer, epoch, counter, num_steps)
            inner_pbar.set_description(f'Step: {counter+1}/{num_steps}')
            inner_pbar.update(1)
            
            del dict_params
            torch.cuda.empty_cache()
        inner_pbar.close()

        # Average loss for this epoch
        loss_per_epoch.append(cum_loss / (counter + 1))
        
        # Validate the model
        val_loss = validate(reco_net, val_loader, device, min_param_tensor, max_param_tensor,
                            min_water_t1t2_tensor, max_water_t1t2_tensor,
                            min_mt_param_fs_ksw, max_mt_param_fs_ksw, min_amide_param_fs_ksw, max_amide_param_fs_ksw)
        val_loss_per_epoch.append(val_loss)
        
        writer.add_scalar("Loss/train", loss_per_epoch[-1], epoch)
        writer.add_scalar("Loss/val", val_loss, epoch)

        pbar.set_description(f'Epoch: {epoch + 1}/{num_epochs}, Train Loss = {loss_per_epoch[-1]}, Val Loss = {val_loss_per_epoch[-1]}')
        pbar.update(1)

        # Early stopping logic
        if (min_loss - val_loss_per_epoch[-1]) / min_loss > min_delta:
            min_loss = val_loss_per_epoch[-1]
            patience_counter = 0
        else:
            patience_counter += 1

        if patience_counter > patience:
            print('Early stopping!')
            break
            
        # Scheduler step
        scheduler.step()

        # Save model checkpoint when val loss gets better
        if val_loss <= cur_val_loss:
            print(f"\nSaved epoch {epoch} model")
            torch.save({
                'model_state_dict': reco_net.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'loss_per_epoch': loss_per_epoch,
                'val_loss_per_epoch': val_loss_per_epoch,
                'noise_std': noise_std,
                'epoch': epoch
            }, nn_fn)
            
            torch.cuda.empty_cache()
            cur_val_loss = val_loss

    pbar.close()
    print(f"Training took {time.time() - t0:.2f} seconds")

    # # Save final model checkpoint
    # torch.save({
    #     'model_state_dict': reco_net.state_dict(),
    #     'optimizer_state_dict': optimizer.state_dict(),
    #     'loss_per_epoch': loss_per_epoch,
    #     'val_loss_per_epoch': val_loss_per_epoch,
    #     'noise_std': noise_std,
    # }, nn_fn)

    writer.flush()
    writer.close()

    # Test the model
    test_loss = test(reco_net, test_loader, device, min_param_tensor, max_param_tensor,
                     min_water_t1t2_tensor, max_water_t1t2_tensor,
                     min_mt_param_fs_ksw, max_mt_param_fs_ksw, min_amide_param_fs_ksw, max_amide_param_fs_ksw)
    print(f"Test Loss: {test_loss}")

    return reco_net

def train_step(device, noise_std, reco_net, optimizer, cum_loss, dict_params, min_param_tensor, max_param_tensor,
               min_water_t1t2_tensor, max_water_t1t2_tensor, min_mt_param_fs_ksw, max_mt_param_fs_ksw, min_amide_param_fs_ksw, max_amide_param_fs_ksw,
               writer, epoch, counter, num_steps):
    cur_fs, cur_ksw, cur_t1w, cur_t2w, cur_mt_fs, cur_mt_ksw, cur_amide_fs, cur_amide_ksw, cur_norm_sig = dict_params

    target = torch.stack((cur_fs, cur_ksw), dim=1).to(device)
    input_water_t1t2 = torch.stack((cur_t1w, cur_t2w), dim=1).to(device)
    input_mt_fs_ksw = torch.stack((cur_mt_fs, cur_mt_ksw), dim=1).to(device)
    input_amide_fs_ksw = torch.stack((cur_amide_fs, cur_amide_ksw), dim=1).to(device)

    # Normalizing the target and input_water_t1t2
    target = normalize_range(original_array=target, original_min=min_param_tensor,
                             original_max=max_param_tensor, new_min=0, new_max=1).to(device)

    input_water_t1t2 = normalize_range(original_array=input_water_t1t2, original_min=min_water_t1t2_tensor,
                                       original_max=max_water_t1t2_tensor, new_min=0, new_max=1).to(device)

    input_mt_fs_ksw = normalize_range(original_array=input_mt_fs_ksw, original_min=min_mt_param_fs_ksw,
                                      original_max=max_mt_param_fs_ksw, new_min=0, new_max=1).to(device)
    
    input_amide_fs_ksw = normalize_range(original_array=input_amide_fs_ksw, original_min=min_amide_param_fs_ksw,
                                  original_max=max_amide_param_fs_ksw, new_min=0, new_max=1).to(device)

    # Adding noise to the input signals (trajectories)
    noised_sig = cur_norm_sig + torch.randn(cur_norm_sig.size()) * noise_std

    # adding the mt_fs_ksw and t1, t2 as additional nn input
    noised_sig = torch.hstack((input_amide_fs_ksw, input_mt_fs_ksw, input_water_t1t2, noised_sig.to(device))).to(device)
    del input_water_t1t2, input_mt_fs_ksw, input_amide_fs_ksw

    # Forward step
    prediction = reco_net(noised_sig.float())
    del noised_sig

    # Batch loss (MSE)
    loss = torch.mean((prediction.float() - target.float()) ** 2)
    del target

    # Backward step
    optimizer.zero_grad()
    loss.backward()

    # Optimization step
    optimizer.step()

    # Storing Cumulative loss
    cum_loss += loss.item()
    
    writer.add_scalar("Loss/train_step", loss.item(), counter+epoch*num_steps)
    
    torch.cuda.empty_cache()
    
    return reco_net, cum_loss


def validate(reco_net, val_loader, device, min_param_tensor, max_param_tensor, min_water_t1t2_tensor, max_water_t1t2_tensor,
             min_mt_param_fs_ksw, max_mt_param_fs_ksw, min_amide_param_fs_ksw, max_amide_param_fs_ksw):
    reco_net.eval()
    val_loss = 0
    with torch.no_grad():
        for dict_params in val_loader:
            cur_fs, cur_ksw, cur_t1w, cur_t2w, cur_mt_fs, cur_mt_ksw, cur_amide_fs, cur_amide_ksw, cur_norm_sig = dict_params

            target = torch.stack((cur_fs, cur_ksw), dim=1).to(device)
            input_water_t1t2 = torch.stack((cur_t1w, cur_t2w), dim=1).to(device)
            input_mt_fs_ksw = torch.stack((cur_mt_fs, cur_mt_ksw), dim=1).to(device)
            input_amide_fs_ksw = torch.stack((cur_amide_fs, cur_amide_ksw), dim=1).to(device)

            target = normalize_range(original_array=target, original_min=min_param_tensor,
                                     original_max=max_param_tensor, new_min=0, new_max=1).to(device)
            input_water_t1t2 = normalize_range(original_array=input_water_t1t2, original_min=min_water_t1t2_tensor,
                                               original_max=max_water_t1t2_tensor, new_min=0, new_max=1).to(device)
            input_mt_fs_ksw = normalize_range(original_array=input_mt_fs_ksw, original_min=min_mt_param_fs_ksw,
                                              original_max=max_mt_param_fs_ksw, new_min=0, new_max=1).to(device)
            input_amide_fs_ksw = normalize_range(original_array=input_amide_fs_ksw, original_min=min_amide_param_fs_ksw,
                              original_max=max_amide_param_fs_ksw, new_min=0, new_max=1).to(device)

            noised_sig = cur_norm_sig

            noised_sig = torch.hstack((input_amide_fs_ksw, input_mt_fs_ksw, input_water_t1t2, noised_sig.to(device))).to(device)

            prediction = reco_net(noised_sig.float())

            loss = torch.mean((prediction.float() - target.float()) ** 2)

            val_loss += loss.item()

    return val_loss / len(val_loader)


def test(reco_net, test_loader, device, min_param_tensor, max_param_tensor, min_water_t1t2_tensor, max_water_t1t2_tensor,
         min_mt_param_fs_ksw, max_mt_param_fs_ksw, min_amide_param_fs_ksw, max_amide_param_fs_ksw):
    reco_net.eval()
    test_loss = 0
    with torch.no_grad():
        for dict_params in test_loader:
            cur_fs, cur_ksw, cur_t1w, cur_t2w, cur_mt_fs, cur_mt_ksw, cur_amide_fs, cur_amide_ksw, cur_norm_sig = dict_params

            target = torch.stack((cur_fs, cur_ksw), dim=1).to(device)
            input_water_t1t2 = torch.stack((cur_t1w, cur_t2w), dim=1).to(device)
            input_mt_fs_ksw = torch.stack((cur_mt_fs, cur_mt_ksw), dim=1).to(device)
            input_amide_fs_ksw = torch.stack((cur_amide_fs, cur_amide_ksw), dim=1).to(device)

            target = normalize_range(original_array=target, original_min=min_param_tensor,
                                     original_max=max_param_tensor, new_min=0, new_max=1).to(device)
            input_water_t1t2 = normalize_range(original_array=input_water_t1t2, original_min=min_water_t1t2_tensor,
                                               original_max=max_water_t1t2_tensor, new_min=0, new_max=1).to(device)
            input_mt_fs_ksw = normalize_range(original_array=input_mt_fs_ksw, original_min=min_mt_param_fs_ksw,
                                              original_max=max_mt_param_fs_ksw, new_min=0, new_max=1).to(device)
            input_amide_fs_ksw = normalize_range(original_array=input_amide_fs_ksw, original_min=min_amide_param_fs_ksw,
                  original_max=max_amide_param_fs_ksw, new_min=0, new_max=1).to(device)

            noised_sig = cur_norm_sig

            noised_sig = torch.hstack((input_amide_fs_ksw, input_mt_fs_ksw, input_water_t1t2, noised_sig.to(device))).to(device)

            prediction = reco_net(noised_sig.float())

            loss = torch.mean((prediction.float() - target.float()) ** 2)

            test_loss += loss.item()

    return test_loss / len(test_loader)


def define_min_max(memmap_fn, sched_iter, add_iter, device):
    num_columns = sched_iter + add_iter + 2
    memmap_array = np.memmap(memmap_fn, dtype=np.float64, mode='r')
    num_rows = memmap_array.size // num_columns  # Calculate the number of rows
    memmap_array.shape = (num_rows, num_columns)  # [#, 30+6]

    min_fs = np.min(memmap_array[:, 4])  # uncomment if non-zero minimum limit is required
    min_ksw = np.min(memmap_array[:, 5].transpose().astype(float))  # uncomment if non-zero minimum limit needed
    max_fs = np.max(memmap_array[:, 4])
    max_ksw = np.max(memmap_array[:, 5].transpose().astype(float))

    min_t1w = np.min(memmap_array[:, 2])
    min_t2w = np.min(memmap_array[:, 3].transpose().astype(float))
    max_t1w = np.max(memmap_array[:, 2])
    max_t2w = np.max(memmap_array[:, 3].transpose().astype(float))

    min_mt_fs = np.min(memmap_array[:, 6])
    min_mt_ksw = np.min(memmap_array[:, 7].transpose().astype(float))
    max_mt_fs = np.max(memmap_array[:, 6])
    max_mt_ksw = np.max(memmap_array[:, 7].transpose().astype(float))

    min_amine_fs = np.min(memmap_array[:, 0])
    min_amine_ksw = np.min(memmap_array[:, 1].transpose().astype(float))
    max_amine_fs = np.max(memmap_array[:, 0])
    max_amine_ksw = np.max(memmap_array[:, 1].transpose().astype(float))

    min_param_tensor = torch.tensor(np.hstack((min_fs, min_ksw)), requires_grad=False).to(device)  # can be switched to  min_fs, min_ksw
    max_param_tensor = torch.tensor(np.hstack((max_fs, max_ksw)), requires_grad=False).to(device)

    min_water_t1t2_tensor = torch.tensor(np.hstack((min_t1w, min_t2w)), requires_grad=False).to(device)
    max_water_t1t2_tensor = torch.tensor(np.hstack((max_t1w, max_t2w)), requires_grad=False).to(device)

    min_mt_param_tensor = torch.tensor(np.hstack((min_mt_fs, min_mt_ksw)), requires_grad=False).to(device)  # can be switched to  min_fs, min_ksw
    max_mt_param_tensor = torch.tensor(np.hstack((max_mt_fs, max_mt_ksw)), requires_grad=False).to(device)  # can be switched to  min_fs, min_ksw

    min_amine_param_tensor = torch.tensor(np.hstack((min_amine_fs, min_amine_ksw)), requires_grad=False).to(device)  # can be switched to  min_fs, min_ksw
    max_amine_param_tensor = torch.tensor(np.hstack((max_amine_fs, max_amine_ksw)),
                                       requires_grad=False).to(device)  # can be switched to  min_fs, min_ksw

    return (min_param_tensor, max_param_tensor, min_water_t1t2_tensor, max_water_t1t2_tensor,
            min_mt_param_tensor, max_mt_param_tensor, min_amine_param_tensor, max_amine_param_tensor)


if __name__ == '__main__':
    if platform.system() == 'Windows':
        multiprocessing.set_start_method('spawn', force=True)
    # os.chdir(os.path.dirname(os.path.realpath(__file__)))
    set_seed(2024)

    main()


Random seed set as 2024
Using device: cuda
There are 24393600 entries in the training dictionary
Number of model parameters:  193502


  0%|          | 0/10 [00:00<?, ?it/s]

  0%|          | 0/16676 [00:00<?, ?it/s]


Saved epoch 0 model


  0%|          | 0/16676 [00:00<?, ?it/s]


Saved epoch 1 model


  0%|          | 0/16676 [00:00<?, ?it/s]

  0%|          | 0/16676 [00:00<?, ?it/s]


Saved epoch 3 model


  0%|          | 0/16676 [00:00<?, ?it/s]

  0%|          | 0/16676 [00:00<?, ?it/s]

  0%|          | 0/16676 [00:00<?, ?it/s]


Saved epoch 6 model


  0%|          | 0/16676 [00:00<?, ?it/s]


Saved epoch 7 model


  0%|          | 0/16676 [00:00<?, ?it/s]

  0%|          | 0/16676 [00:00<?, ?it/s]

Training took 6425.42 seconds
Test Loss: 0.05101222234597939


In [ ]:
# def main():
#     torch.multiprocessing.freeze_support()
#     dict_name_category = 'high_glu_conc_100'
#     fp_prtcl_name = '107a'
# 
#     # Schedule iterations
#     # number of raw images in the CEST-MRF acquisition schedule
#     sched_iter = 30
#     add_iter = 6
# 
#     # Training properties
#     learning_rate = 2e-4
#     step_size = 1
#     gamma = 1
#     batch_size = 1024
#     num_epochs = 20  # 150
#     noise_std = 1e-3  # noise level for training, 1e-2
# 
#     min_delta = 0.05  # minimum absolute change in the loss function
#     patience = np.inf
# 
#     current_dir = os.getcwd()  # Get the current directory
#     parent_dir = os.path.dirname(current_dir)  # Navigate up one directory level
#     glu_dict_folder_fn = os.path.join(parent_dir, 'data', 'exp', 'mt_amide_glu_dicts', dict_name_category, 'glu',
#                                       fp_prtcl_name)  # dict folder directory
#     memmap_fn = os.path.join(glu_dict_folder_fn, 'dict.dat')
#     glu_dict_fn = os.path.join(glu_dict_folder_fn, 'dict.pkl')
# 
#     if not os.path.exists(memmap_fn):
#         pkl_2_dat(glu_dict_folder_fn, sched_iter, add_iter, memmap_fn)
# 
#     net_name = f'{dict_name_category}_glu_dict_noise_{noise_std}_lr_{learning_rate}_{batch_size}'  # _cosine
#     nn_fn = os.path.join(current_dir, 'mouse_nns', 'glu_amide_mt_nns', dict_name_category, 'glu', fp_prtcl_name,
#                          f'{net_name}.pt')  # nn directory
# 
#     device = initialize_device()
#     print(f"Using device: {device}")
# 
#     # Load the entire dataset to get its size
#     # full_dataset = Dataset_4pool(glu_dict_fn)
#     full_dataset = GluMemDataset_4pool(memmap_fn, sched_iter, add_iter)
#     # full_dataset = NoShuffleMultiDataset(glu_dict_folder_fn, add_iter)
#     
#     (min_param_tensor, max_param_tensor,
#     min_water_t1t2_tensor, max_water_t1t2_tensor,
#     min_mt_param_tensor, max_mt_param_tensor, 
#     min_amide_param_tensor, max_amide_param_tensor) = define_min_max(memmap_fn, sched_iter, add_iter, device)
# 
#     # Convert tensors to numpy arrays
#     min_param_array = min_param_tensor.cpu().numpy()
#     max_param_array = max_param_tensor.cpu().numpy()
#     min_water_t1t2_array = min_water_t1t2_tensor.cpu().numpy()
#     max_water_t1t2_array = max_water_t1t2_tensor.cpu().numpy()
#     min_mt_param_array = min_mt_param_tensor.cpu().numpy()
#     max_mt_param_array = max_mt_param_tensor.cpu().numpy()
#     min_amide_param_array = min_amide_param_tensor.cpu().numpy()
#     max_amide_param_array = max_amide_param_tensor.cpu().numpy()
#     
#     if not os.path.exists(os.path.dirname(nn_fn)):
#         os.makedirs(os.path.dirname(nn_fn))
#         
#     # Save all arrays to a single .npz file
#     np.savez(os.path.join(os.path.dirname(nn_fn),'min_max_values.npz'),
#              min_param=min_param_array,
#              max_param=max_param_array,
#              min_water_t1t2=min_water_t1t2_array,
#              max_water_t1t2=max_water_t1t2_array,
#              min_mt_param=min_mt_param_array,
#              max_mt_param=max_mt_param_array,
#              min_amide_param=min_amide_param_array,
#              max_amide_param=max_amide_param_array)
#     
#     dataset_size = len(full_dataset)
# 
#     # Split indices for training, validation, and test sets
#     train_indices, val_indices, test_indices = split_dataset_indices(dataset_size, val_ratio=0.2, test_ratio=0.1)
# 
#     # Create subsets
#     train_dataset = Subset(full_dataset, train_indices)
#     val_dataset = Subset(full_dataset, val_indices)
#     test_dataset = Subset(full_dataset, test_indices)
# 
#     # Create DataLoaders
#     train_loader = DataLoader(dataset=train_dataset,
#                               batch_size=batch_size,
#                               shuffle=True,
#                               num_workers=1)
# 
#     val_loader = DataLoader(dataset=val_dataset,
#                             batch_size=batch_size,
#                             shuffle=False,
#                             num_workers=1)
# 
#     test_loader = DataLoader(dataset=test_dataset,
#                              batch_size=batch_size,
#                              shuffle=False,
#                              num_workers=1)
# 
#     train_network(train_loader, val_loader, test_loader, device, sched_iter, add_iter, dict_name_category, learning_rate, num_epochs, noise_std, patience,
#                   min_delta, min_param_tensor, max_param_tensor, min_water_t1t2_tensor,
#                   max_water_t1t2_tensor, min_mt_param_tensor, max_mt_param_tensor, min_amide_param_tensor, max_amide_param_tensor, nn_fn, step_size, gamma, net_name)
# 
# 
# # Function to split dataset indices
# def split_dataset_indices(dataset_size, val_ratio=0.2, test_ratio=0.1):
#     indices = np.arange(dataset_size)
#     np.random.shuffle(indices)
#     test_split = int(test_ratio * dataset_size)
#     val_split = int(val_ratio * dataset_size) + test_split
#     test_indices = indices[:test_split]
#     val_indices = indices[test_split:val_split]
#     train_indices = indices[val_split:]
#     return train_indices, val_indices, test_indices
# 
# # Function to initialize device
# def initialize_device():
#     return 'cuda' if torch.cuda.is_available() else 'cpu'
# 
# 
# # Function to train the network
# def train_network(train_loader, val_loader, test_loader, device, sched_iter, add_iter, dict_name, learning_rate, num_epochs, noise_std, patience, min_delta,
#                   min_param_tensor, max_param_tensor, min_water_t1t2_tensor, max_water_t1t2_tensor,
#                   min_mt_param_fs_ksw, max_mt_param_fs_ksw,  min_amide_param_fs_ksw, max_amide_param_fs_ksw, nn_fn, step_size, gamma, net_name):
#     nn_folder = os.path.dirname(nn_fn)  # Navigate up one directory level
#     if not os.path.exists(nn_folder):
#         os.makedirs(nn_folder)
# 
#     # Initializing the reconstruction network
#     reco_net = Network(sched_iter, add_iter=add_iter, n_hidden=2, n_neurons=300).to(device)
# 
#     # Print amount of parameters
#     print('Number of model parameters: ', sum(p.numel() for p in reco_net.parameters() if p.requires_grad))
# 
#     # Setting optimizer
#     optimizer = torch.optim.Adam(reco_net.parameters(), lr=learning_rate)
#     scheduler = StepLR(optimizer, step_size=step_size, gamma=gamma)
# 
#     # Storing current time
#     t0 = time.time()
#     # Get today's date
#     today = datetime.datetime.now().strftime('%Y-%m-%d')
#     writer = SummaryWriter(log_dir=f'runs/{net_name}')
# 
#     loss_per_epoch = []
#     val_loss_per_epoch = []
#     patience_counter = 0
#     min_loss = 100
# 
#     reco_net.train()
#     cur_val_loss = float('inf')
# 
#     pbar = tqdm.tqdm(total=num_epochs)
#     for epoch in range(num_epochs):
#         # Cumulative loss
#         cum_loss = 0
#         counter = np.nan
#         
#         num_steps = len(train_loader)
#         inner_pbar = tqdm.tqdm(total=num_steps)
#         for counter, dict_params in enumerate(train_loader, 0):
#             reco_net, cum_loss = train_step(device, noise_std, reco_net, optimizer, cum_loss, dict_params,
#                                             min_param_tensor, max_param_tensor,
#                                             min_water_t1t2_tensor, max_water_t1t2_tensor,
#                                             min_mt_param_fs_ksw, max_mt_param_fs_ksw, min_amide_param_fs_ksw, max_amide_param_fs_ksw, 
#                                             writer, epoch, counter, num_steps)
#             inner_pbar.set_description(f'Step: {counter+1}/{num_steps}')
#             inner_pbar.update(1)
#             
#             del dict_params
#             torch.cuda.empty_cache()
#         inner_pbar.close()
# 
#         # Average loss for this epoch
#         loss_per_epoch.append(cum_loss / (counter + 1))
#         
#         # Validate the model
#         val_loss = validate(reco_net, val_loader, device, min_param_tensor, max_param_tensor,
#                             min_water_t1t2_tensor, max_water_t1t2_tensor,
#                             min_mt_param_fs_ksw, max_mt_param_fs_ksw, min_amide_param_fs_ksw, max_amide_param_fs_ksw)
#         val_loss_per_epoch.append(val_loss)
#         
#         writer.add_scalar("Loss/train", loss_per_epoch[-1], epoch)
#         writer.add_scalar("Loss/val", val_loss, epoch)
# 
#         pbar.set_description(f'Epoch: {epoch + 1}/{num_epochs}, Train Loss = {loss_per_epoch[-1]}, Val Loss = {val_loss_per_epoch[-1]}')
#         pbar.update(1)
# 
#         # Early stopping logic
#         if (min_loss - val_loss_per_epoch[-1]) / min_loss > min_delta:
#             min_loss = val_loss_per_epoch[-1]
#             patience_counter = 0
#         else:
#             patience_counter += 1
# 
#         if patience_counter > patience:
#             print('Early stopping!')
#             break
#             
#         # Scheduler step
#         scheduler.step()
# 
#         # Save model checkpoint when val loss gets better
#         if val_loss <= cur_val_loss:
#             print(f"\nSaved epoch {epoch} model")
#             torch.save({
#                 'model_state_dict': reco_net.state_dict(),
#                 'optimizer_state_dict': optimizer.state_dict(),
#                 'loss_per_epoch': loss_per_epoch,
#                 'val_loss_per_epoch': val_loss_per_epoch,
#                 'noise_std': noise_std,
#                 'epoch': epoch
#             }, nn_fn)
#             
#             torch.cuda.empty_cache()
#             cur_val_loss = val_loss
# 
#     pbar.close()
#     print(f"Training took {time.time() - t0:.2f} seconds")
# 
#     # # Save final model checkpoint
#     # torch.save({
#     #     'model_state_dict': reco_net.state_dict(),
#     #     'optimizer_state_dict': optimizer.state_dict(),
#     #     'loss_per_epoch': loss_per_epoch,
#     #     'val_loss_per_epoch': val_loss_per_epoch,
#     #     'noise_std': noise_std,
#     # }, nn_fn)
# 
#     writer.flush()
#     writer.close()
# 
#     # Test the model
#     test_loss = test(reco_net, test_loader, device, min_param_tensor, max_param_tensor,
#                      min_water_t1t2_tensor, max_water_t1t2_tensor,
#                      min_mt_param_fs_ksw, max_mt_param_fs_ksw, min_amide_param_fs_ksw, max_amide_param_fs_ksw)
#     print(f"Test Loss: {test_loss}")
# 
#     return reco_net
# 
# def train_step(device, noise_std, reco_net, optimizer, cum_loss, dict_params, min_param_tensor, max_param_tensor,
#                min_water_t1t2_tensor, max_water_t1t2_tensor, min_mt_param_fs_ksw, max_mt_param_fs_ksw, min_amide_param_fs_ksw, max_amide_param_fs_ksw,
#                writer, epoch, counter, num_steps):
#     cur_fs, cur_ksw, cur_t1w, cur_t2w, cur_mt_fs, cur_mt_ksw, cur_amide_fs, cur_amide_ksw, cur_norm_sig = dict_params
# 
#     target = torch.stack((cur_fs, cur_ksw), dim=1).to(device)
#     input_water_t1t2 = torch.stack((cur_t1w, cur_t2w), dim=1).to(device)
#     input_mt_fs_ksw = torch.stack((cur_mt_fs, cur_mt_ksw), dim=1).to(device)
#     input_amide_fs_ksw = torch.stack((cur_amide_fs, cur_amide_ksw), dim=1).to(device)
# 
#     # Normalizing the target and input_water_t1t2
#     target = normalize_range(original_array=target, original_min=min_param_tensor,
#                              original_max=max_param_tensor, new_min=0, new_max=1).to(device)
# 
#     input_water_t1t2 = normalize_range(original_array=input_water_t1t2, original_min=min_water_t1t2_tensor,
#                                        original_max=max_water_t1t2_tensor, new_min=0, new_max=1).to(device)
# 
#     input_mt_fs_ksw = normalize_range(original_array=input_mt_fs_ksw, original_min=min_mt_param_fs_ksw,
#                                       original_max=max_mt_param_fs_ksw, new_min=0, new_max=1).to(device)
#     
#     input_amide_fs_ksw = normalize_range(original_array=input_amide_fs_ksw, original_min=min_amide_param_fs_ksw,
#                                   original_max=max_amide_param_fs_ksw, new_min=0, new_max=1).to(device)
# 
#     # Adding noise to the input signals (trajectories)
#     noised_sig = cur_norm_sig + torch.randn(cur_norm_sig.size()) * noise_std
# 
#     # adding the mt_fs_ksw and t1, t2 as additional nn input
#     noised_sig = torch.hstack((input_amide_fs_ksw, input_mt_fs_ksw, input_water_t1t2, noised_sig.to(device))).to(device)
#     del input_water_t1t2, input_mt_fs_ksw, input_amide_fs_ksw
# 
#     # Forward step
#     prediction = reco_net(noised_sig.float())
#     del noised_sig
# 
#     # Batch loss (MSE)
#     loss = torch.mean((prediction.float() - target.float()) ** 2)
#     del target
# 
#     # Backward step
#     optimizer.zero_grad()
#     loss.backward()
# 
#     # Optimization step
#     optimizer.step()
# 
#     # Storing Cumulative loss
#     cum_loss += loss.item()
#     
#     writer.add_scalar("Loss/train_step", loss.item(), counter+epoch*num_steps)
#     
#     torch.cuda.empty_cache()
#     
#     return reco_net, cum_loss
# 
# 
# def validate(reco_net, val_loader, device, min_param_tensor, max_param_tensor, min_water_t1t2_tensor, max_water_t1t2_tensor,
#              min_mt_param_fs_ksw, max_mt_param_fs_ksw, min_amide_param_fs_ksw, max_amide_param_fs_ksw):
#     reco_net.eval()
#     val_loss = 0
#     with torch.no_grad():
#         for dict_params in val_loader:
#             cur_fs, cur_ksw, cur_t1w, cur_t2w, cur_mt_fs, cur_mt_ksw, cur_amide_fs, cur_amide_ksw, cur_norm_sig = dict_params
# 
#             target = torch.stack((cur_fs, cur_ksw), dim=1).to(device)
#             input_water_t1t2 = torch.stack((cur_t1w, cur_t2w), dim=1).to(device)
#             input_mt_fs_ksw = torch.stack((cur_mt_fs, cur_mt_ksw), dim=1).to(device)
#             input_amide_fs_ksw = torch.stack((cur_amide_fs, cur_amide_ksw), dim=1).to(device)
# 
#             target = normalize_range(original_array=target, original_min=min_param_tensor,
#                                      original_max=max_param_tensor, new_min=0, new_max=1).to(device)
#             input_water_t1t2 = normalize_range(original_array=input_water_t1t2, original_min=min_water_t1t2_tensor,
#                                                original_max=max_water_t1t2_tensor, new_min=0, new_max=1).to(device)
#             input_mt_fs_ksw = normalize_range(original_array=input_mt_fs_ksw, original_min=min_mt_param_fs_ksw,
#                                               original_max=max_mt_param_fs_ksw, new_min=0, new_max=1).to(device)
#             input_amide_fs_ksw = normalize_range(original_array=input_amide_fs_ksw, original_min=min_amide_param_fs_ksw,
#                               original_max=max_amide_param_fs_ksw, new_min=0, new_max=1).to(device)
# 
#             noised_sig = cur_norm_sig
# 
#             noised_sig = torch.hstack((input_amide_fs_ksw, input_mt_fs_ksw, input_water_t1t2, noised_sig.to(device))).to(device)
# 
#             prediction = reco_net(noised_sig.float())
# 
#             loss = torch.mean((prediction.float() - target.float()) ** 2)
# 
#             val_loss += loss.item()
# 
#     return val_loss / len(val_loader)
# 
# 
# def test(reco_net, test_loader, device, min_param_tensor, max_param_tensor, min_water_t1t2_tensor, max_water_t1t2_tensor,
#          min_mt_param_fs_ksw, max_mt_param_fs_ksw, min_amide_param_fs_ksw, max_amide_param_fs_ksw):
#     reco_net.eval()
#     test_loss = 0
#     with torch.no_grad():
#         for dict_params in test_loader:
#             cur_fs, cur_ksw, cur_t1w, cur_t2w, cur_mt_fs, cur_mt_ksw, cur_amide_fs, cur_amide_ksw, cur_norm_sig = dict_params
# 
#             target = torch.stack((cur_fs, cur_ksw), dim=1).to(device)
#             input_water_t1t2 = torch.stack((cur_t1w, cur_t2w), dim=1).to(device)
#             input_mt_fs_ksw = torch.stack((cur_mt_fs, cur_mt_ksw), dim=1).to(device)
#             input_amide_fs_ksw = torch.stack((cur_amide_fs, cur_amide_ksw), dim=1).to(device)
# 
#             target = normalize_range(original_array=target, original_min=min_param_tensor,
#                                      original_max=max_param_tensor, new_min=0, new_max=1).to(device)
#             input_water_t1t2 = normalize_range(original_array=input_water_t1t2, original_min=min_water_t1t2_tensor,
#                                                original_max=max_water_t1t2_tensor, new_min=0, new_max=1).to(device)
#             input_mt_fs_ksw = normalize_range(original_array=input_mt_fs_ksw, original_min=min_mt_param_fs_ksw,
#                                               original_max=max_mt_param_fs_ksw, new_min=0, new_max=1).to(device)
#             input_amide_fs_ksw = normalize_range(original_array=input_amide_fs_ksw, original_min=min_amide_param_fs_ksw,
#                   original_max=max_amide_param_fs_ksw, new_min=0, new_max=1).to(device)
# 
#             noised_sig = cur_norm_sig
# 
#             noised_sig = torch.hstack((input_amide_fs_ksw, input_mt_fs_ksw, input_water_t1t2, noised_sig.to(device))).to(device)
# 
#             prediction = reco_net(noised_sig.float())
# 
#             loss = torch.mean((prediction.float() - target.float()) ** 2)
# 
#             test_loss += loss.item()
# 
#     return test_loss / len(test_loader)
# 
# 
# def define_min_max(memmap_fn, sched_iter, add_iter, device):
#     num_columns = sched_iter + add_iter + 2
#     memmap_array = np.memmap(memmap_fn, dtype=np.float64, mode='r')
#     num_rows = memmap_array.size // num_columns  # Calculate the number of rows
#     memmap_array.shape = (num_rows, num_columns)  # [#, 30+6]
# 
#     min_fs = np.min(memmap_array[:, 4])  # uncomment if non-zero minimum limit is required
#     min_ksw = np.min(memmap_array[:, 5].transpose().astype(float))  # uncomment if non-zero minimum limit needed
#     max_fs = np.max(memmap_array[:, 4])
#     max_ksw = np.max(memmap_array[:, 5].transpose().astype(float))
# 
#     min_t1w = np.min(memmap_array[:, 2])
#     min_t2w = np.min(memmap_array[:, 3].transpose().astype(float))
#     max_t1w = np.max(memmap_array[:, 2])
#     max_t2w = np.max(memmap_array[:, 3].transpose().astype(float))
# 
#     min_mt_fs = np.min(memmap_array[:, 6])
#     min_mt_ksw = np.min(memmap_array[:, 7].transpose().astype(float))
#     max_mt_fs = np.max(memmap_array[:, 6])
#     max_mt_ksw = np.max(memmap_array[:, 7].transpose().astype(float))
# 
#     min_amine_fs = np.min(memmap_array[:, 0])
#     min_amine_ksw = np.min(memmap_array[:, 1].transpose().astype(float))
#     max_amine_fs = np.max(memmap_array[:, 0])
#     max_amine_ksw = np.max(memmap_array[:, 1].transpose().astype(float))
# 
#     min_param_tensor = torch.tensor(np.hstack((min_fs, min_ksw)), requires_grad=False).to(device)  # can be switched to  min_fs, min_ksw
#     max_param_tensor = torch.tensor(np.hstack((max_fs, max_ksw)), requires_grad=False).to(device)
# 
#     min_water_t1t2_tensor = torch.tensor(np.hstack((min_t1w, min_t2w)), requires_grad=False).to(device)
#     max_water_t1t2_tensor = torch.tensor(np.hstack((max_t1w, max_t2w)), requires_grad=False).to(device)
# 
#     min_mt_param_tensor = torch.tensor(np.hstack((min_mt_fs, min_mt_ksw)), requires_grad=False).to(device)  # can be switched to  min_fs, min_ksw
#     max_mt_param_tensor = torch.tensor(np.hstack((max_mt_fs, max_mt_ksw)), requires_grad=False).to(device)  # can be switched to  min_fs, min_ksw
# 
#     min_amine_param_tensor = torch.tensor(np.hstack((min_amine_fs, min_amine_ksw)), requires_grad=False).to(device)  # can be switched to  min_fs, min_ksw
#     max_amine_param_tensor = torch.tensor(np.hstack((max_amine_fs, max_amine_ksw)),
#                                        requires_grad=False).to(device)  # can be switched to  min_fs, min_ksw
# 
#     return (min_param_tensor, max_param_tensor, min_water_t1t2_tensor, max_water_t1t2_tensor,
#             min_mt_param_tensor, max_mt_param_tensor, min_amine_param_tensor, max_amine_param_tensor)
# 
# 
# if __name__ == '__main__':
#     if platform.system() == 'Windows':
#         multiprocessing.set_start_method('spawn', force=True)
#     # os.chdir(os.path.dirname(os.path.realpath(__file__)))
#     set_seed(2024)
# 
#     main()
