# Fibad Demonstration

## Download a sample HSC dataset

This dataset is comprised of approximately 1000 cutouts from the Hyper Suprime Cam survey.
The cutouts were requested to be 8 arcsecs on a side.
For consistency, we crop each image to [96, 96] pixels at runtime.
For each object 3 bands have been acquired, I, R, G.

Once unzipped there will be a .fits file for each (object id, band) in the `./data/hsc_8asec_1000` directory.


In [None]:
!pip install pooch
import pooch

file_path = pooch.retrieve(
    # DOI for Example HSC dataset
    url="doi:10.5281/zenodo.14498537/hsc_demo_data.zip",
    known_hash="md5:ed18ac315a1f9bbc7c2325fd4f1c683a",
    fname="example_hsc.zip",
    path="./data",
    processor=pooch.Unzip(extract_dir="."),
)

print(file_path[0])

# Configuration and Training

First we import fibad and create a new fibad object, instantiated (implicitly), with the default configuration file.

In [None]:
import fibad

f = fibad.Fibad()

For this demo, we'll make a few adjustments to the default configuration settings that the `fibad` object was instantiated with. By accessing the `.config` attribute of the fibad instance, we can modify any configuration value. 

Here we change which model to train, the dataset, the location of the data, number of epochs for training as well as a few other parameters.

In [None]:
f.config["general"]["data_dir"] = "./data/hsc_8asec_1000"
f.config["model"]["name"] = "ExampleAutoencoder"
f.config["data_set"]["name"] = "HSCDataSet"
f.config["data_set"]["crop_to"] = [96, 96]
f.config["download"]["filter"] = ["HSC-G", "HSC-R", "HSC-I"]
f.config["train"]["epochs"] = 20
f.config["data_loader"]["batch_size"] = 32

We call the `.train()` method to train the model

In [None]:
f.train()

The output of the training will be stored in a time-stamped directory under the `./results/`. By default, a copy of the final configuration used in training is persisted as `runtime_config.toml`. To run fibad again with the same configuration, you can reference the runtime_config.toml file.

If running in another notebook, instantiate a fibad object like so:
```
new_fibad_instance = fibad.Fibad(config_file='./results/<timestamped_directory>/runtime_config.toml')
```

Or from the command line on an HPC system:
```
>> fibad train --runtime-config ./results/<timestamped_directory>/runtime_config.toml
```

## Create a new model

New models can be written in a notebook for easier development.
Here an autoencoder is written for comparison against the builtin `ExampleAutoencoder`.

For reference, the primary difference is that the builtin autoencoder uses `nn.GeLU` whereas `nn.ReLU` is used here.

In [None]:
import torch.nn as nn
import torch.optim as optim
from fibad.models.model_registry import fibad_model


@fibad_model  # This decorator registers the model with the FIBAD framework
class TrialAutoencoder(nn.Module):
    def __init__(self, config, shape):
        super().__init__()
        self.config = config

        # Encoder
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 16, kernel_size=3, stride=2, padding=1),  # (16, 48, 48)
            nn.ReLU(),
            nn.Conv2d(16, 32, kernel_size=3, stride=2, padding=1),  # (32, 24, 24)
            nn.ReLU(),
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=1),  # (64, 12, 12)
            nn.ReLU(),
            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),  # (128, 6, 6)
            nn.ReLU(),
        )

        # Decoder
        self.decoder = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),  # (64, 12, 12)
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1),  # (32, 24, 24)
            nn.ReLU(),
            nn.ConvTranspose2d(32, 16, kernel_size=3, stride=2, padding=1, output_padding=1),  # (16, 48, 48)
            nn.ReLU(),
            nn.ConvTranspose2d(16, 3, kernel_size=3, stride=2, padding=1, output_padding=1),  # (3, 96, 96)
            nn.Sigmoid(),  # Normalize output to [0, 1]
        )

    def _eval_encoder(self, x):
        return self.encoder(x)

    def _eval_decoder(self, x):
        return self.decoder(x)

    def forward(self, x):
        return self._eval_encoder(x)

    def train_step(self, x):
        z = self._eval_encoder(x)
        x_hat = self._eval_decoder(z)

        # Here, the loss function is defined in the config
        loss = self.criterion(x, x_hat)
        loss = loss.sum(dim=[1, 2, 3]).mean(dim=[0])
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        return {"loss": loss.item()}

    # The optimizer can be coded directly, or defined in the configuration file
    def _optimizer(self):
        return optim.Adam(self.parameters(), lr=1e-3)

With the model defined in the notebook and registered with Fibad using the `@fibad_model` decorator, the configuration
is updated so that in the next call to `f.train()` the new model will be used.

In [None]:
f.config["model"]["name"] = "TrialAutoencoder"
f.config["criterion"]["name"] = "torch.nn.MSELoss"
f.config["torch.nn.MSELoss"] = {}
f.config["torch.nn.MSELoss"]["reduction"] = "none"

In [None]:
f.train()

## Compare model performance

Fibad will automatically logs training information for model evaluation.
Currently Tensorboard and MLFlow are supported for easy model-to-model comparisons of metrics and parameters.

