In [None]:
#DEMO TRAIN CNN PARAMETERS

In [1]:
# Imports and GPU detection
from typing import Any
import sys
import os
import random
import time
import glob
import numpy as np
import math
import h5py
from scipy.io import savemat, loadmat
import multiprocessing as mp
import warnings

# Suppress sklearn NMF alpha deprecation warning
warnings.filterwarnings(
    "ignore",
    message=r".*`alpha` was deprecated in version 1\.0.*",
    category=FutureWarning,
    module=r"sklearn\.decomposition\._nmf",
)

# Ensure the path contains the "suns" folder (prefer repo copy)
REPO_ROOT = os.path.abspath(os.path.join(os.getcwd(), '../..'))
if REPO_ROOT not in sys.path:
    sys.path.insert(0, REPO_ROOT)
# Deprioritize any older non-repo install to avoid wrong config
ALT_PATH = '/gpfs/data/shohamlab/nicole/code/SUNS_nicole'
if ALT_PATH in sys.path:
    sys.path.remove(ALT_PATH)
    sys.path.append(ALT_PATH)

# Backend and device selection
os.environ['KERAS_BACKEND'] = 'tensorflow'
os.environ['CUDA_VISIBLE_DEVICES'] = '0'  # Set which GPU to use. '-1' uses only CPU.

# Import config and core SUNS modules (evict any previously cached 'suns' modules)
for _k in [k for k in list(sys.modules) if k == 'suns' or k.startswith('suns.')]:
    sys.modules.pop(_k, None)
from suns import config
print("importing config")
print("config module path:", getattr(config, '__file__', 'unknown'))

from suns.PreProcessing.preprocessing_functions import preprocess_video, find_dataset
import importlib as _importlib
_gm = _importlib.import_module('suns.PreProcessing.generate_masks')
print("generate_masks module path:", getattr(_gm, '__file__', 'unknown'))
from suns.PreProcessing.generate_masks import generate_masks
from suns.train_CNN_params import train_CNN, parameter_optimization_cross_validation

# Patch FISSA save_to_matlab to ignore ragged-array save errors
try:
    import fissa
    try:
        from fissa.core import Experiment as _FissaExperiment
        _orig_save_to_mat = _FissaExperiment.save_to_matlab
        def _safe_save_to_matlab(self, *args, **kwargs):
            try:
                return _orig_save_to_mat(self, *args, **kwargs)
            except Exception as _e:
                print(f"[WARN] Ignoring FISSA save_to_matlab error: {_e}")
                return None
        _FissaExperiment.save_to_matlab = _safe_save_to_matlab
        print("FISSA save_to_matlab patched to be resilient.")
    except Exception as _pe:
        print("FISSA patch failed:", _pe)
except Exception:
    pass

# TensorFlow GPU setup and sanity check
import tensorflow as tf

tf_version = int(tf.__version__[0])
if tf_version == 1:
    tf_config = tf.ConfigProto()
    tf_config.gpu_options.allow_growth = True
    sess = tf.Session(config=tf_config)
else:  # TensorFlow 2.x
    gpus = tf.config.list_physical_devices('GPU')
    for gpu in gpus:
        tf.config.experimental.set_memory_growth(gpu, True)

    # ---- GPU visibility and quick verification ----
    print(f"TensorFlow version: {tf.__version__}")
    print("Visible GPUs:", gpus)
    if gpus:
        try:
            for g in gpus:
                det = tf.config.experimental.get_device_details(g)
                name = det.get('device_name', 'Unknown GPU')
                cc = det.get('compute_capability', 'n/a')
                print(f"  - {g.name} | {name} | CC={cc}")
        except Exception:
            pass
        # Small op to confirm /GPU:0 executes
        try:
            with tf.device('/GPU:0'):
                a = tf.random.uniform((1024, 1024))
                b = tf.random.uniform((1024, 1024))
                _ = tf.matmul(a, b)
            print('GPU sanity check: OK (matmul on /GPU:0)')
        except Exception as e:
            print('GPU sanity check failed, training may run on CPU:', repr(e))
    else:
        print('No GPU detected by TensorFlow; training will run on CPU.')

# Optional: enable device placement logging (verbose)
LOG_DEVICE_PLACEMENT = False
if LOG_DEVICE_PLACEMENT and tf_version != 1:
    try:
        tf.debugging.set_log_device_placement(True)
    except Exception:
        pass


importing config
importing config
config module path: /gpfs/data/shohamlab/nicole/code/SUNS_nicole/suns/config.py
generate_masks module path: /gpfs/data/shohamlab/nicole/code/SUNS_nicole/suns/PreProcessing/generate_masks.py
FISSA save_to_matlab patched to be resilient.
TensorFlow version: 2.12.1
Visible GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
  - /physical_device:GPU:0 | Tesla V100-SXM2-16GB | CC=(7, 0)
