# CNN Model Conversion (TensorFlow.js & ONNX)

Convert the trained Keras model `food_classifier_best.keras` to TensorFlow.js and/or ONNX.
Run this notebook whenever you want to update the exported models after retraining.

**Requirements:** `food_classifier_best.keras` must be in `data/models/cnn/` (same location the SmartFood app uses).

You can run the **TensorFlow.js** and **ONNX** conversion cells independently of each other and in any order.

## Paths (optional)

Run this cell first to set paths (notebook is expected to live in `notebooks/cnn/`). You can also run each conversion cell below on its own; they set paths internally.

In [None]:
import os

BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), "..", ".."))
MODEL_DIR = os.path.join(BASE_DIR, "data", "models", "cnn")

print(f"MODEL_DIR: {MODEL_DIR}")
model_file = os.path.join(MODEL_DIR, "food_classifier_best.keras")
print(f"Keras model: {model_file}")
if not os.path.exists(model_file):
    print("  File not found – train first or copy the model here.")


## Convert to TensorFlow.js

Converts the trained Keras model to TensorFlow.js for deployment in web applications. **Runs independently** – no need to run the ONNX cell or the Paths cell first.

In [None]:
# TensorFlow.js conversion – runs independently (sets paths if needed)
import os
import tensorflow as tf

BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
MODEL_DIR = os.path.join(BASE_DIR, 'data', 'models', 'cnn')
TFJS_OUTPUT_DIR = os.path.join(MODEL_DIR, 'tfjs')
model_path = os.path.join(MODEL_DIR, 'food_classifier_best.keras')

print('\n' + '=' * 60)
print('Converting model to TensorFlow.js format...')
print('=' * 60)

try:
    import tensorflowjs as tfjs
    
    os.makedirs(TFJS_OUTPUT_DIR, exist_ok=True)
    
    print(f'Converting {model_path} to TensorFlow.js format...')
    print(f'Output directory: {TFJS_OUTPUT_DIR}')
    
    # Load the saved model
    model_for_tfjs = tf.keras.models.load_model(model_path, compile=False)
    
    # Convert to TensorFlow.js
    tfjs.converters.save_keras_model(model_for_tfjs, TFJS_OUTPUT_DIR)
    
    print(f'✓ Model successfully converted to TensorFlow.js!')
    print(f'  TensorFlow.js model location: {TFJS_OUTPUT_DIR}')
    print(f'  The model can now be used directly in web browsers')
    print(f'  No Python servers needed in deployment!')
    
except ImportError:
    print('⚠ tensorflowjs not installed. Skipping TensorFlow.js conversion.')
    print('  To convert later, install: pip install --break-system-packages --timeout=300 tensorflowjs')
    print('  Then re-run this cell or use: tensorflowjs_converter --input_format=keras model.keras output_dir/')
except Exception as e:
    print(f'⚠ Error converting to TensorFlow.js: {e}')
    print('  The Keras model is still saved and can be converted manually later.')

print('=' * 60)

## Convert to ONNX

Exports the model to ONNX so the Next.js app can use `onnxruntime-node` for faster inference than TensorFlow.js CPU. **Runs independently** – no need to run the TensorFlow.js cell or the Paths cell first. Requires: `pip install tf2onnx onnx`.

In [None]:
# ONNX conversion - runs independently
import os
import tensorflow as tf

BASE_DIR = os.path.abspath(os.path.join(os.getcwd(), '..', '..'))
MODEL_DIR = os.path.join(BASE_DIR, 'data', 'models', 'cnn')
model_path = os.path.join(MODEL_DIR, 'food_classifier_best.keras')
ONNX_OUTPUT_PATH = os.path.join(MODEL_DIR, 'food_classifier.onnx')

print('\n' + '=' * 60)
print('Converting model to ONNX format...')
print('=' * 60)
print(f'Input: {model_path}')
print(f'Output: {ONNX_OUTPUT_PATH}')

