In [1]:
#---------------------------------------------------------------------------------
#                                 _             _      
#                                | |_  ___ _ __(_)__ _ 
#                                | ' \/ -_) '_ \ / _` |
#                                |_||_\___| .__/_\__,_|
#                                         |_|          
#
#---------------------------------------------------------------------------------
#
# Company: HEPIA // HES-SO
# Engineer: Hugo Varenne <hugo.varenne@master.hes-so.ch>
# 
# Project Name: Unleashing the Full Potential of 
#               High-Performance Cherenkov Telescopes
#               with Fully-Digital Solid-State Sensors Camera
#
# File: 5.2_optimize_models.ipynb
# Description: Notebook for optimizing ctlearn models
#
# Last update: 2025-10-02
#
#--------------------------------------------------------------------------------

In [1]:
import sys
import os
import importlib
import glob
import shutil
import hdf5plugin, h5py
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
from ctapipe.io import EventSource
from sklearn import metrics
import importlib
import numpy as np
import time
import tensorflow_model_optimization as tfmot
import json
from ctlearn.tools.predict_model import MonoPredictCTLearnModel
from ctlearn.utils import validate_trait_dict
from tensorflow.keras import Input, Model
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
from tools.train_model import TrainCTLearnModel
from ctlearn.core.model import CTLearnModel
from ctapipe.core.traits import ComponentName
from traitlets.config import Config
import yaml
# Custom tools
tools_path = os.path.join("../tools")
if tools_path not in sys.path:
    sys.path.append(tools_path)

2025-11-26 09:51:42.438397: I tensorflow/core/util/port.cc:111] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-11-26 09:51:42.666132: E tensorflow/compiler/xla/stream_executor/cuda/cuda_dnn.cc:9342] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2025-11-26 09:51:42.666168: E tensorflow/compiler/xla/stream_executor/cuda/cuda_fft.cc:609] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-11-26 09:51:42.667209: E tensorflow/compiler/xla/stream_executor/cuda/cuda_blas.cc:1518] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2025-11-26 09:51:42.747830: I tensorflow/core/platform/cpu_feature_g

2025-11-26 09:51:48,125 | INFO | Logging initialized. All stdout/stderr will go to SLURM log.


In [2]:
# Set paths shortcuts (configurable in a yaml file)
import tools.CTLearnMgrConfig as CTLearnMgrConfig
importlib.reload(CTLearnMgrConfig)

ctlearn_mgr_config = CTLearnMgrConfig.CTLearnMgrConfig()
ctlearn_mgr_config.load_config('../config/ctlearnmgr_config.yml')

In [3]:
# Type of model you wanna create : ["type", "energy", "direction"]
RECO = "energy"

# Name of the model (should match config name) ["ResNet", "SimpleCNN", "LoadedModel"] are the types of models
loading = "ResNet"

# Custom model config path 
config_filename = f"{RECO}_{loading}_optimize.yaml"

In [4]:
# Load config file
config_path = os.path.join(ctlearn_mgr_config.workspace_path, "models", "configs", config_filename)

def recursive_config(d):
    """Recursively convert nested dicts into traitlets Configs."""
    if isinstance(d, dict):
        cfg = Config()
        for k, v in d.items():
            cfg[k] = recursive_config(v)
        return cfg
    return d
    
with open(config_path) as f:
    yaml_config = yaml.safe_load(f)
c = recursive_config(yaml_config)

In [5]:
# check model CTLearn
prepare_config = c.prepare_model
model = CTLearnModel.from_name(
    prepare_config.model_type, 
    input_shape=tuple(prepare_config.input_shape), 
    tasks=prepare_config.tasks, 
    config=prepare_config
    ).model

model.summary()

Model: "CTLearn_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input (InputLayer)          [(None, 96, 96, 2)]       0         
                                                                 
 ThinResNet_block (Function  (None, 1024)              5357600   
 al)                                                             
                                                                 
 fc_energy_1 (Dense)         (None, 512)               524800    
                                                                 
 fc_energy_2 (Dense)         (None, 256)               131328    
                                                                 
 energy (Dense)              (None, 1)                 257       
                                                                 
Total params: 6013985 (22.94 MB)
Trainable params: 6013985 (22.94 MB)
Non-trainable params: 0 (0.00 Byte)
_____________

In [6]:
submodel = model.get_layer("ThinResNet_block")
submodel.summary()