GPU sanity check: OK (matmul on /GPU:0)


In [29]:
# Select experiment set (load repo config explicitly)
import importlib.util, sys, os

REPO_CONFIG = '/gpfs/home/bizzin01/nicole/code/SUNS_nicole_git/Shallow-UNet-Neuron-Segmentation_SUNS/suns/config.py'
spec = importlib.util.spec_from_file_location('suns_repo_config', REPO_CONFIG)
config = importlib.util.module_from_spec(spec)
spec.loader.exec_module(config)
print('Loaded config from:', getattr(config, '__file__', 'unknown'))

TARGET_SET = '4video_mouse7'
if (TARGET_SET not in getattr(config, 'EXP_ID_SETS', {})) or (TARGET_SET not in getattr(config, 'DATAFOLDER_SETS', {})):
    LEGACY = '4video mouse7'
    if (LEGACY in getattr(config, 'EXP_ID_SETS', {})) and (LEGACY in getattr(config, 'DATAFOLDER_SETS', {})):
        TARGET_SET = LEGACY

config.ACTIVE_EXP_SET = TARGET_SET
print('ACTIVE_EXP_SET set to:', config.ACTIVE_EXP_SET)


importing config
Loaded config from: /gpfs/home/bizzin01/nicole/code/SUNS_nicole_git/Shallow-UNet-Neuron-Segmentation_SUNS/suns/config.py
ACTIVE_EXP_SET set to: 4video_mouse7


In [21]:
# Dataset IDs and directories

list_Exp_ID = config.EXP_ID_SETS[config.ACTIVE_EXP_SET]
dir_video = config.DATAFOLDER_SETS[config.ACTIVE_EXP_SET]
# folder of the ".mat" files storing the GT masks in sparse 2D matrices. 'FinalMasks_' is a prefix of the file names.
dir_GTMasks = os.path.join(dir_video, 'GT Masks', 'FinalMasks_')

print("ACTIVE_EXP_SET:", config.ACTIVE_EXP_SET)
print("list_Exp_ID:", list_Exp_ID)
print("dir_video:", dir_video)
print("dir_GTMasks:", dir_GTMasks)


ACTIVE_EXP_SET: 4video_mouse7
list_Exp_ID: ['mouse7_773', 'mouse7_774', 'mouse7_775', 'mouse7_776']
dir_video: /gpfs/home/bizzin01/nicole/code/SUNS_nicole/demo/4video mouse7
dir_GTMasks: /gpfs/home/bizzin01/nicole/code/SUNS_nicole/demo/4video mouse7/GT Masks/FinalMasks_


In [22]:
# Video parameters: frame rate and magnification
rate_hz = config.RATE_HZ[config.ACTIVE_EXP_SET]
Mag = config.MAG[config.ACTIVE_EXP_SET]

print("rate_hz:", rate_hz)
print("Mag:", Mag)


rate_hz: 3.56
Mag: 0.399


In [23]:
# Pre-processing parameters
# Spatial Gaussian filter size (in pixels)
gauss_filt_size = 50 * Mag
# Number of frames used to calculate median and median-based std
num_median_approx = 1000

# Temporal filter kernel using a single exponential decay function
# 6f: 0.8, 6s: 1.25 (unit: second)
decay = 1.25
leng_tf = np.ceil(rate_hz * decay) + 1
Poisson_filt = np.exp(-np.arange(leng_tf) / rate_hz / decay)
Poisson_filt = (Poisson_filt / Poisson_filt.sum()).astype('float32')

print("gauss_filt_size:", gauss_filt_size)
print("num_median_approx:", num_median_approx)
print("Poisson_filt length:", Poisson_filt.size)


gauss_filt_size: 19.950000000000003
num_median_approx: 1000
Poisson_filt length: 6


In [24]:
# Training parameters
thred_std = 3 # SNR threshold to determine when neurons are active
num_train_per = 1200  # 75% of 1600 frames, Number of frames per video used for training
NO_OF_EPOCHS = 120  # Number of epochs used for training
batch_size_eval = 100  # Batch size in CNN inference
list_thred_ratio = [thred_std]  # A list of SNR threshold values

print("thred_std:", thred_std)
print("num_train_per:", num_train_per)
print("NO_OF_EPOCHS:", NO_OF_EPOCHS)
print("batch_size_eval:", batch_size_eval)
print("list_thred_ratio:", list_thred_ratio)


thred_std: 3
num_train_per: 1200
NO_OF_EPOCHS: 120
batch_size_eval: 100
list_thred_ratio: [3]


