# Tooolchain Machine Learning on Microcontrollers (MCU)

![](pictures/Visualization-toolchain.png)


See [README.md](README.md)

[Paper on arXiv](https://arxiv.org/abs/2104.10645).




## Notebooks


1. `01_...` Model selection and analysis
2. `02_..` Conversion and optimization
3. `03_...` Preparation of required files for the deployment on the MCU

- ...

## Imports and helpers

Below are some imports and functions which are required by all the nootebooks.

In [None]:
print("Importing libraries and helper functions from 00_README.ipynb")

import tensorflow as tf
from tensorflow.keras import layers
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import ipywidgets as widgets
from ipywidgets import interact, interactive, fixed, interact_manual, Layout
from datetime import datetime, date
import time
import serial
import glob # for querying files within directory
import os # for filesize
from sys import platform
import json
from rocketlogger.data import RocketLoggerData
from keras_flops import get_flops
from scipy.stats import linregress


print("\tTensorflow Version: ", tf.__version__)
print("\tNumpy Version: ", np.__version__)
print("\tPandas Version: ", pd.__version__)
print("Imported all modules.")

Helper functions and setting global variables:

In [None]:
tf.get_logger().setLevel('ERROR')

np.random.seed(42)
image_no = np.random.randint(10000)


loss_fn_crossentropy = tf.keras.losses.CategoricalCrossentropy(reduction='sum_over_batch_size')
loss_fn_meansquared = tf.keras.losses.MeanSquaredError(reduction='sum_over_batch_size')

model_selection = widgets.Dropdown(
    options=sorted(glob.glob("keras-model/*.h5")),
    description='Select model:',
)

def get_tf_model_string(tf_model_file):
    start = tf_model_file.find('01')
    end = tf_model_file.find('.h5')
    return tf_model_file[start:end]

def get_tfl_model_string(tfl_model_file):
    start = tfl_model_file.find('01')
    end = tfl_model_file.find('.tflite')
    return tfl_model_file[start:end]


# helper function to get model name
def get_model_name(model_string):
    if 'ResNet' in model_string:
        model_name = 'ResNet20_CIFAR-10'
    elif 'LeNet' in model_string:
        model_name = 'LeNet-MNIST'
    else:
        model_name = 'unknown'
        print('unknown model name')
    return model_name