# Deep Learning for Cone Cells and Light Frequency

In [3]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm


np.set_printoptions(suppress=True)

import sklearn.metrics as metrics

import tensorflow as tf
from tensorflow import keras

import pickle

from simulator_lib import *

## Generating / Loading the Cone Cells

In [None]:
GENERATE_NEW_CONE_CELLS = False

if GENERATE_NEW_CONE_CELLS:

    TOTAL_NUM_CELLS = 10000
    NUM_RED = int(TOTAL_NUM_CELLS * PERCENT_RED)
    NUM_GREEN = int(TOTAL_NUM_CELLS * PERCENT_GREEN)
    NUM_BLUE = int(TOTAL_NUM_CELLS * PERCENT_BLUE)
    assert(NUM_RED + NUM_GREEN + NUM_BLUE == TOTAL_NUM_CELLS)

    # create the cone cells
    cone_cells = (
        generate_cone_cells(NUM_RED, RED_MU, RED_SIGMA) +
        generate_cone_cells(NUM_GREEN, GREEN_MU, GREEN_SIGMA) +
        generate_cone_cells(NUM_BLUE, BLUE_MU, BLUE_SIGMA)
    )

    # save the cone cells
    with open("./data/normal_cone_cells.pkl", "wb") as fp:
        pickle.dump(cone_cells, fp, pickle.HIGHEST_PROTOCOL)

else:

    with open("./data/normal_cone_cells.pkl", mode="rb") as fp:
        cone_cells = pickle.load(fp)
    TOTAL_NUM_CELLS = len(cone_cells)

print(TOTAL_NUM_CELLS)
print(cone_cells)

10000
[<simulator_lib.ConeCell object at 0x7889037627e0>, <simulator_lib.ConeCell object at 0x78888a457860>, <simulator_lib.ConeCell object at 0x78888a49d5e0>, <simulator_lib.ConeCell object at 0x78888a49cf20>, <simulator_lib.ConeCell object at 0x78888a49cd70>, <simulator_lib.ConeCell object at 0x78888a49c770>, <simulator_lib.ConeCell object at 0x78888a49d4f0>, <simulator_lib.ConeCell object at 0x78888a49ce00>, <simulator_lib.ConeCell object at 0x78888a49cf80>, <simulator_lib.ConeCell object at 0x78888a49d0a0>, <simulator_lib.ConeCell object at 0x78888a49d070>, <simulator_lib.ConeCell object at 0x78888a49ce60>, <simulator_lib.ConeCell object at 0x78888a49cdd0>, <simulator_lib.ConeCell object at 0x78888a49d130>, <simulator_lib.ConeCell object at 0x78888a49d160>, <simulator_lib.ConeCell object at 0x78888a49d040>, <simulator_lib.ConeCell object at 0x78888a49e420>, <simulator_lib.ConeCell object at 0x78888a49c0e0>, <simulator_lib.ConeCell object at 0x78888a49d1c0>, <simulator_lib.ConeCell 

## Generating / Loading the Simulation Data

In [10]:
GENERATE_NEW_SIMULATION_DATA = False

if GENERATE_NEW_SIMULATION_DATA:

    NUM_DATA_POINTS = 1000
    MIN_WAVELENGTH = 380
    MAX_WAVELENGTH = 750

    data, colors = sample_wavelengths(
        num_data_points=NUM_DATA_POINTS, 
        cells=cone_cells, 
        min_wl=MIN_WAVELENGTH, 
        max_wl=MAX_WAVELENGTH,
    )

    # save the simulation data
    with open("./data/normal_simulation_data.pkl", "wb") as fp:
        pickle.dump((data, colors), fp, pickle.HIGHEST_PROTOCOL)

else:

    with open("./data/normal_simulation_data.pkl", "rb") as fp:
        data, colors = pickle.load(fp)

print(data)
print(colors)

[[0. 0. 0. ... 0. 0. 0.]
 [0. 0. 1. ... 0. 0. 0.]
 [0. 0. 1. ... 0. 0. 0.]
 ...
 [0. 0. 0. ... 0. 1. 0.]
 [0. 1. 1. ... 0. 0. 0.]
 [0. 0. 0. ... 0. 0. 0.]]
[<Color.RED: 5>, <Color.ORANGE: 4>, <Color.ORANGE: 4>, <Color.VIOLET: 0>, <Color.RED: 5>, <Color.YELLOW: 3>, <Color.RED: 5>, <Color.RED: 5>, <Color.RED: 5>, <Color.RED: 5>, <Color.GREEN: 2>, <Color.RED: 5>, <Color.VIOLET: 0>, <Color.RED: 5>, <Color.VIOLET: 0>, <Color.VIOLET: 0>, <Color.BLUE: 1>, <Color.VIOLET: 0>, <Color.GREEN: 2>, <Color.GREEN: 2>, <Color.RED: 5>, <Color.RED: 5>, <Color.ORANGE: 4>, <Color.VIOLET: 0>, <Color.VIOLET: 0>, <Color.VIOLET: 0>, <Color.GREEN: 2>, <Color.BLUE: 1>, <Color.GREEN: 2>, <Color.BLUE: 1>, <Color.GREEN: 2>, <Color.BLUE: 1>, <Color.VIOLET: 0>, <Color.RED: 5>, <Color.VIOLET: 0>, <Color.RED: 5>, <Color.BLUE: 1>, <Color.GREEN: 2>, <Color.RED: 5>, <Color.VIOLET: 0>, <Color.BLUE: 1>, <Color.GREEN: 2>, <Color.VIOLET: 0>, <Color.YELLOW: 3>, <Color.RED: 5>, <Color.RED: 5>, <Color.BLUE: 1>, <Color.VIOLET: 0>

## Feedforward Network Architecture

### Model Setup

In [None]:
model = keras.Sequential([
    keras.Input(shape=(train_cone_cells.shape[-1], )),
    keras.layers.Dense(units=128, activation="silu"),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(units=128, activation="silu"),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(units=128, activation="silu"),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(units=128, activation="silu"),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(units=WAVELENGTH_BINS, activation="softmax"),
])

### Compile Model

In [None]:
optimizer = keras.optimizers.AdamW(learning_rate=5e-4, weight_decay=0.004)
model.compile(
    optimizer=optimizer,
    loss="sparse_categorical_crossentropy",
    metrics=["accuracy"],
)

model.summary()

### Fit the Model

In [None]:
EPOCHS = 30
h = model.fit(train_cone_cells, train_wavelengths, batch=64, epochs=EPOCHS, validation_split=0.1)

### Visualize Loss Trajectory

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(10, 6))
ax.plot(range(EPOCHS), h.history['val_loss'], 'o-', color='maroon', label='Validation')
ax.plot(range(EPOCHS), h.history['loss'], 'o-', color='black', label='Training')
ax.set_xlabel('Epoch')
ax.set_ylabel('Cross-entropy')
ax.legend()
ax.set_title('Loss Trajectory')
sns.despine(ax=ax)

### Evaluate Model on Test Data

In [None]:
test_loss, test_accuracy = model.evaluate(test_cone_cells, test_wavelengths)

In [None]:
preds = model.predict(test_images)
metrics.ConfusionMatrixDisplay.from_predictions(test_labels, preds.argmax(axis=1))