In [25]:
# Processing options
useSF = False  # Spatial filtering in pre-processing
useTF = True   # Temporal filtering in pre-processing
useSNR = True  # Pixel-by-pixel SNR normalization filtering in pre-processing
med_subtract = False  # Subtract spatial median before temporal filtering (only if no spatial filtering)
prealloc = False  # Pre-allocate memory (faster, higher memory usage). Not needed for training.
useWT = False  # Use additional watershed
load_exist = False  # Use temp files already saved in folders
use_validation = True  # Use validation set outside the training set
useMP = False  # Use multiprocessing to speed up
BATCH_SIZE = 20  # Batch size for training

# Cross-validation strategy: "leave_one_out", "train_1_test_rest", or "use_all"
cross_validation = "leave_one_out"

# Parameters of the loss function
Params_loss = {"DL": 1, "BCE": 20, "FL": 0, "gamma": 1, "alpha": 0.25}

print("Options:", {
    'useSF': useSF,
    'useTF': useTF,
    'useSNR': useSNR,
    'med_subtract': med_subtract,
    'prealloc': prealloc,
    'useWT': useWT,
    'load_exist': load_exist,
    'use_validation': use_validation,
    'useMP': useMP,
    'BATCH_SIZE': BATCH_SIZE,
    'cross_validation': cross_validation,
    'Params_loss': Params_loss,
})


Options: {'useSF': False, 'useTF': True, 'useSNR': True, 'med_subtract': False, 'prealloc': False, 'useWT': False, 'load_exist': False, 'use_validation': True, 'useMP': False, 'BATCH_SIZE': 20, 'cross_validation': 'leave_one_out', 'Params_loss': {'DL': 1, 'BCE': 20, 'FL': 0, 'gamma': 1, 'alpha': 0.25}}


In [None]:
# CNN training and parameter optimization
import os
import h5py
import numpy as np

# ---- Output directories ----
dir_parent = os.path.join(dir_video, config.OUTPUT_FOLDER[config.ACTIVE_EXP_SET])
dir_network_input = os.path.join(dir_parent, 'network_input')
dir_mask = os.path.join(dir_parent, f'temporal_masks({thred_std})')
weights_path = os.path.join(dir_parent, 'Weights')
training_output_path = os.path.join(dir_parent, 'training output')
dir_output = os.path.join(dir_parent, 'output_masks')
dir_temp = os.path.join(dir_parent, 'temp')

for d in [dir_network_input, weights_path, training_output_path, dir_output, dir_temp]:
    if not os.path.exists(d):
        os.makedirs(d)

# ---- Get and check video dimensions ----
nvideo = len(list_Exp_ID)
list_Dimens = np.zeros((nvideo, 3), dtype='uint16')
for (eid, Exp_ID) in enumerate(list_Exp_ID):
    h5_video = os.path.join(dir_video, Exp_ID + '.h5')
    h5_file = h5py.File(h5_video, 'r')
    dset = find_dataset(h5_file)
    list_Dimens[eid] = h5_file[dset].shape
    h5_file.close()

nframes = np.unique(list_Dimens[:, 0])
Lx = np.unique(list_Dimens[:, 1])
Ly = np.unique(list_Dimens[:, 2])
if len(Lx) * len(Ly) != 1:
    raise ValueError('The lateral dimensions of all the training videos must be the same in this version.')

nframes = nframes.min()
rows = Lx[0]
cols = Ly[0]

rowspad = int(np.ceil(rows / 8) * 8)
colspad = int(np.ceil(cols / 8) * 8)
num_total = int(nframes - Poisson_filt.size + 1)

# ---- Post-processing hyper-parameters to optimize ----
list_minArea = list(range(5, 25, 5))   
list_avgArea = [30]
list_thresh_pmap = list(range(110, 190, 10))
thresh_mask = 0.5
thresh_COM0 = 0.8
list_thresh_COM = list(np.arange(2, 5, 0.5)) 
list_thresh_IOU = [0.5]
list_cons = list(range(1, 4, 1)) 

# # Adjust units according to magnification and frame rate differences
# list_minArea = list(np.round(np.array(list_minArea) * Mag ** 2))
# list_avgArea = list(np.round(np.array(list_avgArea) * Mag ** 2))
# thresh_COM0 = thresh_COM0 * Mag
# list_thresh_COM = list(np.array(list_thresh_COM) * Mag)
# Optionally adjust list_cons for different frame rates (kept as-is, like script)

