In [1]:
import json
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt

import torch

# imports to load some data
from fours.utils.data_handling import load_adi_data

from applefy.utils.fake_planets import add_fake_planets
from applefy.utils.photometry import flux_ratio2mag, mag2flux_ratio

from pynpoint.util.analysis import fake_planet

# Load the data

In [2]:
dataset_file = Path("/fast/mbonse/s4/30_data/s4_datasets/HD22049_303_199_C-0065_C_.hdf5")
experiment_root_dir = Path("/fast/mbonse/s4/70_results/01_position_flux/implementation_tests/")
exp_id = str("0123a")
use_rotation_loss = True
lambda_reg = float(850)

In [3]:
science_data, angles, raw_psf_template_data = \
    load_adi_data(
        hdf5_dataset=str(dataset_file),
        data_tag="object_stacked_05",
        psf_template_tag="psf_template",
        para_tag="header_object_stacked_05/PARANG")

psf_template = np.median(raw_psf_template_data, axis=0)

# other parameters
dit_psf_template = 0.0042560
dit_science = 0.08
fwhm = 3.6
pixel_scale = 0.02718

# we cut the image to 91 x 91 pixel to be slightly larger than 1.2 arcsec
cut_off = int((science_data.shape[1] - 91) / 2)
science_data = science_data[:, cut_off:-cut_off, cut_off:-cut_off]

In [4]:
dataset_id = "0115b"

# add the fake planet
fake_planet_config_file = "/fast/mbonse/s4/70_results/x1_fake_planet_experiments/HD22049_303_199_C-0065_C_/configs_cgrid/exp_ID_" + exp_id + ".json"
with open(fake_planet_config_file) as json_file:
    fake_planet_config = json.load(json_file)

In [5]:
fake_planet_config["flux_ratio"] = mag2flux_ratio(12.)
fake_planet_config["separation"] = 30
fake_planet_config["planet_position"] = [0., 0., 30, 20.]

In [6]:
data_with_fake_planet = add_fake_planets(
    input_stack=science_data,
    psf_template=psf_template,
    parang=angles - np.pi/2, # the pi/2 is to correct the parang to the north
    dit_psf_template=dit_psf_template,
    dit_science=dit_science,
    experiment_config=fake_planet_config,
    scaling_factor=1.0)

In [7]:
science_data_torch_with_planet = torch.from_numpy(data_with_fake_planet).float()

# Import the noise models

In [8]:
from fours.position_flux.pca import PCANoiseModel
from fours.position_flux.negfc import NegFC
from fours.models.psf_subtraction import FourS
from fours.position_flux.loss_functions import NegFCLoss

from tqdm import tqdm
from torch import optim

# run a test with PCA

In [9]:
pca_model = PCANoiseModel(
    angles=angles,
    image_shape=science_data.shape[1],
    pca_number=200,
    approx_svd=5000)

In [10]:
# create the NegFC model
negfc = NegFC(
    psf_template=psf_template,
    all_angles=angles,
    input_size=science_data.shape[1],
    init_separation=fake_planet_config["separation"] + 2,
    init_pos_angle=fake_planet_config["planet_position"][-1] - 2,
    init_magnitude=11.5,
    dit_science=dit_science,
    dit_psf_template=dit_psf_template,
    nd_factor=1.0,
    interpolation="bicubic",)

In [11]:
# create the loss function
negfc_loss = NegFCLoss(
    residual_shape=science_data.shape[1:],
    separation_pixel=fake_planet_config["separation"],
    pos_angle_deg=fake_planet_config["planet_position"][-1],
    aperture_radius=10,
    metric_function="hessian")

In [12]:
# move everything to the GPU
science_data_torch_with_planet = science_data_torch_with_planet.to(0)
negfc = negfc.to(0)
pca_model = pca_model.to(0)
negfc_loss = negfc_loss.to(0)

In [36]:
input_data = science_data_torch_with_planet

# create the optimizer
optimizer_kwargs = {
    "max_iter": 20,
    "history_size": 10}