Model: "ThinResNet_block"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input (InputLayer)          [(None, 96, 96, 2)]          0         []                            
                                                                                                  
 ThinResNet_block_conv2_blo  (None, 96, 96, 48)           144       ['input[0][0]']               
 ck1_1_conv (Conv2D)                                                                              
                                                                                                  
 ThinResNet_block_conv2_blo  (None, 96, 96, 48)           20784     ['ThinResNet_block_conv2_block
 ck1_2_conv (Conv2D)                                                1_1_conv[0][0]']              
                                                                                   

In [7]:
# Save pruned model
# model_for_pruning.save(prepare_config.temp_dir)
model.save(prepare_config.temp_dir)

INFO:tensorflow:Assets written to: /home/hugo/TM/ml/models/energy/optimize/temp/assets
2025-11-19 08:44:32,318 | INFO | Assets written to: /home/hugo/TM/ml/models/energy/optimize/temp/assets


### Quantization - (old)


def flatten_functional_model(model):
    """
    Rebuild a Keras Functional model in a flat way,
    preserving skip connections, branching, Add/Concat,
    and handling unhashable KerasTensors.
    """
    if not isinstance(model, tf.keras.Model):
        raise ValueError("Model must be a Keras Model")
    
    # 1) Create new input tensors
    if isinstance(model.input, list):
        new_inputs = [tf.keras.Input(shape=t.shape[1:], name=f"flat_input_{i}") for i, t in enumerate(model.input)]
        tensor_map = {t.ref(): new for t, new in zip(model.input, new_inputs)}
    else:
        new_inputs = tf.keras.Input(shape=model.input.shape[1:], name="flat_input")
        tensor_map = {model.input.ref(): new_inputs}

    # 2) Iterate layers in order
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.InputLayer):
            continue  # already handled

        # Map inputs
        layer_inputs = layer.input
        if isinstance(layer_inputs, list):
            mapped_inputs = [tensor_map[t.ref()] for t in layer_inputs]
        else:
            mapped_inputs = tensor_map[layer_inputs.ref()]

        # Call layer
        layer_outputs = layer(mapped_inputs)

        # Map outputs
        if isinstance(layer_outputs, list):
            for old, new in zip(layer.output, layer_outputs):
                tensor_map[old.ref()] = new
        else:
            tensor_map[layer.output.ref()] = layer_outputs

    # 3) Map outputs
    if isinstance(model.output, list):
        new_outputs = [tensor_map[t.ref()] for t in model.output]
    else:
        new_outputs = tensor_map[model.output.ref()]

    # 4) Build new model
    return tf.keras.Model(inputs=new_inputs, outputs=new_outputs)


def flatten_functional_model(model):
    """
    Rebuild a Keras Functional model in a flat way,
    preserving skip connections, branching, Add/Concat,
    and handling unhashable KerasTensors.
    """
    if not isinstance(model, tf.keras.Model):
        raise ValueError("Model must be a Keras Model")
    
    # 1) Create new input tensors
    if isinstance(model.input, list):
        new_inputs = [tf.keras.Input(shape=t.shape[1:], name=f"flat_input_{i}") for i, t in enumerate(model.input)]
        tensor_map = {t.ref(): new for t, new in zip(model.input, new_inputs)}
    else:
        new_inputs = tf.keras.Input(shape=model.input.shape[1:], name="flat_input")
        tensor_map = {model.input.ref(): new_inputs}

    # 2) Iterate layers in order
    for layer in model.layers:
        if isinstance(layer, tf.keras.layers.InputLayer):
            continue

        # Map inputs
        layer_inputs = layer.input
        if isinstance(layer_inputs, list):
            mapped_inputs = [tensor_map[t.ref()] for t in layer_inputs]
        else:
            mapped_inputs = tensor_map[layer_inputs.ref()]

        # Call layer
        layer_outputs = layer(mapped_inputs)

        # Map outputs
        if isinstance(layer_outputs, list):
            for old, new in zip(layer.output, layer_outputs):
                tensor_map[old.ref()] = new
        else:
            tensor_map[layer.output.ref()] = layer_outputs

    # 3) Map outputs
    if isinstance(model.output, list):
        new_outputs = [tensor_map[t.ref()] for t in model.output]
    else:
        new_outputs = tensor_map[model.output.ref()]

    # 4) Build new model
    return tf.keras.Model(inputs=new_inputs, outputs=new_outputs)
    
def iterate(layer):
    """
    Flatten a layer if it's a nested Functional model.
    Otherwise, return the layer as-is.
    """
    if isinstance(layer, tf.keras.Model) and not isinstance(layer, tf.keras.Sequential):
        return flatten_functional_model(layer)
    return layer

from tensorflow_model_optimization.quantization.keras import QuantizeWrapper

flatten_model = auto_flatten_nested_models(model.get_layer("ThinResNet_block"))

x = flatten_model(model.input)
x = model.get_layer("fc_type_1")(x)
x = model.get_layer("fc_type_2")(x)
x = model.get_layer("type")(x)
outputs = model.get_layer("softmax")(x)

new_model = Model(inputs=model.input, outputs=outputs)
import tensorflow_model_optimization as tfmot
quantized_flat_model = tfmot.quantization.keras.quantize_model(new_model)

quantized_flat_model.summary()