# ---- Pack parameter dictionaries ----
Params_pre = {
    'gauss_filt_size': gauss_filt_size,
    'num_median_approx': num_median_approx,
    'Poisson_filt': Poisson_filt,
}
Params_set = {
    'list_minArea': list_minArea,
    'list_avgArea': list_avgArea,
    'list_thresh_pmap': list_thresh_pmap,
    'thresh_COM0': thresh_COM0,
    'list_thresh_COM': list_thresh_COM,
    'list_thresh_IOU': list_thresh_IOU,
    'thresh_mask': thresh_mask,
    'list_cons': list_cons,
}
print("Params_set:", Params_set)

# ---- Pre-processing for training ----
for Exp_ID in list_Exp_ID:
    # Pre-process video
    video_input, _ = preprocess_video(
        dir_video,
        Exp_ID,
        Params_pre,
        dir_network_input,
        useSF=useSF,
        useTF=useTF,
        useSNR=useSNR,
        med_subtract=med_subtract,
        prealloc=prealloc,
    )

    # Determine active neurons in all frames using FISSA (debug and validation inside generate_masks)
    file_mask = dir_GTMasks + Exp_ID + '.mat'
    try:
        import fissa  # type: ignore
    except Exception:
        pass

    generate_masks(video_input, file_mask, list_thred_ratio, dir_parent, Exp_ID, verbose=False)
    del video_input

# ---- CNN training (cross-validation) ----
if cross_validation == "use_all":
    list_CV = [nvideo]
else:
    list_CV = list(range(0, nvideo))

for CV in list_CV:
    if cross_validation == "leave_one_out":
        list_Exp_ID_train = list_Exp_ID.copy()
        list_Exp_ID_val = [list_Exp_ID_train.pop(CV)]
    elif cross_validation == "train_1_test_rest":
        list_Exp_ID_val = list_Exp_ID.copy()
        list_Exp_ID_train = [list_Exp_ID_val.pop(CV)]
    elif cross_validation == "use_all":
        use_validation = False
        list_Exp_ID_train = list_Exp_ID.copy()
    else:
        raise RuntimeError('wrong "cross_validation"')

    if not use_validation:
        list_Exp_ID_val = None

    file_CNN = os.path.join(weights_path, f'Model_CV{CV}.h5')
    results = train_CNN(
        dir_network_input,
        dir_mask,
        file_CNN,
        list_Exp_ID_train,
        list_Exp_ID_val,
        BATCH_SIZE,
        NO_OF_EPOCHS,
        num_train_per,
        num_total,
        (rowspad, colspad),
        Params_loss,
    )

    # Save training and validation loss after each epoch
    with h5py.File(os.path.join(training_output_path, f"training_output_CV{CV}.h5"), "w") as f:
        f.create_dataset("loss", data=results.history['loss'])
        f.create_dataset("dice_loss", data=results.history['dice_loss'])
        if use_validation:
            f.create_dataset("val_loss", data=results.history['val_loss'])
            f.create_dataset("val_dice_loss", data=results.history['val_dice_loss'])

# ---- Parameter optimization across CV ----
parameter_optimization_cross_validation(
    cross_validation,
    list_Exp_ID,
    Params_set,
    (rows, cols),
    dir_network_input,
    weights_path,
    dir_GTMasks,
    dir_temp,
    dir_output,
    batch_size_eval,
    useWT=useWT,
    useMP=useMP,
    load_exist=load_exist,
)

print("Pipeline complete.")


Params_set: {'list_minArea': [5, 10, 15, 20], 'list_avgArea': [30], 'list_thresh_pmap': [110, 120, 130, 140, 150, 160, 170, 180], 'thresh_COM0': 0.8, 'list_thresh_COM': [2.0, 2.5, 3.0, 3.5, 4.0, 4.5], 'list_thresh_IOU': [0.5], 'thresh_mask': 0.5, 'list_cons': [1, 2, 3]}
Initialization: 0.0003428459167480469 s
data loading: 161.70623874664307 s
temporal filtering: 1.2380571365356445 s
median computation: 3.7171056270599365 s
normalization: 0.0973351001739502 s
total per frame: 3.157958686351776 ms
Network_input saving: 0.1064610481262207 s
[DEBUG] FISSA import OK; version=0.7.2
[DEBUG] Calling generate_masks for Exp_ID=mouse7_773 with masks=/gpfs/home/bizzin01/nicole/code/SUNS_nicole/demo/4video mouse7/GT Masks/FinalMasks_mouse7_773.mat
Reloading previously prepared data...
Reloading previously separated data...
Doing region growing and data extraction....
Reloading previously prepared data...
Doing signal separation....
NMF converged after 649 iterations.NMF converged after 702 iterati

ValueError: setting an array element with a sequence. The requested array has an inhomogeneous shape after 2 dimensions. The detected shape was (5, 1) + inhomogeneous part.