optimizer = optim.LBFGS(
    negfc.parameters(),
    **optimizer_kwargs)

In [37]:
for j in range(30):
    print(negfc.separation.item(), negfc.pos_angle.item(), negfc.magnitude.item())
    after_negfc = negfc(science_data_torch_with_planet)
    pca_model.update_noise_model(after_negfc)

    # iterate for 100 steps
    for i in range(20):
        def closure():        
            optimizer.zero_grad()
            after_negfc = negfc(science_data_torch_with_planet)
            after_pca_neg_fc = pca_model(after_negfc)
            
            # apply the loss function
            loss = negfc_loss(after_pca_neg_fc)
            loss.backward()
            return loss

        optimizer.step(closure)

31.97163200378418 18.002296447753906 11.571259498596191
31.327606201171875 19.144691467285156 11.924256324768066
30.768407821655273 19.97575569152832 12.091241836547852


KeyboardInterrupt: 

In [15]:
x_shift_res = negfc.separation.item() * np.cos(np.deg2rad(negfc.pos_angle.item()))
y_shift_res = negfc.separation.item() * np.sin(np.deg2rad(negfc.pos_angle.item()))

x_shift = 30 * np.cos(np.deg2rad(20))
y_shift = 30 * np.sin(np.deg2rad(20))

print(np.abs(y_shift - y_shift_res))
print(np.abs(x_shift - x_shift_res))
print(np.abs(negfc.magnitude.item() - 12))

0.0932798870667213
0.04365737891630772
0.020438194274902344


In [39]:
import torch
import torch.nn.functional as F


from fours.utils.masks import create_aperture_mask


# code needed for the calculation of the hessian
def gaussian_kernel1d(sigma, order, radius):
    """Create a 1D Gaussian kernel of a given order and sigma."""
    x = torch.arange(-radius, radius + 1, dtype=torch.float32)
    if order == 0:
        kernel = torch.exp(-0.5 * (x / sigma) ** 2)
    elif order == 1:
        kernel = -x / sigma ** 2 * torch.exp(-0.5 * (x / sigma) ** 2)
    elif order == 2:
        kernel = (x ** 2 / sigma ** 4 - 1 / sigma ** 2) * torch.exp(
            -0.5 * (x / sigma) ** 2)
    else:
        raise ValueError('Order must be 0, 1, or 2')
    kernel /= kernel.abs().sum()
    return kernel


def gaussian_kernel2d(sigma, order_x, order_y, radius):
    """Create a 2D Gaussian kernel by computing the outer product of two 1D Gaussian kernels."""
    kernel_x = gaussian_kernel1d(sigma, order_x, radius)
    kernel_y = gaussian_kernel1d(sigma, order_y, radius)
    kernel_2d = torch.outer(kernel_y, kernel_x)
    return kernel_2d


def hessian_matrix(
        image,
        hxx_kernel,
        hxy_kernel,
        hyy_kernel):

    # Reshape kernels for convolution
    hxx_kernel = hxx_kernel.view(1, 1, *hxx_kernel.shape)
    hxy_kernel = hxy_kernel.view(1, 1, *hxy_kernel.shape)
    hyy_kernel = hyy_kernel.view(1, 1, *hyy_kernel.shape)

    # Add batch and channel dimensions to the image
    image = image.unsqueeze(0).unsqueeze(0)

    # Convolve the image with the kernels
    padding = (hxx_kernel.shape[-1] - 1) // 2
    hxx = F.conv2d(image, hxx_kernel, padding=padding)
    hxy = F.conv2d(image, hxy_kernel, padding=padding)
    hyy = F.conv2d(image, hyy_kernel, padding=padding)

    # Remove the batch and channel dimensions
    hxx = hxx.squeeze(0).squeeze(0)
    hxy = hxy.squeeze(0).squeeze(0)
    hyy = hyy.squeeze(0).squeeze(0)

    return hxx, hxy, hyy