import tensorflow_model_optimization as tfmot
quantize_model = tfmot.quantization.keras.quantize_model
qmodel = quantize_model(flat_model)
qmodel.compile(optimizer="adam", loss="mse", metrics=["accuracy"])
qmodel.summary()

### Quantization - (with AutoQKeras)

In [7]:

from qkeras.autoqkeras import *
from qkeras import *
from qkeras.utils import model_quantize
from qkeras.qtools import run_qtools
from qkeras.qtools import settings as qtools_settings

cur_strategy = tf.distribute.get_strategy()
custom_objects = {}
quantization_config = {
        "kernel": {
                "binary": 1,
                "stochastic_binary": 1,
                "ternary": 2,
                "stochastic_ternary": 2,
                "quantized_bits(2,1,1,alpha=1.0)": 2,
                "quantized_bits(4,0,1,alpha=1.0)": 4,
                "quantized_bits(8,0,1,alpha=1.0)": 8,
                "quantized_po2(4,1)": 4
        },
        "bias": {
                "quantized_bits(4,0,1)": 4,
                "quantized_bits(8,3,1)": 8,
                "quantized_po2(4,8)": 4
        },
        "activation": {
                "binary": 1,
                "ternary": 2,
                "quantized_relu_po2(4,4)": 4,
                "quantized_relu(3,1)": 3,
                "quantized_relu(4,2)": 4,
                "quantized_relu(8,2)": 8,
                "quantized_relu(8,4)": 8,
                "quantized_relu(16,8)": 16
        },
        "linear": {
                "binary": 1,
                "ternary": 2,
                "quantized_bits(4,1)": 4,
                "quantized_bits(8,2)": 8,
                "quantized_bits(16,10)": 16
        }
}


limit = {
    "Dense": [8, 8, 4],
    "Conv2D": [4, 8, 4],
    "DepthwiseConv2D": [4, 8, 4],
    "Activation": [4],
    "BatchNormalization": []
}

goal = {
    "type": "energy",
    "params": {
        "delta_p": 8.0,
        "delta_n": 8.0,
        "rate": 2.0,
        "stress": 1.0,
        "process": "horowitz",
        "parameters_on_memory": ["sram", "sram"],
        "activations_on_memory": ["sram", "sram"],
        "rd_wr_on_io": [False, False],
        "min_sram_size": [0, 0],
        "source_quantizers": ["int8"],
        "reference_internal": "int8",
        "reference_accumulator": "int32"
        }
}

run_config = {
  "output_dir": "../temp/",
  "goal": goal,
  "quantization_config": quantization_config,
  "learning_rate_optimizer": False,
  "transfer_weights": False,
  "mode": "random",
  "seed": 42,
  "limit": limit,
  "tune_filters": "layer",
  "tune_filters_exceptions": "^dense",
  "distribution_strategy": cur_strategy,
  # first layer is input, layer two layers are softmax and flatten
  "layer_indexes": range(1, len(model.layers) - 1),
  "max_trials": 20
}

print("quantizing layers:", [model.layers[i].name for i in run_config["layer_indexes"]])

quantizing layers: ['ThinResNet_block', 'fc_energy_1', 'fc_energy_2']


In [21]:

from dl1_data_handler.reader import DLDataReader
from ctlearn.core.loader import DLDataLoader
from ctlearn.core.model import CTLearnModel
from pathlib import Path


training_config = c.training_model

input_url_signal = []
input_dir_signal = Path(training_config.TrainCTLearnModel.input_dir_signal)
input_url_background = []

file_pattern_signal = training_config.TrainCTLearnModel.file_pattern_signal

for signal_pattern in file_pattern_signal:
    input_url_signal.extend(input_dir_signal.glob(signal_pattern))

dl1dh_reader = DLDataReader.from_name(
    "DLImageReader",
    input_url_signal=sorted(input_url_signal),
    input_url_background=sorted(input_url_background)
)
indices = list(range(dl1dh_reader._get_n_events()))
np.random.shuffle(indices)
n_validation_examples = int(0.2 * dl1dh_reader._get_n_events())
training_indices = indices[n_validation_examples:]
validation_indices = indices[:n_validation_examples]

training_loader = DLDataLoader(
    dl1dh_reader,
    training_indices,
    tasks=[training_config.TrainCTLearnModel.reco_tasks],
    batch_size=training_config.TrainCTLearnModel.batch_size,
    random_seed=0,
    sort_by_intensity=False,
    stack_telescope_images=False,
)
validation_loader = DLDataLoader(
    dl1dh_reader,
    validation_indices,
    tasks=[training_config.TrainCTLearnModel.reco_tasks],
    batch_size=training_config.TrainCTLearnModel.batch_size,
    random_seed=0,
    sort_by_intensity=False,
    stack_telescope_images=False,
)

