# Intro to Training and Configurations

First we import fibad and create a new fibad object, instantiated (implicitly), with the default configuration file.

In [None]:
import fibad

fibad_instance = fibad.Fibad(config_file="./train_model_config.toml")

For this demo, we'll make a few adjustments to the default configuration settings that the `fibad` object was instantiated with. By accessing the `.config` attribute of the fibad instance, we can modify any configuration value. Here we change which built in model to use, the dataset, batch size, number of epochs for training.

In [None]:
fibad_instance.config["model"]["name"] = "ExampleAutoencoder"
fibad_instance.config["data_set"]["name"] = "HSCDataSet"
fibad_instance.config["train"]["epochs"] = 20

We call the `.train()` method to train the model

In [None]:
fibad_instance.train()

The output of the training will be stored in a time-stamped directory under the `./results/`. By default, a copy of the final configuration used in training is persisted as `runtime_config.toml`. To run fibad again with the same configuration, you can reference the runtime_config.toml file.

If running in another notebook, instantiate a fibad object like so:
```
new_fibad_instance = fibad.Fibad(config_file='./results/<timestamped_directory>/runtime_config.toml')
```

Or from the command line:
```
>> fibad train --runtime-config ./results/<timestamped_directory>/runtime_config.toml
```

Fibad automatically records training metrics so that they can be examined using Tensorboard.

In [None]:
%reload_ext tensorboard
%tensorboard --logdir ./results

# If you are running on a remote server and want to run tensorboard there;
# you need to pass additional argments for port forwarding to work.
# For example
# %tensorboard --logdir ./results --bind_all --port 8888
# will start tensorboard on port 8888 and you may need to forward that
# port to your local machine using
# ssh -N -L 8888:<name_of_machine>:8888 <username@server.com>

Once a model has been trained, we can use the model weights file to run inference on. Here we update the configuration in the `fibad_instance` object to specify that we want to use a specific model weights file, and that we want our dataset to be 100% test data.

If you are running this locally, you'll need to update the path to your local model weights file.

In [None]:
# Set this to the path of the example.pth file that was created by the call to fibad_instance.train().
# It should be something like `.../results/<timestamp>-train/example_model.pth`.
# fibad_instance.config["infer"]["model_weights_file"] = ""

fibad_instance.config["data_set"]["test_size"] = 1.0
fibad_instance.config["data_set"]["train_size"] = 0.0
fibad_instance.config["data_set"]["validate_size"] = 0.0
fibad_instance.config["data_loader"]["batch_size"] = 128

The following will run inference on the specified dataset.

In [None]:
# Uncomment the following line after setting the model_weights_file in the previous cell

fibad_instance.infer()

TODO: Add a cell to plot a confusion matrix.

In [None]:
import chromadb
import numpy as np

# open a connection to the vector database
from fibad.config_utils import find_most_recent_results_dir

results_dir = find_most_recent_results_dir(fibad_instance.config, "infer")

client = chromadb.PersistentClient(path=str(results_dir))
collection = client.get_collection("fibad")

In [None]:
print(f"Total records in the vdb: {collection.count()}")
example_record = collection.get(ids=["40005764936385106"], include=["embeddings"])
print(f"First record: {example_record}")

print(f"Shape of embedding: {example_record["embeddings"].shape}")

In [None]:
nearest_neighbors = collection.query(query_embeddings=example_record["embeddings"], n_results=10)
print(f"Distance to nearest neighbors: {nearest_neighbors['distances']}")
median_dist_to_neighbors = np.median(nearest_neighbors["distances"])
print(f"Median distance to neighbors: {median_dist_to_neighbors}")

In [None]:
all_embeddings = collection.get(include=["embeddings"])

In [None]:
all_nn = collection.query(query_embeddings=all_embeddings["embeddings"], n_results=10)
median_all_nn_dist = np.median(all_nn["distances"], axis=1)

In [None]:
import matplotlib.pyplot as plt

_ = plt.hist(median_all_nn_dist, bins=50)

In [None]:
indexes = [i for i, x in enumerate(median_all_nn_dist) if 4 < x]
print(f"Number of indexes: {len(indexes)}")

In [None]:
anom_object_ids = []
for indx in indexes:
    anom_object_ids.append(all_embeddings["ids"][indx])

In [None]:
import glob

names = set()
for anom_object_id in anom_object_ids:
    found_files = glob.glob(
        f"/home/drew/code/fibad/docs/notebooks/data/hsc_example/hsc_8asec_1000/{anom_object_id}*.fits"
    )
    for f in found_files:
        print(f)
        names.add(f[:-11])

print(names)

In [None]:
from astropy.io import fits


# Function to normalize the data to the range [0, 1]
def normalize(data):
    data_min = np.min(data)
    data_max = np.max(data)
    return (data - data_min) / (data_max - data_min)


# Plot our 3 filter images
def plotter(file_name):
    # Read the FITS files
    base_path = "/home/drew/code/fibad/docs/notebooks/data/hsc_example/hsc_8asec_1000/"
    base_path = ""
    fits_file_r = base_path + file_name + "_HSC-I.fits"
    fits_file_g = base_path + file_name + "_HSC-R.fits"
    fits_file_b = base_path + file_name + "_HSC-G.fits"

    data_r = fits.getdata(fits_file_r)
    data_g = fits.getdata(fits_file_g)
    data_b = fits.getdata(fits_file_b)

    # Normalize the data
    data_r = normalize(data_r)
    data_g = normalize(data_g)
    data_b = normalize(data_b)

    # Combine the data into an RGB image
    rgb_image = np.zeros((data_r.shape[0], data_r.shape[1], 3))
    rgb_image[..., 0] = data_r  # Red channel
    rgb_image[..., 1] = data_g  # Green channel
    rgb_image[..., 2] = data_b  # Blue channel

    # Display the image
    plt.imshow(rgb_image, origin="lower")
    plt.axis("off")  # Hide the axis
    plt.show()

In [None]:
for i in names:
    plotter(i)

In [None]:
odd_one = collection.get(ids=["39921222800117211"], include=["embeddings"])

In [None]:
nearest_to_odd_one = collection.query(query_embeddings=odd_one["embeddings"], n_results=5)
print(f"Number of neighbors: {nearest_to_odd_one['ids']}")

In [None]:
names = set()
for anom_object_id in nearest_to_odd_one["ids"][0]:
    found_files = glob.glob(
        f"/home/drew/code/fibad/docs/notebooks/data/hsc_example/hsc_8asec_1000/{anom_object_id}*.fits"
    )
    for f in found_files:
        print(f)
        names.add(f[:-11])

print(names)

In [None]:
for i in names:
    plotter(i)