# Training the CNN Model for the 2D Spatial Components (Keras Version)

This notebook will help to demonstrate how to train the CNN Model used in CaImAn to evaluate the shape of (2p) spatial components using the Keras API.

The basic function for this is caiman.train.train_cnn_model_keras.keras_cnn_model(). It takes it the number of classes to build of a  CNN model (based on a tutorial on the CIFAR dataset). The other functions, caiman.train.train_cnn_model.data_generation(), takes as input the model, the training and validation datasets, and the parameters for the model to train the model. caiman.train.train_cnn_model_keras.save_model() and caiman.train.train_cnn_model_keras.load_model() save and retrieve the model and weights of the model. 

In [1]:
import numpy as np
import os
import keras 
from keras.layers import Input, Conv2D, Activation, MaxPooling2D, Dropout, Flatten, Dense 
from keras.models import save_model, load_model 
from sklearn.model_selection import train_test_split
from sklearn.utils import class_weight as cw

import caiman as cm
from caiman.paths import caiman_datadir
from caiman.train.train_cnn_model_helper import cnn_model_keras, save_model_keras, load_model_keras

os.environ["KERAS_BACKEND"] = "torch"

2024-08-06 20:34:40.739703: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-08-06 20:34:40.770178: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


## Initalizing the Parameters for the Model 

In [2]:
batch_size = 128
num_classes = 2
epochs = 1000 #Can be upgraded to 5000
test_fraction = 0.25
augmentation = False 
img_rows, img_cols = 50, 50 #input image dimensions

#Note: Augmentation is currently not working 

## Loading the Dataset of the Model 

In [3]:
with np.load('/mnt/ceph/data/neuro/caiman/data_minions/ground_truth_components_curated_minions.npz') as ld:
    all_masks_gt = ld['all_masks_gt']
    labels_gt = ld['labels_gt_cur']

## Constructing the Training and Validation Set for the Model 

In [4]:
x_train, x_test, y_train, y_test = train_test_split(
all_masks_gt, labels_gt, test_size=test_fraction)

# class_weight = cw.compute_class_weight(class_weight='balanced', classes=np.unique(y_train), y=y_train)

if keras.config.image_data_format() == 'channels_first':
    x_train = x_train.reshape(x_train.shape[0], 1, img_rows, img_cols)
    x_test = x_test.reshape(x_test.shape[0], 1, img_rows, img_cols)
    input_shape = (1, img_rows, img_cols)
else:
    x_train = x_train.reshape(x_train.shape[0], img_rows, img_cols, 1)
    x_test = x_test.reshape(x_test.shape[0], img_rows, img_cols, 1)
    input_shape = (img_rows, img_cols, 1)
    
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

# convert class vectors to binary class matrices
y_train = keras.utils.to_categorical(y_train, num_classes)
y_test = keras.utils.to_categorical(y_test, num_classes)

x_train shape: (6771, 50, 50, 1)
6771 train samples
2257 test samples


## Build and Evaluate the Model 

In [7]:
model = cnn_model_keras(input_shape, num_classes)

model.compile(loss=keras.losses.categorical_crossentropy,
                    optimizer=keras.optimizers.Adam(learning_rate=0.01), 
                    metrics=['accuracy'])
    
# cnn_model_cifar = data_generation(cnn_model_cifar, augmentation, x_train, x_test, y_train, y_test, batch_size, epochs, class_weight)  
#Augmentation does not work!!!
model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              verbose=1,
              validation_data=(x_test, y_test))

score = model.evaluate(x_test, y_test, verbose=0)
print('Test loss:', score[0])
print('Test accuracy:', score[1])
# Need to fix 

Epoch 1/1000
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m27s[0m 481ms/step - accuracy: 0.5579 - loss: nan - val_accuracy: 0.5950 - val_loss: nan
Epoch 2/1000
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m40s[0m 460ms/step - accuracy: 0.5808 - loss: nan - val_accuracy: 0.5950 - val_loss: nan
Epoch 3/1000
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 460ms/step - accuracy: 0.5753 - loss: nan - val_accuracy: 0.5950 - val_loss: nan
Epoch 4/1000
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m41s[0m 467ms/step - accuracy: 0.5778 - loss: nan - val_accuracy: 0.5950 - val_loss: nan
Epoch 5/1000
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m43s[0m 505ms/step - accuracy: 0.5852 - loss: nan - val_accuracy: 0.5950 - val_loss: nan
Epoch 6/1000
[1m53/53[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m39s[0m 468ms/step - accuracy: 0.5814 - loss: nan - val_accuracy: 0.5950 - val_loss: nan
Epoch 7/1000
[1m53/53[0m [32m━━━━━━━━━━━━━━

KeyboardInterrupt: 

## Save the Model and its weights

In [11]:
save_model_path = save_model_keras(model, name='cnn_model_test')

Saved trained model at /mnt/home/mpaez/caiman_data/model/cnn_model_test.keras 


## Visualize Results

In [12]:
predictions = model.predict(all_masks_gt, batch_size=32, verbose=1)
cm.movie(np.squeeze(all_masks_gt[np.where(predictions[:, 0] >= 0.5)[0]])).play(
    gain=3., magnification=5, fr=10)

[1m283/283[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 25ms/step


  return np.nanmean(a, axis, out=out, keepdims=keepdims)


## Retrieve the Model and its weights

In [14]:
loaded_model = load_model_keras(save_model_path)
loaded_model.summary()

Load trained model at /mnt/home/mpaez/caiman_data/model/cnn_model_test.keras 


## Visualize Results 

In [15]:
predictions = loaded_model.predict(all_masks_gt, batch_size=32, verbose=1)
cm.movie(np.squeeze(all_masks_gt[np.where(predictions[:, 0] >= 0.5)[0]])).play(
    gain=3., magnification=5, fr=10)

[1m283/283[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m7s[0m 25ms/step