class NegFCLoss(torch.nn.Module):

    def __init__(
            self,
            residual_shape,
            separation_pixel,
            pos_angle_deg,
            aperture_radius,
            metric_function="mse",
            sigma_hessian=1.8
    ):
        super(NegFCLoss, self).__init__()

        # Create the mask and transform it to a tensor
        self.loss_mask = create_aperture_mask(
            image_shape=residual_shape,
            separation_pixel=separation_pixel,
            pos_angle_deg=pos_angle_deg,
            aperture_radius=aperture_radius)

        # define the metric function
        self.metric_function = metric_function

        # if the metric function is hessian create the kernels
        if self.metric_function == "hessian":
            radius = int(3 * sigma_hessian + 0.5)
            hxx_kernel = gaussian_kernel2d(
                sigma_hessian, 2, 0, radius)
            hxy_kernel = gaussian_kernel2d(
                sigma_hessian, 1, 1, radius)
            hyy_kernel = gaussian_kernel2d(
                sigma_hessian, 0, 2, radius)

            # register the kernels
            self.register_buffer("hxx_kernel", hxx_kernel)
            self.register_buffer("hxy_kernel", hxy_kernel)
            self.register_buffer("hyy_kernel", hyy_kernel)

    def forward(self, residual_image):
        if self.metric_function == "hessian":
            hxx, hxy, hyy = hessian_matrix(
                residual_image,
                self.hxx_kernel,
                self.hxy_kernel,
                self.hyy_kernel)

            # calculate the determinant of the hessian
            hes_det = (hxx * hyy) - (hxy * hxy)

            # apply the mask
            selected_pixel = hes_det[self.loss_mask]

        elif self.metric_function == "mse":
            # apply the mask
            selected_pixel = residual_image[self.loss_mask]

        else:
            raise ValueError("The metric function is not implemented.")

        # calculate the loss
        loss = torch.sum(selected_pixel ** 2)
        print(loss)

        return loss


# Run the same for the FourS model

In [10]:
experiment_root_dir = Path("/fast/mbonse/s4/70_results/01_position_flux/implementation_tests/4s_test_dir")
experiment_root_dir.mkdir(exist_ok=True)

In [48]:
fours_model = FourS(
    science_cube=data_with_fake_planet,
    adi_angles=angles,
    psf_template=psf_template,
    noise_model_lambda=100,
    normalization_type="dynamic",
    psf_fwhm=3.6,
    right_reason_mask_factor=1.5,
    rotation_grid_subsample=1,
    device=0,
    work_dir=experiment_root_dir,
    verbose=True)

In [49]:
fours_model.fit_noise_model(
        num_epochs=50,
        training_name="First_training",
        logging_interval=10)

S4 model: Fit noise model ... 

  0%|          | 0/50 [00:00<?, ?it/s]

[DONE]


In [50]:
# create the NegFC model
negfc = NegFC(
    psf_template=psf_template,
    all_angles=angles,
    input_size=science_data.shape[1],
    init_separation=fake_planet_config["separation"] + 2,
    init_pos_angle=fake_planet_config["planet_position"][-1] - 2,
    init_magnitude=11.5,
    dit_science=dit_science,
    dit_psf_template=dit_psf_template,
    nd_factor=1.0,
    interpolation="bicubic",)

In [51]:
# create the loss function
negfc_loss = NegFCLoss(
    residual_shape=science_data.shape[1:],
    separation_pixel=fake_planet_config["separation"],
    pos_angle_deg=fake_planet_config["planet_position"][-1],
    aperture_radius=15,
    metric_function="hessian")
negfc_loss = negfc_loss.to(0)

In [52]:
negfc = negfc.to(0)
science_data_torch_with_planet = science_data_torch_with_planet.to(0)

# create the optimizer
optimizer_kwargs = {
    "max_iter": 20,
    "history_size": 10}

optimizer = optim.LBFGS(
    negfc.parameters(),
    **optimizer_kwargs)

In [53]:
fours_model.noise_model = fours_model.noise_model.cpu()
fours_model.rotation_model = fours_model.rotation_model.cpu()
fours_model.normalization_model = fours_model.normalization_model.cpu()