![alt text](mlflow_mpr_training_loss.JPG)
![alt text](mlflow_mpr_param_diffs.JPG)


In [None]:
%reload_ext tensorboard
%tensorboard --logdir ./results

## Running inference using a trained model
Once a model has been trained, we can use the model weights file to run inference on.

In this example, we'll assume that the the builtin `ExampleAutoencoder` has outperformed the `TrialAutoencoder` defined in the notebook.

In order to run inference using a specific trained model, the `'infer'` configuration section is updated.
Additionally, to run inference on the entire dataset the data set splits are updated.

In [None]:
# Set this to the path of the example.pth file that was created by the call to fibad_instance.train().
# It should be something like `.../results/<timestamp>-train/example_model.pth`.
# fibad_instance.config["infer"]["model_weights_file"] = ""

# Update the data set splits to be 100% test data
f.config["data_set"]["test_size"] = 1.0
f.config["data_set"]["train_size"] = 0.0
f.config["data_set"]["validate_size"] = 0.0

# Set the batch size larger for faster inference
f.config["data_loader"]["batch_size"] = 512

The following will run inference on the specified dataset.

In [None]:
f.infer()

## Exploring the results of inference

Fibad will save the output of inference in batched .npy files in the `.../results/<timestamp>-infer-xxxx` directory.

Optionally, Fibad can populate a vector database for fast approximate similarity search.

In [None]:
import chromadb
import numpy as np

# open a connection to the vector database
from fibad.config_utils import find_most_recent_results_dir

results_dir = find_most_recent_results_dir(f.config, "infer")

client = chromadb.PersistentClient(path=str(results_dir))
collection = client.get_collection("fibad")

In [None]:
all_embeddings = collection.get(include=["embeddings"])
all_nn = collection.query(query_embeddings=all_embeddings["embeddings"], n_results=5)
median_all_nn_dist = np.median(all_nn["distances"], axis=1)

In [None]:
import matplotlib.pyplot as plt

# Choose a threshold, plot histogram of values lower to exclude outliers due to instrumental defects
_ = plt.hist(median_all_nn_dist, bins=100, range=(0, 30_000))

In [None]:
indexes = [i for i, x in enumerate(median_all_nn_dist) if 18_000 < x and x < 30_000]
print(f"Number of indexes: {len(indexes)}")

# Should probably sort the indexes in order of increasing median_all_nn_dist values.

In [None]:
anom_object_ids = []
for indx in indexes:
    anom_object_ids.append(all_embeddings["ids"][indx])

Get the actual file names for the "anomalous objects"

In [None]:
import glob

names = set()
for anom_object_id in anom_object_ids:
    found_files = glob.glob(
        f"/home/drew/code/fibad/docs/notebooks/data/hsc_example/hsc_8asec_1000/{anom_object_id}*.fits"
    )
    for f in found_files:
        names.add(f[:-11])

In [None]:
from astropy.io import fits


# Function to normalize the data to the range [0, 1]
def normalize(data):
    data_min = np.min(data)
    data_max = np.max(data)
    return (data - data_min) / (data_max - data_min)


# Plot our 3 filter images
def plotter(ax, file_name):
    # Read the FITS files
    fits_file_r = file_name + "_HSC-I.fits"
    fits_file_g = file_name + "_HSC-R.fits"
    fits_file_b = file_name + "_HSC-G.fits"

    data_r = fits.getdata(fits_file_r)
    data_g = fits.getdata(fits_file_g)
    data_b = fits.getdata(fits_file_b)

    # Normalize the data
    data_r = normalize(data_r)
    data_g = normalize(data_g)
    data_b = normalize(data_b)

    # Combine the data into an RGB image
    rgb_image = np.zeros((data_r.shape[0], data_r.shape[1], 3))
    rgb_image[..., 0] = data_r  # Red channel
    rgb_image[..., 1] = data_g  # Green channel
    rgb_image[..., 2] = data_b  # Blue channel

    # Display the image
    ax.imshow(rgb_image, origin="lower")
    ax.set_title(
        "Obj ID: " + file_name.split("/")[-1][:17], y=1.0, pad=-14, color="white"
    )  # Set the title to the file name
    ax.axis("off")  # Hide the axis

In [None]:
def plot_nx3_grid(data_list):
    """
    Plots an n x 3 grid of matplotlib plots.

    Parameters
    ----------
    data_list : list of arrays
        List of data arrays to plot.
    """
    num_plots = len(data_list)
    num_rows = (num_plots + 2) // 3  # Calculate the number of rows needed

    fig, axes = plt.subplots(num_rows, 3, figsize=(15, 5 * num_rows))

    for i, data in enumerate(data_list):
        row = i // 3
        col = i % 3
        ax = axes[row, col]
        plotter(ax, data)
        ax.plot(data, color="white")  # Set plot line color to white
        fig.patch.set_facecolor("darkslategrey")  # Set background color to black

    # Hide any unused subplots
    for j in range(num_plots, num_rows * 3):
        fig.delaxes(axes.flatten()[j])

    plt.tight_layout()
    plt.show()

In [None]:
plot_nx3_grid(names)