# Improved ColabNAS with Block Variants
This notebook implements an enhanced version of ColabNAS with:
* Multiple block variants (basic, residual, dense)
* Robust search strategy with epsilon threshold
* Enhanced training callbacks (early stopping, learning rate reduction)
* Better regularization techniques
* More efficient exploration algorithm

## Preliminaries

In [None]:
!pip install -q tensorflow-model-optimization

In [None]:
!git clone https://github.com/AndreaMattiaGaravagno/ColabNAS

In [None]:
!chmod +x /content/ColabNAS/stm32tflm
mv /content/ColabNAS/stm32tflm /content/

## Improved ColabNAS Implementation

In [None]:
from tensorflow_model_optimization.python.core.keras.compat import keras
from pathlib import Path
import tensorflow as tf
import numpy as np
import subprocess
import datetime
import shutil
import glob
import re
import os

class ImprovedColabNAS:
    architecture_name = 'resulting_architecture'
    
    def __init__(self, max_RAM, max_Flash, max_MACC, path_to_training_set, val_split, 
                 cache=False, input_shape=(50,50,3), save_path='.', 
                 path_to_stm32tflm='/content/stm32tflm'):
        self.learning_rate = 1e-3
        self.batch_size = 128
        self.epochs = 100
        
        self.max_MACC = max_MACC
        self.max_Flash = max_Flash
        self.max_RAM = max_RAM
        self.path_to_training_set = path_to_training_set
        self.num_classes = len(next(os.walk(path_to_training_set))[1])
        self.val_split = val_split
        self.cache = cache
        self.input_shape = input_shape
        self.save_path = Path(save_path)
        self.current_block_type = 'basic'
        
        self.path_to_trained_models = self.save_path / "trained_models"
        self.path_to_trained_models.mkdir(parents=True, exist_ok=True)
        
        self.path_to_stm32tflm = Path(path_to_stm32tflm)
        
        self.load_training_set()
    
    def Model(self, k, c):
        """Enhanced model with block variants and regularization"""
        kernel_size = (3,3)
        pool_size = (2,2)
        pool_strides = (2,2)
        
        number_of_cells_limited = False
        number_of_mac = 0
        
        inputs = keras.Input(shape=self.input_shape)
        
        # Convolutional base
        n = int(k)
        multiplier = 2
        
        # First convolutional layer
        c_in = self.input_shape[2]
        x = keras.layers.Conv2D(n, kernel_size, padding='same')(inputs)
        x = keras.layers.BatchNormalization()(x)
        x = keras.layers.ReLU()(x)
        
        number_of_mac += (c_in * kernel_size[0] * kernel_size[1] * x.shape[1] * x.shape[2] * x.shape[3])
        
        # Adding cells with block variants
        for i in range(1, c + 1):
            if x.shape[1] <= 1 or x.shape[2] <= 1:
                number_of_cells_limited = True
                break
                
            n = int(np.ceil(n * multiplier))
            multiplier = multiplier - 2**-i
            
            # Apply pooling
            x = keras.layers.MaxPooling2D(pool_size=pool_size, strides=pool_strides, padding='valid')(x)
            c_in = x.shape[3]
            
            # Apply block variant
            if self.current_block_type == 'residual' and c_in == n:
                # Residual block
                residual = x
                x = keras.layers.Conv2D(n, kernel_size, padding='same')(x)
                x = keras.layers.BatchNormalization()(x)
                x = keras.layers.ReLU()(x)
                x = keras.layers.Add()([x, residual])
            elif self.current_block_type == 'dense':
                # Dense block (concatenation)
                x1 = keras.layers.Conv2D(n//2, kernel_size, padding='same')(x)
                x1 = keras.layers.BatchNormalization()(x1)
                x1 = keras.layers.ReLU()(x1)
                x = keras.layers.Concatenate()([x, x1])
                x = keras.layers.Conv2D(n, (1,1), padding='same')(x)  # 1x1 conv to adjust channels
                x = keras.layers.BatchNormalization()(x)
                x = keras.layers.ReLU()(x)
            else:
                # Basic block
                x = keras.layers.Conv2D(n, kernel_size, padding='same')(x)
                x = keras.layers.BatchNormalization()(x)
                x = keras.layers.ReLU()(x)
            
            number_of_mac += (c_in * kernel_size[0] * kernel_size[1] * x.shape[1] * x.shape[2] * x.shape[3])
        
        # Classifier with dropout for better generalization
        x = keras.layers.GlobalAveragePooling2D()(x)
        input_shape = x.shape[1]
        x = keras.layers.Dropout(0.2)(x)
        x = keras.layers.Dense(n)(x)
        x = keras.layers.BatchNormalization()(x)
        x = keras.layers.ReLU()(x)
        number_of_mac += (input_shape * x.shape[1])
        x = keras.layers.Dropout(0.1)(x)
        x = keras.layers.Dense(self.num_classes)(x)
        x = keras.layers.BatchNormalization()(x)
        outputs = keras.layers.Softmax()(x)
        number_of_mac += (x.shape[1] * outputs.shape[1])
        
        model = keras.Model(inputs=inputs, outputs=outputs)
        
        # Use adaptive learning rate with decay
        opt = tf.keras.optimizers.Adam(learning_rate=self.learning_rate, decay=1e-6)
        model.compile(optimizer=opt,
                     loss='categorical_crossentropy',
                     metrics=['accuracy'])
        
        model.summary()
        
        return model, number_of_mac, number_of_cells_limited
    
    def load_training_set(self):
        """Load and preprocess training data"""
        if 3 == self.input_shape[2]:
            color_mode = 'rgb'
        elif 1 == self.input_shape[2]:
            color_mode = 'grayscale'
        
        train_ds = tf.keras.utils.image_dataset_from_directory(
            directory=self.path_to_training_set,
            labels='inferred',
            label_mode='categorical',
            color_mode=color_mode,
            batch_size=self.batch_size,
            image_size=self.input_shape[0:2],
            shuffle=True,
            seed=11,
            validation_split=self.val_split,
            subset='training'
        )
        
        validation_ds = tf.keras.utils.image_dataset_from_directory(
            directory=self.path_to_training_set,
            labels='inferred',
            label_mode='categorical',
            color_mode=color_mode,
            batch_size=self.batch_size,
            image_size=self.input_shape[0:2],
            shuffle=True,
            seed=11,
            validation_split=self.val_split,
            subset='validation'
        )
        
        # Enhanced data augmentation
        data_augmentation = tf.keras.Sequential([
            tf.keras.layers.RandomFlip("horizontal"),
            tf.keras.layers.RandomRotation(0.2, fill_mode='constant', interpolation='bilinear'),
            tf.keras.layers.RandomZoom(0.1),
            tf.keras.layers.RandomContrast(0.1)
        ])
        
        if self.cache:
            self.train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y), 
                                       num_parallel_calls=tf.data.AUTOTUNE).cache().prefetch(buffer_size=tf.data.AUTOTUNE)
            self.validation_ds = validation_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
        else:
            self.train_ds = train_ds.map(lambda x, y: (data_augmentation(x, training=True), y), 
                                       num_parallel_calls=tf.data.AUTOTUNE).prefetch(buffer_size=tf.data.AUTOTUNE)
            self.validation_ds = validation_ds.prefetch(buffer_size=tf.data.AUTOTUNE)
    
    def quantize_model_uint8(self):
        """Quantize model to uint8"""
        def representative_dataset():
            for data in self.train_ds.rebatch(1).take(150):
                yield [tf.dtypes.cast(data[0], tf.float32)]
        
        model = tf.keras.models.load_model(self.path_to_trained_models / f"{self.model_name}.h5")
        converter = tf.lite.TFLiteConverter.from_keras_model(model)
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
        converter.representative_dataset = representative_dataset
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = tf.uint8
        converter.inference_output_type = tf.uint8
        tflite_quant_model = converter.convert()
        
        with open(self.path_to_trained_models / f"{self.model_name}.tflite", 'wb') as f:
            f.write(tflite_quant_model)
        
        (self.path_to_trained_models / f"{self.model_name}.h5").unlink()
    
    def evaluate_flash_and_peak_RAM_occupancy(self):
        """Evaluate Flash and RAM occupancy using STM32 tools"""
        self.quantize_model_uint8()
        
        proc = subprocess.Popen([self.path_to_stm32tflm, self.path_to_trained_models / f"{self.model_name}.tflite"], 
                               stdout=subprocess.PIPE)
        try:
            outs, errs = proc.communicate(timeout=15)
            Flash, RAM = re.findall(r'\d+', str(outs))
        except subprocess.TimeoutExpired:
            proc.kill()
            outs, errs = proc.communicate()
            print("stm32tflm error")
            exit()
        
        return int(Flash), int(RAM)
    
    def evaluate_model_process(self, k, c):
        """Enhanced model evaluation with better callbacks"""
        if k > 0:
            self.model_name = f"k_{k}_c_{c}_{self.current_block_type}"
            print(f"\n{self.model_name}\n")
            
            # Enhanced callbacks
            checkpoint = tf.keras.callbacks.ModelCheckpoint(
                str(self.path_to_trained_models / f"{self.model_name}.h5"), 
                monitor='val_accuracy',
                verbose=1, save_best_only=True, save_weights_only=False, mode='auto')
            
            early_stopping = tf.keras.callbacks.EarlyStopping(
                monitor='val_accuracy', patience=15, restore_best_weights=True)
            
            reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
                monitor='val_loss', factor=0.5, patience=8, min_lr=1e-7)
            
            model, MACC, number_of_cells_limited = self.Model(k, c)
            
            # Initial training for resource evaluation
            model.fit(self.train_ds, epochs=1, validation_data=self.validation_ds, validation_freq=1)
            model.save(self.path_to_trained_models / f"{self.model_name}.h5")
            Flash, RAM = self.evaluate_flash_and_peak_RAM_occupancy()
            print(f"\nRAM: {RAM},\t Flash: {Flash},\t MACC: {MACC}\n")
            
            if MACC <= self.max_MACC and Flash <= self.max_Flash and RAM <= self.max_RAM and not number_of_cells_limited:
                hist = model.fit(self.train_ds, epochs=self.epochs - 1, validation_data=self.validation_ds, 
                               validation_freq=1, callbacks=[checkpoint, early_stopping, reduce_lr])
                self.quantize_model_uint8()
                
            return {'k': k,
                    'c': c if not number_of_cells_limited else "Not feasible",
                    'RAM': RAM if RAM <= self.max_RAM else "Outside the upper bound",
                    'Flash': Flash if Flash <= self.max_Flash else "Outside the upper bound",
                    'MACC': MACC if MACC <= self.max_MACC else "Outside the upper bound",
                    'max_val_acc': np.around(np.amax(hist.history['val_accuracy']), decimals=3)
                    if 'hist' in locals() else -3}
        else:
            return {'k': 'unfeasible', 'c': c, 'max_val_acc': -3}
    
    def explore_num_cells(self, k):
        """Robust exploration with epsilon threshold"""
        previous_architecture = {'k': -1, 'c': -1, 'max_val_acc': -2}
        current_architecture = {'k': -1, 'c': -1, 'max_val_acc': -1}
        c = -1
        k = int(k)
        epsilon = 0.005
        
        # More robust exploration with epsilon threshold
        while current_architecture['max_val_acc'] > previous_architecture['max_val_acc'] + epsilon:
            previous_architecture = current_architecture
            c = c + 1
            self.model_counter = self.model_counter + 1
            current_architecture = self.evaluate_model_process(k, c)
            print(f"\n\n\n{current_architecture}\n\n\n")
            
            # Early termination if constraints are violated
            if (current_architecture['k'] == 'unfeasible' or 
                current_architecture['c'] == "Not feasible" or
                current_architecture['RAM'] == "Outside the upper bound" or
                current_architecture['Flash'] == "Outside the upper bound" or
                current_architecture['MACC'] == "Outside the upper bound"):
                break
                
        return previous_architecture
    
    def search(self):
        """Enhanced search with block variants"""
        self.model_counter = 0
        epsilon = 0.005
        k0 = 4
        
        # Block variants to explore different architectures
        block_variants = ['basic', 'residual', 'dense']
        
        start = datetime.datetime.now()
        
        best_architecture = {'k': -1, 'c': -1, 'block': 'basic', 'max_val_acc': -1}
        
        # Explore each block variant
        for block_type in block_variants:
            print(f"\nExploring block variant: {block_type}")
            self.current_block_type = block_type
            
            k = k0
            previous_architecture = self.explore_num_cells(k)
            previous_architecture['block'] = block_type
            
            k = 2 * k
            current_architecture = self.explore_num_cells(k)
            current_architecture['block'] = block_type

            if current_architecture['max_val_acc'] > previous_architecture['max_val_acc'] + epsilon:
                # Expanding search - keep doubling k while improvement continues
                while current_architecture['max_val_acc'] > previous_architecture['max_val_acc'] + epsilon:
                    previous_architecture = current_architecture
                    k = 2 * k
                    current_architecture = self.explore_num_cells(k)
                    current_architecture['block'] = block_type
                    if current_architecture['k'] == 'unfeasible':
                        break
            else:
                # Contracting search - keep halving k while improvement continues
                k = k0 / 2
                current_architecture = self.explore_num_cells(k)
                current_architecture['block'] = block_type
                while (current_architecture['max_val_acc'] >= previous_architecture['max_val_acc'] and k >= 1):
                    previous_architecture = current_architecture
                    k = k / 2
                    current_architecture = self.explore_num_cells(k)
                    current_architecture['block'] = block_type
                    if current_architecture['k'] == 'unfeasible' or k < 1:
                        break

            # Update best architecture if current is better
            if previous_architecture['max_val_acc'] > best_architecture['max_val_acc']:
                best_architecture = previous_architecture
                
        resulting_architecture = best_architecture

        end = datetime.datetime.now()

        if resulting_architecture['max_val_acc'] > 0:
            resulting_architecture_name = f"k_{resulting_architecture['k']}_c_{resulting_architecture['c']}_{resulting_architecture['block']}.tflite"
            self.path_to_resulting_architecture = self.save_path / f"resulting_architecture_{resulting_architecture_name}"
            (self.path_to_trained_models / f"k_{resulting_architecture['k']}_c_{resulting_architecture['c']}_{resulting_architecture['block']}.tflite").rename(self.path_to_resulting_architecture)
            shutil.rmtree(self.path_to_trained_models)
            print(f"\nResulting architecture: {resulting_architecture}\n")
        else:
            print(f"\nNo feasible architecture found\n")
        print(f"Elapsed time (search): {end-start}\n")

        return self.path_to_resulting_architecture if resulting_architecture['max_val_acc'] > 0 else None

