In [1]:
import torch
import json
import numpy as np
import pickle
from tqdm import tqdm
from pathlib import Path

import matplotlib.pyplot as plt

from s4hci.utils.data_handling import load_adi_data
from s4hci.models.noise import S4Ridge

import torch.nn.functional as F
import torch.multiprocessing as mp

from s4hci.utils.masks import construct_round_rfrr_template, construct_rfrr_mask

## Load the data

In [2]:
dataset_config_file = "../../../../70_results/x1_s4_cross_validation/0100_C-0656-A/dataset.json"

with open(dataset_config_file) as json_file:
    dataset_config = json.load(json_file)

print("Loading data ... ", end='')
science_data, raw_angles, raw_psf_template_data = \
    load_adi_data(dataset_config["file_path"],
                  data_tag=dataset_config["stack_key"],
                  psf_template_tag=dataset_config["psf_template_key"],
                  para_tag=dataset_config["parang_key"])

psf_template_data = np.mean(raw_psf_template_data, axis=0)

X_train = science_data[0::2]
X_test = science_data[1::2]
print("[DONE]")

Loading data ... [DONE]


## Test S4 Ridge

In [3]:
s4_ridge = S4Ridge(
    psf_template=psf_template_data,
    lambda_reg = 100,
    convolve=True,
    use_normalization=True,
    available_devices=[0],
    half_precision=False,
    cut_radius_psf=4,
    mask_template_setup=("radius", 6.5))

## run a fast test

In [8]:
s4_ridge._setup_training(science_data)
s4_ridge.science_data_norm = s4_ridge.normalize_data(science_data)

Creating right reason mask ... [DONE]
Build normalization frames ... [DONE]


In [9]:
positions = [(y, x)
             for x in range(0, s4_ridge.image_size, 20)
             for y in range(0, s4_ridge.image_size, 20)]

In [10]:
s4_ridge.betas = s4_ridge._fit_mp(positions)

100%|██████████| 36/36 [00:22<00:00,  1.58it/s]


In [11]:
s4_ridge.betas.shape

torch.Size([36, 11449])

## Run one full fit

In [4]:
s4_ridge.fit(science_data)

Creating right reason mask ... [DONE]
Build normalization frames ... [DONE]


100%|██████████| 11449/11449 [1:17:05<00:00,  2.48it/s]


## Save and restore

In [5]:
save_file = "/cluster/project/quanz/mbonse/2022_S4/70_results/2023_04_13_hyperparameter_search/initial_tests/test_betas.pkl"

In [6]:
s4_ridge.save(save_file)

In [7]:
s4_ridge_new = S4Ridge.restore_from_checkpoint(
    save_file,
    verbose=True,
    available_devices=[0],
    half_precision=False)

In [8]:
s4_ridge_new.betas.shape

torch.Size([11449, 11449])