# Deep Learning for Cone Cells and Light Frequency

In [12]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm


np.set_printoptions(suppress=True)

import sklearn.metrics as metrics

import tensorflow as tf
from tensorflow import keras

import pickle

from simulator_lib import *

## Generating the Simulation Data

In [None]:
TOTAL_NUM_CELLS = 10000
NUM_RED = int(TOTAL_NUM_CELLS * PERCENT_RED)
NUM_GREEN = int(TOTAL_NUM_CELLS * PERCENT_GREEN)
NUM_BLUE = int(TOTAL_NUM_CELLS * PERCENT_BLUE)
assert(NUM_RED + NUM_GREEN + NUM_BLUE == TOTAL_NUM_CELLS)

# create the cone cells
cone_cells = (
    generate_cone_cells(NUM_RED, RED_MU, RED_SIGMA) +
    generate_cone_cells(NUM_GREEN, GREEN_MU, GREEN_SIGMA) +
    generate_cone_cells(NUM_BLUE, BLUE_MU, BLUE_SIGMA)
)

# save the cone cells (uncomment to generate new cone cells)
#with open("./data/cone_cells_normal.pkl", "wb") as fp:
#    pickle.dump(cone_cells, fp, pickle.HIGHEST_PROTOCOL)

## Loading the Simulation Data

## Feedforward Network Architecture

### Model Setup

In [None]:
model = keras.Sequential([
    keras.Input(shape=(train_cone_cells.shape[-1], )),
    keras.layers.Dense(units=128, activation="silu"),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(units=128, activation="silu"),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(units=128, activation="silu"),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(units=128, activation="silu"),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(units=WAVELENGTH_BINS, activation="softmax"),
])

### Compile Model

In [None]:
optimizer = keras.optimizers.AdamW(learning_rate=5e-4, weight_decay=0.004)
model.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

model.summary()

### Fit the Model

In [None]:
EPOCHS = 30
h = model.fit(train_cone_cells, train_wavelengths, batch=64, epochs=EPOCHS, validation_split=0.1)

### Visualize Loss Trajectory

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
ax.plot(range(EPOCHS), h.history['val_loss'], 'o-', color='maroon', label='Validation')
ax.plot(range(EPOCHS), h.history['loss'], 'o-', color='black', label='Training')
ax.set_xlabel('Epoch')
ax.set_ylabel('Cross-entropy')
ax.legend()
ax.set_title('Loss Trajectory')
sns.despine(ax=ax)

### Evaluate Model on Test Data

In [None]:
test_loss, test_accuracy = model.evaluate(test_cone_cells, test_wavelengths)

In [None]:
preds = model.predict(test_images)
metrics.ConfusionMatrixDisplay.from_predictions(test_labels, preds.argmax(axis=1))