try:
    import tf2onnx
    import tempfile
    import shutil
    import json as json_module
    import ast
    import copy
    import zipfile
    
    def _string_shape_to_list(s):
        if not isinstance(s, str):
            return None
        s = s.strip()
        inner = None
        if s.startswith('TensorShape(') and s.endswith(')'):
            inner = s[12:-1].strip()
        if inner is None and '[' in s and ']' in s:
            start, end = s.find('['), s.rfind(']')
            if start != -1 and end != -1 and end > start:
                inner = s[start:end+1]
        if inner is None:
            return None
        inner = inner.replace('null', 'None').replace('Null', 'None')
        try:
            out = ast.literal_eval(inner)
            return out if isinstance(out, list) else None
        except:
            return None

    def _string_to_list_any(s):
        """Parse any string that looks like a list/tuple into a list."""
        if not isinstance(s, str):
            return None
        s = s.strip()
        if (s.startswith('[') and s.endswith(']')) or (s.startswith('(') and s.endswith(')')):
            try:
                out = ast.literal_eval(s)
                if isinstance(out, (list, tuple)):
                    return list(out)
            except:
                return None
        return None
    
    def _dtype_dict_to_str(dtype_val):
        if not isinstance(dtype_val, dict):
            return None
        inner = dtype_val.get('config') or dtype_val
        if isinstance(inner, dict) and 'name' in inner:
            return inner['name']
        return None
    
    def _fix_config_deep(obj):
        """Recursively fix config: shapes, dtypes, InputLayer params"""
        if isinstance(obj, dict):
            if obj.get('class_name') == 'InputLayer' and 'config' in obj:
                c = obj['config']
                if 'batch_shape' in c:
                    c['batch_input_shape'] = c.pop('batch_shape')
                c.pop('optional', None)
            for k in list(obj.keys()):
                v = obj[k]
                if isinstance(v, str):
                    parsed = _string_shape_to_list(v)
                    if parsed is None:
                        parsed = _string_to_list_any(v)
                    if parsed is not None:
                        obj[k] = parsed
                elif k == 'dtype' and isinstance(v, dict):
                    dtype_str = _dtype_dict_to_str(v)
                    if dtype_str:
                        obj[k] = dtype_str
                else:
                    _fix_config_deep(v)
        elif isinstance(obj, list):
            for i, item in enumerate(obj):
                if isinstance(item, str):
                    parsed = _string_shape_to_list(item)
                    if parsed is None:
                        parsed = _string_to_list_any(item)
                    if parsed is not None:
                        obj[i] = parsed
                else:
                    _fix_config_deep(item)
    
    # ========================================================================
    # APPROACH: Fix config.json inside .keras file, then load normally
    # ========================================================================
    print("\n[1/4] Extracting and fixing model config...")
    
    with tempfile.TemporaryDirectory() as tmp_dir:
        # Extract .keras file
        with zipfile.ZipFile(model_path, 'r') as zf:
            zf.extractall(tmp_dir)
        
        # Load and fix config.json
        config_path = os.path.join(tmp_dir, 'config.json')
        with open(config_path, 'r', encoding='utf-8') as f:
            config = json_module.load(f)
        
        # Deep fix all config issues
        _fix_config_deep(config)
        
        # Also specifically fix inbound_nodes which may have string arrays
        def fix_inbound_nodes(obj):
            if isinstance(obj, dict):
                if 'inbound_nodes' in obj:
                    nodes = obj['inbound_nodes']
                    for i, node in enumerate(nodes):
                        if isinstance(node, list):
                            for j, item in enumerate(node):
                                if isinstance(item, str):
                                    parsed = _string_shape_to_list(item)
                                    if parsed is None:
                                        parsed = _string_to_list_any(item)
                                    if parsed is not None:
                                        nodes[i][j] = parsed
                for v in obj.values():
                    fix_inbound_nodes(v)
            elif isinstance(obj, list):
                for item in obj:
                    fix_inbound_nodes(item)
        
        fix_inbound_nodes(config)
        
        # Save fixed config
        with open(config_path, 'w', encoding='utf-8') as f:
            json_module.dump(config, f)
        
        # Repack to new .keras file
        fixed_keras_path = os.path.join(tmp_dir, 'fixed_model.keras')
        with zipfile.ZipFile(fixed_keras_path, 'w', zipfile.ZIP_DEFLATED) as zf:
            for root, dirs, files in os.walk(tmp_dir):
                for file in files:
                    if file == 'fixed_model.keras':
                        continue
                    file_path = os.path.join(root, file)
                    arc_name = os.path.relpath(file_path, tmp_dir)
                    zf.write(file_path, arc_name)
        
        print("  OK: Config fixed and repacked")
        
        # ========================================================================
        # Load model from fixed .keras file
        # ========================================================================
        print("\n[2/4] Loading fixed Keras model...")
        
        # Custom InputLayer for compatibility
        from tensorflow.keras.layers import InputLayer as KerasInputLayer
        class InputLayerCompat(KerasInputLayer):
            def __init__(self, batch_shape=None, optional=False, **kwargs):
                if batch_shape is not None and 'batch_input_shape' not in kwargs:
                    kwargs['batch_input_shape'] = batch_shape
                super().__init__(**kwargs)
        
        custom_objs = {'InputLayer': InputLayerCompat}
        
        import warnings
        with warnings.catch_warnings():
            warnings.simplefilter('ignore')
            model_for_onnx = tf.keras.models.load_model(
                fixed_keras_path, 
                custom_objects=custom_objs, 
                compile=False, 
                safe_mode=False
            )
        
        print("  OK: Model loaded!")
        print(f"  Input: {model_for_onnx.input_shape}")
        print(f"  Output: {model_for_onnx.output_shape}")
        
        # ========================================================================
        # Convert to ONNX
        # ========================================================================
        print("\n[3/4] Converting to ONNX...")
        input_spec = [tf.TensorSpec((None, 224, 224, 3), tf.float32, name='input_1')]
        onnx_model, _ = tf2onnx.convert.from_keras(model_for_onnx, input_signature=input_spec, opset=14)
        
        print("\n[4/4] Saving ONNX model...")
        import onnx
        onnx.save(onnx_model, ONNX_OUTPUT_PATH)
    
    print(f'\nSUCCESS!')
    print(f'  Output: {ONNX_OUTPUT_PATH}')
    
except ImportError as e:
    print(f'ERROR: {e}')
    print('  Install: pip install tf2onnx onnx')
except Exception as e:
    import traceback
    traceback.print_exc()
    print(f'\nERROR: {e}')

print('=' * 60)