# Intro to Training and Configurations

First we import fibad and create a new fibad object, instantiated (implicitly), with the default configuration file.

In [None]:
import fibad

fibad_instance = fibad.Fibad(config_file="/Users/drew/code/fibad/drews_config.toml")

For this demo, we'll make a few adjustments to the default configuration settings that the `fibad` object was instantiated with. By accessing the `.config` attribute of the fibad instance, we can modify any configuration value. Here we change which built in model to use, the dataset, batch size, number of epochs for training.

In [None]:
fibad_instance.config["model"]["name"] = "ExampleAutoencoder"
fibad_instance.config["data_set"]["name"] = "HSCDataSet"
fibad_instance.config["data_loader"]["batch_size"] = 64
fibad_instance.config["train"]["epochs"] = 20

We call the `.train()` method to train the model

In [None]:
fibad_instance.train()

The output of the training will be stored in a time-stamped directory under the `./results/`. By default, a copy of the final configuration used in training is persisted as `runtime_config.toml`. To run fibad again with the same configuration, you can reference the runtime_config.toml file.

If running in another notebook, instantiate a fibad object like so:
```
new_fibad_instance = fibad.Fibad(config_file='./results/<timestamped_directory>/runtime_config.toml')
```

Or from the command line:
```
>> fibad train --runtime-config ./results/<timestamped_directory>/runtime_config.toml
```

In [None]:
%load_ext tensorboard

In [None]:
%tensorboard --logdir ./results

# if running on a remote server, and tunnelling a connection,
# pass the --bind-all flag
# %tensorboard --logdir ./results --bind_all
# and then forward the selected port to your local machine

In [None]:
fibad_instance.config["predict"][
    "model_weights_file"
] = "/Users/drew/code/fibad/docs/notebooks/results/20241216-110118-train/example_model.pth"
fibad_instance.config["predict"]["split"] = "test"
fibad_instance.config["data_set"]["test_size"] = 1.0
fibad_instance.config["data_set"]["train_size"] = 0.0
fibad_instance.config["data_set"]["validate_size"] = 0.0
fibad_instance.config["data_loader"]["batch_size"] = 128

In [None]:
fibad_instance.predict()

In [None]:
prepped_output = fibad_instance.prepare()

In [None]:
import chromadb

client = chromadb.PersistentClient(path="/Users/drew/code/fibad/docs/notebooks/results/vdb")

In [None]:
collection = client.get_collection("fibad_collection")

In [None]:
import numpy as np

a = np.load("/Users/drew/code/fibad/docs/notebooks/results/20241216-155404-predict/0.npy")

In [None]:
# 97 is a cool example

query_results = collection.query(
    query_embeddings=[a[67]],
    n_results=10,
)

print(query_results["distances"])

In [None]:
metadatas = query_results["metadatas"]

files_to_plot = []
for m in metadatas[0]:
    files = prepped_output.container.files[int(m["filename"])]
    g_file = files["HSC-G"]
    files_to_plot.append(g_file[:-11])

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from astropy.io import fits


# Function to normalize the data to the range [0, 1]
def normalize(data):
    data_min = np.min(data)
    data_max = np.max(data)
    return (data - data_min) / (data_max - data_min)


def plotter(file_name):
    # Read the FITS files
    base_path = "/Users/drew/code/fibad/docs/notebooks/data/hsc_example/hsc_8asec_1000/"
    fits_file_r = base_path + file_name + "_HSC-I.fits"
    fits_file_g = base_path + file_name + "_HSC-R.fits"
    fits_file_b = base_path + file_name + "_HSC-G.fits"

    data_r = fits.getdata(fits_file_r)
    data_g = fits.getdata(fits_file_g)
    data_b = fits.getdata(fits_file_b)

    # Normalize the data
    data_r = normalize(data_r)
    data_g = normalize(data_g)
    data_b = normalize(data_b)

    # Combine the data into an RGB image
    rgb_image = np.zeros((data_r.shape[0], data_r.shape[1], 3))
    rgb_image[..., 0] = data_r  # Red channel
    rgb_image[..., 1] = data_g  # Green channel
    rgb_image[..., 2] = data_b  # Blue channel

    # Display the image
    plt.imshow(rgb_image, origin="lower")
    plt.axis("off")  # Hide the axis
    plt.show()


for file_name in files_to_plot:
    plotter(file_name)