In [4]:
import os
import sys
import importlib

os.environ["JAX_ENABLE_X64"] = "true"
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
os.environ['ENABLE_PJRT_COMPATIBILITY'] = 'false'
os.environ["JAX_PLATFORMS"] = "cpu"

In [5]:
import numpy as np

In [None]:
import src
importlib.import_module(f'src')
for module in dir(src):
    if '__' not in module:
        print(module)
        importlib.import_module(f'src.{module}')
        importlib.reload(sys.modules[f'src.{module}'])

from src.experiments import GratingImages
from src.experiments import experiment_grating
from src.theory_utils_spectral import SVR_th_spectral

In [None]:
import timm
max_img_num = 100
random_seed = 42

grating_manifold = GratingImages(max_img_num=max_img_num,
                                 image_size=(64, 64),
                                 max_sample_size=100,
                                 random_seed=random_seed)
angles, manifold = grating_manifold.create_grating_manifold()
grating_manifold.sample_digits()
grating_manifold.sample_manifold()

In [8]:
def process_data(all_data, save_keys):
    
    from copy import deepcopy

    model_list = ['resnet18', 'resnet34', 'resnet50'
                  ]
    trained_list = ['random', 'trained']
    layer_list = ['conv1', 'layer1', 'layer2',  'layer3', 'layer4']

    all_box_data = []
    for trained in trained_list:
        box_temp1 = []
        for model in model_list:
            box_temp2 = []
            for layer in layer_list:

                d = deepcopy(all_data[trained][model][layer])
                d = d['all_results_th']

                def correct_dict(all_results_th):

                    corrected_th = deepcopy(all_results_th)

                    C = all_results_th['C'][0]
                    Etr = all_results_th['Etr'][0]
                    idx_C = np.abs(np.diff(C, prepend=0)) > 2e-1
                    idx_Etr = Etr > 1e-5

                    correct_keys = ['C', 'Etr', 'alpha_eff', 'epsilon_eff', 'epsilon_th', 'lamb_eff']

                    for i in range(len(C)):
                        if (idx_C[i] or idx_Etr[i]) and i > 5:
                            for key in correct_keys:
                                corrected_th[key][0][i] = corrected_th[key][0][i-1]

                    return corrected_th
                d = correct_dict(d)

                assert len(np.unique(d['E_inf'])) == 1
                d['Etr'] = d['Etr'] * (d['Etr'] > 1e-10)
                d['C'] = d['C'] * (d['C'] > 1e-15)
                d['C'][0] *= (d['Etr'][0] < 1e-15)
                d['epsilon_th'][0] *= (d['Etr'][0] < 1e-5)
                d['epsilon_eff'][0] *= (d['Etr'][0] < 1e-5)
                d['Etr'][0] *= (d['Etr'][0] < 1e-5)

                idx_wrong = d['Etr'][0] > 1e-5
                if idx_wrong.sum():
                    d['C'][0][idx_wrong] = np.nan

                box_temp2 += [[d[key]*np.ones((4, 100)) for key in save_keys]]
            box_temp1 += [box_temp2]
        all_box_data += [box_temp1]

    all_box_data = np.array(all_box_data).swapaxes(0, 3)
    all_box_data_dict = {key: val for key, val in zip(save_keys, all_box_data)}
    return all_box_data_dict

In [None]:
overwrite = False


rand_proj_dim_list = [25, 50, 75, 100, 500, 1000,
                      ]
sample_size_list = [100, 300,
                    ]
max_img_num_list = [100, 300
                    ]
random_seed_list = [12, 42, 123,  187, 345,
                    ]

from copy import deepcopy
dict4 = {key: dict() for key in random_seed_list}
dict3 = {key: deepcopy(dict4) for key in sample_size_list}
dict2 = {key: deepcopy(dict3) for key in max_img_num_list}
dict1 = {key: deepcopy(dict2) for key in rand_proj_dim_list}

combined_data = deepcopy(dict1)
for rand_proj_dim in rand_proj_dim_list:
    for max_img_num in max_img_num_list:
        for sample_size in sample_size_list:
            for seed in random_seed_list:

                data = experiment_grating(sample_size=sample_size,
                                          max_img_num=max_img_num,
                                          rand_proj_dim=rand_proj_dim,
                                          random_seed=seed,
                                          SVR_th=SVR_th_spectral,
                                          plot_fig=False)

                combined_data[rand_proj_dim][max_img_num][sample_size][seed] = deepcopy(data)


combined_data['rand_proj_dim_list'] = rand_proj_dim_list
combined_data['max_img_num_list'] = max_img_num_list
combined_data['sample_size_list'] = sample_size_list
combined_data['random_seed_list'] = random_seed_list

In [None]:
np.savez('./results/combined_data_grating_dataset.npz', combined_data=combined_data)