In [None]:
print(negfc.separation.item(), negfc.pos_angle.item(), negfc.magnitude.item())

for j in range(5):    
    # move everything to the GPU
    fours_model.noise_model = fours_model.noise_model.to(0)
    fours_model.rotation_model = fours_model.rotation_model.to(0)
    fours_model.normalization_model = fours_model.normalization_model.to(0)
    fours_model.noise_model.betas_raw.requires_grad = False
    fours_model.noise_model.intercept.requires_grad = False
    fours_model.noise_model.compute_betas()

    # iterate for 200 steps
    negfc = negfc.to(0)
    # create the loss function
    negfc_loss = NegFCLoss(
        residual_shape=science_data.shape[1:],
        separation_pixel=fake_planet_config["separation"],
        pos_angle_deg=fake_planet_config["planet_position"][-1],
        aperture_radius=500,
        metric_function="hessian")
    
    negfc_loss = negfc_loss.to(0)
    
    for i in tqdm(range(5)):
        def closure():        
            optimizer.zero_grad()
            after_negfc = negfc(science_data_torch_with_planet)
            after_4s_neg_fc = fours_model.forward(after_negfc)[0]

            # apply the loss function
            loss = negfc_loss(after_4s_neg_fc)
            loss.backward()
            return loss

        optimizer.step(closure)
    
    print(negfc.separation.item(), negfc.pos_angle.item(), negfc.magnitude.item())
    with torch.no_grad():
        after_negfc_for_train = negfc(science_data_torch_with_planet).detach()
        
    fours_model.noise_model.betas_raw.requires_grad = True
    fours_model.noise_model.intercept.requires_grad = True

    fours_model.update_noise_model(
        images=after_negfc_for_train,
        num_epochs=20,
        training_name="update_" + str(j+1),
        logging_interval=10)

29.994550704956055 19.821147918701172 11.972750663757324


 40%|██████████████████████████▊                                        | 2/5 [00:00<00:00,  7.79it/s]

tensor(2.9743e-06, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2.9743e-06, device='cuda:0', grad_fn=<SumBackward0>)


 80%|█████████████████████████████████████████████████████▌             | 4/5 [00:00<00:00,  8.03it/s]

tensor(2.9743e-06, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2.9743e-06, device='cuda:0', grad_fn=<SumBackward0>)


100%|███████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  7.96it/s]

tensor(2.9743e-06, device='cuda:0', grad_fn=<SumBackward0>)
29.994550704956055 19.821147918701172 11.972750663757324
S4 model: Fit noise model ... 




  0%|          | 0/20 [00:00<?, ?it/s]

[DONE]


 40%|██████████████████████████▊                                        | 2/5 [00:00<00:00,  8.13it/s]

tensor(2.7367e-06, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2.7367e-06, device='cuda:0', grad_fn=<SumBackward0>)


 80%|█████████████████████████████████████████████████████▌             | 4/5 [00:00<00:00,  8.16it/s]

tensor(2.7367e-06, device='cuda:0', grad_fn=<SumBackward0>)
tensor(2.7367e-06, device='cuda:0', grad_fn=<SumBackward0>)


100%|███████████████████████████████████████████████████████████████████| 5/5 [00:00<00:00,  8.15it/s]

tensor(2.7367e-06, device='cuda:0', grad_fn=<SumBackward0>)
29.994550704956055 19.821147918701172 11.972750663757324
S4 model: Fit noise model ... 




  0%|          | 0/20 [00:00<?, ?it/s]

In [None]:
x_shift_res = negfc.separation.item() * np.cos(np.deg2rad(negfc.pos_angle.item()))
y_shift_res = negfc.separation.item() * np.sin(np.deg2rad(negfc.pos_angle.item()))

x_shift = 30 * np.cos(np.deg2rad(20))
y_shift = 30 * np.sin(np.deg2rad(20))

print(np.abs(y_shift - y_shift_res))
print(np.abs(x_shift - x_shift_res))
print(np.abs(negfc.magnitude.item() - 12))