## Try Improved ColabNAS!

In [None]:
# Download dataset
dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
data_dir = tf.keras.utils.get_file('flower_photos.tar', origin=dataset_url, extract=True)
data_dir = Path(data_dir).with_suffix('')

In [None]:
import numpy as np
import tensorflow as tf

input_shape = (50,50,3)

# Target: STM32L412KBU3
# 273 CoreMark, 40 kiB RAM, 128 kiB Flash
peak_RAM_upper_bound = 40960
Flash_upper_bound = 131072
MACC_upper_bound = 2730000  # CoreMark * 1e4

path_to_training_set = data_dir
val_split = 0.3
cache = True
save_path = '/content/'

# Show GPU info
!nvidia-smi

# Initialize improved ColabNAS
improved_nas = ImprovedColabNAS(
    peak_RAM_upper_bound, Flash_upper_bound, MACC_upper_bound, 
    path_to_training_set, val_split, cache, input_shape, save_path=save_path
)

# Run enhanced search
path_to_tflite_model = improved_nas.search()

## Key Improvements

### 1. Block Variants
- **Basic Block**: Standard convolution + batch norm + ReLU
- **Residual Block**: Adds skip connections for better gradient flow
- **Dense Block**: Concatenates features for better information flow

### 2. Enhanced Search Strategy
- Epsilon threshold (0.005) for more robust convergence
- Early termination when constraints are violated
- Systematic exploration of multiple block types

### 3. Better Training
- Early stopping to prevent overfitting
- Learning rate reduction on plateau
- Enhanced data augmentation (zoom, contrast)
- Dropout layers for regularization

### 4. Improved Robustness
- Better error handling
- More comprehensive constraint checking
- Adaptive learning rate with decay

This implementation follows the robust search strategy from the pseudo code while adding practical improvements for better performance and reliability.