[PosixPath('/home/hugo/TM/data/samples/gamma/train/gamma_200_800E3GeV_20_20deg_ATM52_100596.corsika.gz.NSBmed4.simtel.h5'), PosixPath('/home/hugo/TM/data/samples/gamma/train/gamma_200_800E3GeV_20_20deg_ATM52_100576.corsika.gz.NSBmed4.simtel.h5'), PosixPath('/home/hugo/TM/data/samples/gamma/train/gamma_200_800E3GeV_20_20deg_ATM52_100591.corsika.gz.NSBmed4.simtel.h5'), PosixPath('/home/hugo/TM/data/samples/gamma/train/gamma_200_800E3GeV_20_20deg_ATM52_100518.corsika.gz.NSBmed4.simtel.h5'), PosixPath('/home/hugo/TM/data/samples/gamma/train/gamma_200_800E3GeV_20_20deg_ATM52_100599.corsika.gz.NSBmed4.simtel.h5'), PosixPath('/home/hugo/TM/data/samples/gamma/train/gamma_200_800E3GeV_20_20deg_ATM52_100555.corsika.gz.NSBmed4.simtel.h5'), PosixPath('/home/hugo/TM/data/samples/gamma/train/gamma_200_800E3GeV_20_20deg_ATM52_110056.corsika.gz.NSBmed4.simtel.h5'), PosixPath('/home/hugo/TM/data/samples/gamma/train/gamma_200_800E3GeV_20_20deg_ATM52_100589.corsika.gz.NSBmed4.simtel.h5'), PosixPath('/hom



























In [32]:
autoqk = AutoQKeras(qmodel, metrics=["acc"], custom_objects=custom_objects, **run_config)
autoqk.fit(training_loader, validation_data=validation_loader, batch_size=1024, epochs=20)

IndexError: list index out of range

### Quantization - (small tests)

In [39]:
quantized_model = model.quantize("int4")

AttributeError: 'Functional' object has no attribute 'quantize'

In [40]:
import tensorflow.keras as keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model
from qkeras import *
from qkeras.quantizers import quantized_bits, quantized_relu
from qkeras.utils import model_quantize, _add_supported_quantized_objects


# Prepare custom_objects
custom_objects = {}
_add_supported_quantized_objects(custom_objects)

# Register all QKeras objects for Keras serialization
for name, obj in custom_objects.items():
    keras.saving.register_keras_serializable(package="QKeras", name=name)(obj)

x = x_in = Input((96, 96, 2))
x = Conv2D(18, (3, 3), name="conv2d_1")(x)
x = Activation("relu", name="act_1")(x)
x = Conv2D(32, (3, 3), name="conv2d_2")(x)
x = Activation("relu", name="act_2")(x)
x = Flatten(name="flatten")(x)
x = Dense(1, name="dense")(x)
x = Activation("softmax", name="softmax")(x)

model = Model(inputs=x_in, outputs=x)

print(type(model))
model.summary()

default_config = {
    "QConv2D": {
        "conv2d_1": {
            "kernel_quantizer": "quantized_bits(4,0,1)",
            "bias_quantizer": "quantized_bits(4,0,1)"
        },
        "conv2d_2": {
            "kernel_quantizer": "quantized_bits(4,0,1)",
            "bias_quantizer": "quantized_bits(4,0,1)"
        }
    },
    "QDense": {
        "dense": {
            "kernel_quantizer": "quantized_bits(4,0,1)",
            "bias_quantizer": "quantized_bits(4)"
        }
    },
    "default": {
        "kernel_quantizer": "quantized_bits(4,0,1)",
        "bias_quantizer": "quantized_bits(4,0,1)"
    }
}

print("TensorFlow version:", tf.__version__)

qmodel = model_quantize(
    model,
    default_config,
    activation_bits=4,
    transfer_weights=True,
    custom_objects=custom_objects
)
qmodel.summary()

<class 'keras.src.engine.functional.Functional'>
Model: "model_17"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_18 (InputLayer)       [(None, 96, 96, 2)]       0         
                                                                 
 conv2d_1 (Conv2D)           (None, 94, 94, 18)        342       
                                                                 
 act_1 (Activation)          (None, 94, 94, 18)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 92, 92, 32)        5216      
                                                                 
 act_2 (Activation)          (None, 92, 92, 32)        0         
                                                                 
 flatten (Flatten)           (None, 270848)            0         
                                                                 
 dense (D

In [41]:
for layer in qmodel.layers:
    print(layer.name, layer.__class__.__name__)
    if hasattr(layer, "kernel_quantizer"):
        print("  kernel quantizer:", layer.kernel_quantizer)
    if hasattr(layer, "bias_quantizer"):
        print("  bias quantizer:", layer.bias_quantizer)
    if hasattr(layer, "activation"):
        print("  activation:", layer.activation)

input_18 InputLayer
conv2d_1 Conv2D
  activation: <function linear at 0x7645fca4a440>
act_1 Activation
  activation: <function relu at 0x7645fca49a20>
conv2d_2 Conv2D
  activation: <function linear at 0x7645fca4a440>
act_2 Activation
  activation: <function relu at 0x7645fca49a20>
flatten Flatten
dense Dense
  activation: <function linear at 0x7645fca4a440>
softmax Activation
  activation: <function softmax at 0x7645fca49000>


In [43]:
for layer in qmodel.layers:
  try:
    if layer.get_quantizers():
      q_w_pairs = zip(layer.get_quantizers(), layer.get_weights())
      for _, (quantizer, weight) in enumerate(q_w_pairs):
        qweight = K.eval(quantizer(weight))
        print("quantized weight")
        print(qweight)
  except AttributeError:
    print("warning, the weight is not quantized in the layer %s", layer.name)



In [42]:
inp_shape = qmodel.input_shape  # or model.inputs[0].shape
dummy = np.zeros([1] + list(inp_shape[1:]), dtype=np.float32)
_ = qmodel.predict(dummy, verbose=0)

for i, L in enumerate(qmodel.layers):
    has = hasattr(L, "get_quantizers")
    q = None
    if has:
        try:
            q = L.get_quantizers()
        except Exception as e:
            q = f"get_quantizers() raised {e!r}"
    print(i, L.name, type(L).__name__, "has_get_quantizers:", has, "->", q)

0 input_18 InputLayer has_get_quantizers: False -> None
1 conv2d_1 Conv2D has_get_quantizers: False -> None
2 act_1 Activation has_get_quantizers: False -> None
3 conv2d_2 Conv2D has_get_quantizers: False -> None
4 act_2 Activation has_get_quantizers: False -> None
5 flatten Flatten has_get_quantizers: False -> None
6 dense Dense has_get_quantizers: False -> None
7 softmax Activation has_get_quantizers: False -> None


### Quantization - last steps

In [15]:
import tensorflow.keras as keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import Model

x = x_in = Input((96, 96, 2))
x = Conv2D(18, (3, 3), name="conv2d_1")(x)
x = Activation("relu", name="act_1")(x)
x = Conv2D(32, (3, 3), name="conv2d_2")(x)
x = Activation("relu", name="act_2")(x)
x = Flatten(name="flatten")(x)
x = Dense(1, name="dense")(x)
x = Activation("softmax", name="softmax")(x)

model = Model(inputs=x_in, outputs=x)

print(type(model))
model.summary()


<class 'keras.src.engine.functional.Functional'>
Model: "model_1"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input_2 (InputLayer)        [(None, 96, 96, 2)]       0         
                                                                 
 conv2d_1 (Conv2D)           (None, 94, 94, 18)        342       
                                                                 
 act_1 (Activation)          (None, 94, 94, 18)        0         
                                                                 
 conv2d_2 (Conv2D)           (None, 92, 92, 32)        5216      
                                                                 
 act_2 (Activation)          (None, 92, 92, 32)        0         
                                                                 
 flatten (Flatten)           (None, 270848)            0         
                                                                 
 dense (De

In [30]:
from tensorflow.keras.models import clone_model


def quantize_model(model):
    editing_model = model
    for i, layer in enumerate(editing_model.layers):
        editing_model.layers[i] = annotate_layer(layer)
    with tfmot.quantization.keras.quantize_scope():
        q_aware_model = tfmot.quantization.keras.quantize_apply(editing_model, quantized_layer_name_prefix='quant_')
    return q_aware_model
    
def annotate_layer(layer):
    # If the layer is a nested model → recurse
    if isinstance(layer, tf.keras.Model): 
        try:
            nested_model = quantize_model(layer)
            nested_model.summary()
            return nested_model
        except Exception as e:
            print(f"Skipping model {layer.name}: {e}")
            return layer

    # Skip layers that cannot be quantized
    if isinstance(layer, tf.keras.layers.InputLayer):
        return layer

    # Wrap quantizable layers
    try:
        return tfmot.quantization.keras.quantize_annotate_layer(
            layer
        )
    except Exception as e:
        print(f"Skipping layer {layer.name}: {e}")
        return layer

import tensorflow as tf
import tensorflow_model_optimization as tfmot
    
annotated_model = quantize_model(model)
annotated_model.summary()

KeyError: <Reference wrapping <KerasTensor: shape=(None, 1024) dtype=float32 (created by layer 'ThinResNet_block')>>

In [81]:
def print_quantization_tree(layer, indent=0):
    prefix = "  " * indent
    print(f"{prefix}- {layer.__class__.__name__}  (name: {layer.name})")

    # If it's a wrapper → it's quantized
    if "QuantizeWrapper" in layer.__class__.__name__:
        print(f"{prefix}    [QUANTIZED]")

    # Recurse for nested models
    if isinstance(layer, tf.keras.Model):
        for sub in layer.layers:
            if sub is not layer:  # avoid self-loop
                print_quantization_tree(sub, indent + 1)
print_quantization_tree(annotated_model)

- Functional  (name: CTLearn_model)
  - InputLayer  (name: input)
  - Functional  (name: ThinResNet_block)
    - InputLayer  (name: input)
    - Conv2D  (name: ThinResNet_block_conv2_block1_1_conv)
    - Conv2D  (name: ThinResNet_block_conv2_block1_2_conv)
    - Conv2D  (name: ThinResNet_block_conv2_block1_0_conv)
    - Conv2D  (name: ThinResNet_block_conv2_block1_3_conv)
    - Add  (name: ThinResNet_block_conv2_block1_add)
    - ReLU  (name: ThinResNet_block_conv2_block1_out)
    - Conv2D  (name: ThinResNet_block_conv2_block2_1_conv)
    - Conv2D  (name: ThinResNet_block_conv2_block2_2_conv)
    - Conv2D  (name: ThinResNet_block_conv2_block2_3_conv)
    - Add  (name: ThinResNet_block_conv2_block2_add)
    - ReLU  (name: ThinResNet_block_conv2_block2_out)
    - Conv2D  (name: ThinResNet_block_conv3_block1_1_conv)
    - Conv2D  (name: ThinResNet_block_conv3_block1_2_conv)
    - Conv2D  (name: ThinResNet_block_conv3_block1_0_conv)
    - Conv2D  (name: ThinResNet_block_conv3_block1_3_conv

In [73]:
with tfmot.quantization.keras.quantize_scope():
    q_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
q_aware_model.summary()

Model: "CTLearn_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input (InputLayer)          [(None, 96, 96, 2)]       0         
                                                                 
 ThinResNet_block (Function  (None, 1024)              5357600   
 al)                                                             
                                                                 
 quant_fc_energy_1 (Quantiz  (None, 512)               524805    
 eWrapperV2)                                                     
                                                                 
 quant_fc_energy_2 (Quantiz  (None, 256)               131333    
 eWrapperV2)                                                     
                                                                 
 quant_energy (QuantizeWrap  (None, 1)                 262       
 perV2)                                              

### Pruning

In [34]:
# Load model as Custom one...
config_training = c.training_model
if os.path.exists(config_training.TrainCTLearnModel.output_dir):
    shutil.rmtree(config_training.TrainCTLearnModel.output_dir)
    
model = TrainCTLearnModel(config=config_training)

start = time.time()
try:
    model.run()
except SystemExit as e:
    print(f"Caught SystemExit ({e.code}, continuing...)")
end = time.time()
training_time = (end - start) * 1000 # ms
training_events = model.dl1dh_reader._get_n_events()

--- Logging error ---
Traceback (most recent call last):
  File "/home/hugo/miniforge3/envs/ctlearn/lib/python3.10/logging/__init__.py", line 440, in format
    return self._format(record)
  File "/home/hugo/miniforge3/envs/ctlearn/lib/python3.10/logging/__init__.py", line 436, in _format
    return self._fmt % values
KeyError: 'highlevel'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/home/hugo/miniforge3/envs/ctlearn/lib/python3.10/logging/__init__.py", line 1100, in emit
    msg = self.format(record)
  File "/home/hugo/miniforge3/envs/ctlearn/lib/python3.10/logging/__init__.py", line 943, in format
    return fmt.format(record)
  File "/home/hugo/miniforge3/envs/ctlearn/lib/python3.10/site-packages/ctapipe/core/logging.py", line 52, in format
    s = super().format(record)
  File "/home/hugo/miniforge3/envs/ctlearn/lib/python3.10/logging/__init__.py", line 681, in format
    s = self.formatMessage(record)
  File "/ho

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)
2025-11-12 19:26:10,002 [1;32mINFO[0m [tools.ctlearn-train-model] (train_model.setup): Number of devices: 1
2025-11-12 19:26:10,004 [1;32mINFO[0m [tools.ctlearn-train-model] (train_model.setup): Loading data:
2025-11-12 19:26:10,005 [1;32mINFO[0m [tools.ctlearn-train-model] (train_model.setup): For a large dataset, this may take a while...
2025-11-12 19:26:11,756 [1;32mINFO[0m [tools.ctlearn-train-model] (train_model.setup): Number of events loaded: 3236
2025-11-12 19:26:11,757 [1;34mDEBUG[0m [tools.ctlearn-train-model] (tool.run): CONFIG: {'TrainCTLearnModel': {'batch_size': 128, 'config_files': [], 'dl1dh_reader_type': 'DLImageReader', 'early_stopping': None, 'file_pattern_background': ['*.h5'], 'file_pattern_signal': ['gamma_*.h5'], 'input_dir_background': None, 'input_dir_signal': PosixPath('/home/hugo/TM/data/samples/gamma/train'), 'log_config': {}, 'log_datefmt': '%Y-%m



2025-11-12 19:26:12,838 [1;32mINFO[0m [tools.ctlearn-train-model] (train_model.start): Pruning CTLearn model.
2025-11-12 19:26:12,839 [1;32mINFO[0m [tools.ctlearn-train-model] (train_model.start): Pruning: steps_per_epoch=22, end_step=22 (n_epochs=1)
2025-11-12 19:26:12,840 [1;32mINFO[0m [tools.ctlearn-train-model] (train_model.start): Parameters for pruning: initial_sparsity=0.5, final_sparsity=0.9, begin_step=0
2025-11-12 19:26:13,571 [1;32mINFO[0m [tools.ctlearn-train-model] (train_model.start): Compiling CTLearn model.
2025-11-12 19:26:13,580 [1;32mINFO[0m [tools.ctlearn-train-model] (train_model.start): Training and evaluating...
2025-11-12 19:26:13.750325: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will not be optimized because the dataset does not implement the AsGraphDefInternal() method needed to apply optimizations.
2025-11-12 19:29:06.036328: W tensorflow/core/framework/dataset.cc:959] Input of GeneratorDatasetOp::Dataset will 


Epoch 1: val_loss improved from inf to 0.62816, saving model to /home/hugo/TM/ml/models/energy/optimize/v2/ctlearn_model.cpk
INFO:tensorflow:Assets written to: /home/hugo/TM/ml/models/energy/optimize/v2/ctlearn_model.cpk/assets


INFO:tensorflow:Assets written to: /home/hugo/TM/ml/models/energy/optimize/v2/ctlearn_model.cpk/assets


22/22 - 191s - loss: 0.6659 - mae_energy: 0.6659 - val_loss: 0.6282 - val_mae_energy: 0.6282 - lr: 1.0000e-04 - 191s/epoch - 9s/step


2025-11-12 19:29:24,843 [1;32mINFO[0m [tools.ctlearn-train-model] (train_model.start): Training and evaluating finished succesfully!
2025-11-12 19:29:24,844 [1;32mINFO[0m [tools.ctlearn-train-model] (train_model.finish): Tool is shutting down
2025-11-12 19:29:24,846 [1;32mINFO[0m [tools.ctlearn-train-model] (tool.write_provenance): Output: /home/hugo/TM/ml/models/type/optimize/v2/predict/gamma_200_800E3GeV_20_20deg_ATM52_100505.h5
2025-11-12 19:29:24,846 [1;32mINFO[0m [tools.ctlearn-train-model] (tool.write_provenance): Output: /home/hugo/TM/ml/models/type/optimize/v2/predict/gamma_200_800E3GeV_20_20deg_ATM52_110055.h5
2025-11-12 19:29:24,847 [1;32mINFO[0m [tools.ctlearn-train-model] (tool.write_provenance): Output: /home/hugo/TM/ml/models/type/optimize/v2/predict/gamma_200_800E3GeV_20_20deg_ATM52_100575.h5
2025-11-12 19:29:24,847 [1;32mINFO[0m [tools.ctlearn-train-model] (tool.write_provenance): Output: /home/hugo/TM/ml/models/type/optimize/v2/predict/proton_400_1300E3GeV_

Caught SystemExit (0, continuing...)


In [35]:
model_pruned_path = os.path.join(config_training.TrainCTLearnModel.output_dir, "ctlearn_model.cpk")
with tfmot.sparsity.keras.prune_scope():
    model_pruned = tf.keras.models.load_model(model_pruned_path)
model_pruned.summary()





Model: "CTLearn_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input (InputLayer)          [(None, 96, 96, 2)]       0         
                                                                 
 ThinResNet_block (Function  (None, 1024)              10703969  
 al)                                                             
                                                                 
 prune_low_magnitude_fc_ene  (None, 512)               1049090   
 rgy_1 (PruneLowMagnitude)                                       
                                                                 
 prune_low_magnitude_fc_ene  (None, 256)               262402    
 rgy_2 (PruneLowMagnitude)                                       
                                                                 
 prune_low_magnitude_energy  (None, 1)                 515       
  (PruneLowMagnitude)                                

In [36]:
model_for_export = tfmot.sparsity.keras.strip_pruning(model_pruned)
model_for_export.summary()

Model: "CTLearn_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 input (InputLayer)          [(None, 96, 96, 2)]       0         
                                                                 
 ThinResNet_block (Function  (None, 1024)              5357600   
 al)                                                             
                                                                 
 fc_energy_1 (Dense)         (None, 512)               524800    
                                                                 
 fc_energy_2 (Dense)         (None, 256)               131328    
                                                                 
 energy (Dense)              (None, 1)                 257       
                                                                 
Total params: 6013985 (22.94 MB)
Trainable params: 6013985 (22.94 MB)
Non-trainable params: 0 (0.00 Byte)
_____________

In [None]:
model_for_export.save(os.path.join(config_training.TrainCTLearnModel.output_dir, "ctlearn_model.cpk"))

In [8]:
# Perform predictions
config_training = c.training_model
# Prepare Prediction model
model = os.path.join(config_training.TrainCTLearnModel.output_dir, "ctlearn_model.cpk")
 

# Predict on every test file
particles = ["gamma", "proton"]
testing_events = 0
inference_time_global = 0
# Create result folder (clean if already existing)
shutil.rmtree(os.path.join(config_training.TrainCTLearnModel.output_dir, "predict"), ignore_errors=True)
os.makedirs(os.path.join(config_training.TrainCTLearnModel.output_dir, "predict"), exist_ok=True)
for particle in particles: 
        directory = os.path.join(ctlearn_mgr_config.training_samples_path, particle, "test")
        for filename in os.listdir(directory):
            if filename.endswith(".h5"):
                # Prepare new filename as output
                predict_file = os.path.basename(filename).split(".", 1)[0]
                input_url = os.path.join(directory, filename)
                # Create results file
                output_url = os.path.join(config_training.TrainCTLearnModel.output_dir, "predict", f"{predict_file}.h5")
                # Launch the prediction
                match config_training.TrainCTLearnModel.reco_tasks:
                    case "type":
                        model_predict = MonoPredictCTLearnModel(input_url=input_url, load_type_model_from=model, output_path=output_url)
                    case "energy":
                        model_predict = MonoPredictCTLearnModel(input_url=input_url, load_energy_model_from=model, output_path=output_url)
                    case "direction":
                        model_predict = MonoPredictCTLearnModel(input_url=input_url, load_direction_model_from=model, output_path=output_url)
                    case _:
                        print("ERROR")
                start = time.time()
                try:
                    model_predict.run()
                except SystemExit as e:
                    print(f"Caught SystemExit ({e.code}, continuing...)")
                stop = time.time()
                testing_events += model_predict.dl1dh_reader._get_n_events()
                inference_time_global += (stop - start) * 1000 # ms



INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)








Caught SystemExit (0, continuing...)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)








Caught SystemExit (0, continuing...)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)








Caught SystemExit (0, continuing...)




INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)








Caught SystemExit (0, continuing...)




INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)








Caught SystemExit (0, continuing...)




INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)




Caught SystemExit (0, continuing...)




In [12]:
inference_per_events_pruned = inference_time_global / testing_events
print(f"Inference per event : {inference_per_events_pruned:.4f} ms")
print(f"Total time to test : {(inference_time_global / 1000):.4f} s for {testing_events} events")

Inference per event : 35.0832 ms
Total time to test : 97.9172 s for 2791 events


In [15]:
# Perform predictions

# Prepare Prediction model
model = os.path.join(c.TrainCTLearnModel.output_dir, "ctlearn_model.cpk")
 

# Predict on every test file
particles = ["gamma", "proton"]
testing_events = 0
inference_time_global = 0
# Create result folder (clean if already existing)
shutil.rmtree(os.path.join(c.TrainCTLearnModel.output_dir, "predict"), ignore_errors=True)
os.makedirs(os.path.join(c.TrainCTLearnModel.output_dir, "predict"), exist_ok=True)
for particle in particles: 
        directory = os.path.join(ctlearn_mgr_config.training_samples_path, particle, "test")
        for filename in os.listdir(directory):
            if filename.endswith(".h5"):
                # Prepare new filename as output
                predict_file = os.path.basename(filename).split(".", 1)[0]
                input_url = os.path.join(directory, filename)
                # Create results file
                output_url = os.path.join(c.TrainCTLearnModel.output_dir, "predict", f"{predict_file}.h5")
                # Launch the prediction
                match c.TrainCTLearnModel.reco_tasks:
                    case "type":
                        model_predict = MonoPredictCTLearnModel(input_url=input_url, load_type_model_from=model, output_path=output_url)
                    case "energy":
                        model_predict = MonoPredictCTLearnModel(input_url=input_url, load_energy_model_from=model, output_path=output_url)
                    case "direction":
                        model_predict = MonoPredictCTLearnModel(input_url=input_url, load_direction_model_from=model, output_path=output_url)
                    case _:
                        print("ERROR")
                start = time.time()
                try:
                    model_predict.run()
                except SystemExit as e:
                    print(f"Caught SystemExit ({e.code}, continuing...)")
                stop = time.time()
                testing_events += model_predict.dl1dh_reader._get_n_events()
                inference_time_global += (stop - start) * 1000 # ms



INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)






Caught SystemExit (0, continuing...)




INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)






Caught SystemExit (0, continuing...)




INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)






Caught SystemExit (0, continuing...)




INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)






Caught SystemExit (0, continuing...)




INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)






Caught SystemExit (0, continuing...)




INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)


INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:CPU:0',)






Caught SystemExit (0, continuing...)


In [16]:
inference_per_events_pruned = inference_time_global / testing_events
print(inference_per_events_pruned)
print(inference_time_global)
print(testing_events)

33.611548925035535
93809.83304977417
2791
