In [1]:
import xarray as xr
import numpy as np

# Initialize an empty xarray Dataset
ds = xr.Dataset(
    {
        'O': (['L', 'p_ctrl', 'p_proj', 'trial'], np.empty((0, 0, 0, 0))),
        'EE': (['L', 'p_ctrl', 'p_proj', 'trial'], np.empty((0, 0, 0, 0))),
        'TMI': (['L', 'p_ctrl', 'p_proj', 'trial'], np.empty((0, 0, 0, 0)))
    },
    coords={
        'L': [],
        'p_ctrl': [],
        'p_proj': [],
        'trial': []
    }
)


In [2]:
ds

In [3]:
def add_new_data(ds, d):
    # Extract parameter values
    L = np.arange(*d['args'].L)
    p_ctrl = np.linspace(*d['args'].p_ctrl)
    p_proj = np.linspace(*d['args'].p_proj)
    
    # Convert torch tensor to numpy for easier handling with xarray
    O = d['O'].numpy()
    EE = d['EE'].numpy()
    TMI = d['TMI'].numpy()

    # Handle each parameter combination
    for l in L:
        for pc in p_ctrl:
            for pp in p_proj:
                # Check if the parameter combination exists
                condition = (ds['L'] == l) & (ds['p_ctrl'] == pc) & (ds['p_proj'] == pp)
                
                if condition.sum() > 0:
                    # If the combination exists, append new trials
                    for key, data in datasets.items():
                        current_data = ds.loc[dict(L=l, p_ctrl=pc, p_proj=pp)][key]
                        new_data = np.concatenate([current_data, data], axis=-1)
                        ds.loc[dict(L=l, p_ctrl=pc, p_proj=pp)][key] = new_data

                else:
                    # If the combination does not exist, add a new record
                    new_data_dict = {key: (('L', 'p_ctrl', 'p_proj', 'trial'), data) 
                                     for key, data in datasets.items()}
                    
                    new_dataset = xr.Dataset(
                        new_data_dict,
                        coords={
                            'L': [l],
                            'p_ctrl': [pc],
                            'p_proj': [pp],
                            'trial': np.arange(data.shape[-1])
                        }
                    )
                    
                    ds = xr.concat([ds, new_dataset], dim=['L', 'p_ctrl', 'p_proj'])

    return ds

In [5]:
import pickle

In [6]:
with open('CT_En2000_pctrl(0.00,1.00,11)_pproj(0.00,0.00,1)_L(10,14,2)_xj(1-3,2-3)_seed0_es2000_64.pickle','rb') as f:
    data=pickle.load(f)

In [None]:
ds=add_new_data(ds,data)