# Generating 🦕Dinosaur names on microcontrollers using CircuitPython and Markov Chains

This notebook contains code for generating random dinosaur names in [CircuitPython](https://docs.circuitpython.org/en/latest/README.html) using [Markov Chains](https://en.wikipedia.org/wiki/Markov_chain).

For deploying code to microcontrollers refere to my Github repo : https://github.com/code2k13/circuitpython_textgen

## Why Markov Chains ?

- C++ is hard. Setting up dev enviornment can be challenging for newbies.
- CircuitPython has many useful libraries available out of box.
- Loading Tflite models is currently not supported by CircuitPython. Checkout this open source project https://github.com/mocleiri/tensorflow-micropython-examples for work in direction of making this possible (doesn't work for all boards as of now).
- RNNs are complicated and use more memory. 
- Markov Chains are simpler and lightweight !

## What does this notebook contain ?

This notebook is divided into two parts. **Part 1** contains code that takes a text file with dinosaur names and generates a Markov Chain model and saves it as JSON. You can run this code anywhere (on microcontroller or normal computer).  **Part 2** contains code that can be executed on a microcontroller. It is written for CircuitPython. It reads the model from JSON files and generates some random samples.

> Note: This code only depends on **json** and **random** packages. CircuitPython does not contain implementation for random.choices , so I had to write my own implementation called '**custom_random_choices**' which you will find in Part 2

## Part 1: Creating the Markov Chain from text File

Let us download a text file containg dinosaur names (one name per line) from github. We are using the file from this github project: https://github.com/junosuarez/dinosaurs. Feel free to use any other file if you want (say names of elves or cities or animals or anything )

In [1]:
!wget https://raw.githubusercontent.com/junosuarez/dinosaurs/master/dinosaurs.csv

--2022-08-19 07:35:15--  https://raw.githubusercontent.com/junosuarez/dinosaurs/master/dinosaurs.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.109.133, 185.199.110.133, 185.199.108.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.109.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 19656 (19K) [text/plain]
Saving to: ‘dinosaurs.csv’


2022-08-19 07:35:15 (58.3 MB/s) - ‘dinosaurs.csv’ saved [19656/19656]



Let us open the file and read all text in single varialble. Then we get set of unique letters and iterate through them to create a Markov Chain.

In [2]:
with open("dinosaurs.csv") as f:
    txt = f.read()
unique_letters = set(txt)

In [3]:
import re
from collections import Counter
chain = {}
for l in unique_letters:
    next_chars = []
    try:
        for m in re.finditer(l, txt):
             start_idx = m.start()
             if start_idx+1 < len(txt):
                next_chars.append(txt[start_idx+1])
        chain[l] = dict(Counter(next_chars))
    except:
        continue    

Given below is chain for charactor 'a'. The chain is just a dictionary where keys are charactors that come after 'a' in our file and how many times. For example 'a' follows 'a' 13 times in below example. Charactor 'c' follows 'a' 103 times and so on !! The Markov Chain object is just a dictionary of dictionaries.

In [4]:
chain['a']

{'a': 13,
 'c': 103,
 'u': 744,
 'r': 125,
 'b': 21,
 'n': 342,
 'p': 89,
 't': 214,
 'v': 30,
 'd': 35,
 'm': 70,
 's': 160,
 'e': 42,
 'f': 5,
 'g': 41,
 '\n': 143,
 'h': 17,
 'i': 24,
 'k': 20,
 'j': 6,
 'l': 140,
 'x': 11,
 'z': 8,
 'o': 23,
 'q': 4,
 'y': 15,
 'w': 7,
 '_': 1}

In [5]:
import json
with open("dino_chain.json", "w") as i :
   json.dump(chain, i)

## Part 2: CircuitPython code to make inferences

The code in following bocks needs to be run on device that supports CircuitPython. It uses the '**custom_random_choices**' function which I talked about earlier (its a pure python replacement for random.choices). 

In [6]:
import json
import random

def custom_random_choices(arr,weights):
    cum_weights = []
    weights_sum = 0
    for i in range(0,len(weights)):        
        weights_sum = weights_sum + weights[i]
        cum_weights.append(weights_sum)
    assert abs(1-weights_sum) < 0.0001
    r_weight = random.uniform(0,1)    
    
    start_idx = 0  
    end_idx  = len(arr)-1
    middle_idx  = end_idx//2       

    while end_idx-start_idx > 2 :      
        if cum_weights[middle_idx] <= r_weight:
            start_idx = middle_idx
        else:
            end_idx = middle_idx            
        middle_idx = start_idx + ((end_idx-start_idx)//2 )  
        
    if cum_weights[middle_idx] < r_weight:
        return arr[end_idx]
    else:
        if cum_weights[start_idx] > r_weight:
            return arr[start_idx]
        return arr[middle_idx]
    return -1


Let's load the dino_chain.json , this file **MUST be present on micrcontroller device**.

In [7]:
with open('dino_chain.json', 'r') as f:
    data=f.read()
dino_chain =  json.loads(data)

In [8]:
def generate_text(n,chain):
    op = []
    all_states = list(chain.keys())
    initial_state = random.choice(all_states)
    op.append(initial_state)
    for i in range(0,n):
        possible_next_states = chain[initial_state].keys()
        possible_next_state_weights = list(chain[initial_state].values())
        total = sum(possible_next_state_weights)
        possible_next_state_weights = [n*1.000/total for n in possible_next_state_weights]
        #Use this line for normal Python
        #next_state = random.choices(list(possible_next_states),k = 1,weights = possible_next_state_weights)[0] 
        #Use this line for CircuitPython
        next_state = custom_random_choices(list(possible_next_states) ,possible_next_state_weights)
        op.append(next_state)
        initial_state = next_state
    return ''.join(op)

Let use generate some random dinosaur names. We will filter names which are too short or too long !!

In [9]:
a = generate_text(100,dino_chain);
names = a.split("\n")
for n in names:
    if len(n) > 4 and len(n) < 30:
        print(n)

olerrasaaterullocaururusaqur
ermintimondrroecr
latos
ururanysataura
craus
ausanglaus
