# Class Map Generation
* Generate a file `classes.json` that contains a map, `CLASSES` of all classes by their logit index for use by our tfjs model loader

## Setup

In [2]:
import matplotlib.pyplot as plt
import numpy as np
import PIL
import datetime
import os
import pandas as pd

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.models import Sequential

from tensorflow.keras.preprocessing.image import ImageDataGenerator

import tensorflow_hub as hub
from keras.utils.layer_utils import count_params


## Enumerate Datasets to test

In [3]:
import pathlib

flowers_dataset_url = "https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz"
flowers_data_dir = tf.keras.utils.get_file('flower_photos', origin=flowers_dataset_url, untar=True)
flowers_data_dir = pathlib.Path(flowers_data_dir)

datasets = [
    ('CUB-200-2011', '/mnt/cub/CUB_200_2011/images'),
    ('flowers', flowers_data_dir),
]

## Dataset

In [8]:
# build a dataset object from a directory of images
def build_dataset(
    dataset,
    image_size,
    preprocess_input = None,
    batch_size = 64,
):
   
    train_ds, val_ds = tf.keras.utils.image_dataset_from_directory(
        dataset[1],
        batch_size = batch_size,
        validation_split = 0.2,
        image_size = image_size,
        subset = "both",
        shuffle = True, # default but here for clarity
        seed=42,
        label_mode="categorical" # enables one-hot encoding (use 'int' for sparse_categorical_crossentropy loss)
    )
    
    # Retrieve class names
    # (can't do this after converting to PrefetchDataset?)
    class_names = train_ds.class_names
    
    # Prefetch images
    train_ds = train_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
    val_ds = val_ds.cache().prefetch(buffer_size=tf.data.AUTOTUNE)
    
    if preprocess_input:
        # apply preprocessing function
        train_ds = train_ds.map(
            lambda x, y: (preprocess_input(x), y),
            num_parallel_calls = 16,
        )
        val_ds = val_ds.map(
            lambda x, y: (preprocess_input(x), y),
            num_parallel_calls = 16,
        )

    return (train_ds, val_ds, class_names)

In [10]:
train_ds, val_ds, class_names = build_dataset(
    datasets[0],
    (299, 299),
)

Found 11788 files belonging to 200 classes.
Using 9431 files for training.
Using 2357 files for validation.


In [11]:
class_names

['001.Black_footed_Albatross',
 '002.Laysan_Albatross',
 '003.Sooty_Albatross',
 '004.Groove_billed_Ani',
 '005.Crested_Auklet',
 '006.Least_Auklet',
 '007.Parakeet_Auklet',
 '008.Rhinoceros_Auklet',
 '009.Brewer_Blackbird',
 '010.Red_winged_Blackbird',
 '011.Rusty_Blackbird',
 '012.Yellow_headed_Blackbird',
 '013.Bobolink',
 '014.Indigo_Bunting',
 '015.Lazuli_Bunting',
 '016.Painted_Bunting',
 '017.Cardinal',
 '018.Spotted_Catbird',
 '019.Gray_Catbird',
 '020.Yellow_breasted_Chat',
 '021.Eastern_Towhee',
 '022.Chuck_will_Widow',
 '023.Brandt_Cormorant',
 '024.Red_faced_Cormorant',
 '025.Pelagic_Cormorant',
 '026.Bronzed_Cowbird',
 '027.Shiny_Cowbird',
 '028.Brown_Creeper',
 '029.American_Crow',
 '030.Fish_Crow',
 '031.Black_billed_Cuckoo',
 '032.Mangrove_Cuckoo',
 '033.Yellow_billed_Cuckoo',
 '034.Gray_crowned_Rosy_Finch',
 '035.Purple_Finch',
 '036.Northern_Flicker',
 '037.Acadian_Flycatcher',
 '038.Great_Crested_Flycatcher',
 '039.Least_Flycatcher',
 '040.Olive_sided_Flycatcher',
 '

In [18]:
def parse_name(name):
    
    split = name.split('.')
    
    return {
        "index": int(split[0]),
        "name": split[1].replace('_', ' '),
    }

In [26]:
df = pd.DataFrame(list(map(parse_name, class_names)))

In [28]:
df.set_index("index")

Unnamed: 0_level_0,name
index,Unnamed: 1_level_1
1,Black footed Albatross
2,Laysan Albatross
3,Sooty Albatross
4,Groove billed Ani
5,Crested Auklet
...,...
196,House Wren
197,Marsh Wren
198,Rock Wren
199,Winter Wren


In [34]:
import json
data = json.loads(df.to_json(orient='index'))

{'0': {'index': 1, 'name': 'Black footed Albatross'},
 '1': {'index': 2, 'name': 'Laysan Albatross'},
 '2': {'index': 3, 'name': 'Sooty Albatross'},
 '3': {'index': 4, 'name': 'Groove billed Ani'},
 '4': {'index': 5, 'name': 'Crested Auklet'},
 '5': {'index': 6, 'name': 'Least Auklet'},
 '6': {'index': 7, 'name': 'Parakeet Auklet'},
 '7': {'index': 8, 'name': 'Rhinoceros Auklet'},
 '8': {'index': 9, 'name': 'Brewer Blackbird'},
 '9': {'index': 10, 'name': 'Red winged Blackbird'},
 '10': {'index': 11, 'name': 'Rusty Blackbird'},
 '11': {'index': 12, 'name': 'Yellow headed Blackbird'},
 '12': {'index': 13, 'name': 'Bobolink'},
 '13': {'index': 14, 'name': 'Indigo Bunting'},
 '14': {'index': 15, 'name': 'Lazuli Bunting'},
 '15': {'index': 16, 'name': 'Painted Bunting'},
 '16': {'index': 17, 'name': 'Cardinal'},
 '17': {'index': 18, 'name': 'Spotted Catbird'},
 '18': {'index': 19, 'name': 'Gray Catbird'},
 '19': {'index': 20, 'name': 'Yellow breasted Chat'},
 '20': {'index': 21, 'name': 'E

In [36]:
l = list(map(parse_name, class_names))

In [46]:
mapping = {}
for item in l:
    mapping[item['index']] = item['name']

{'index': 1, 'name': 'Black footed Albatross'}
{'index': 2, 'name': 'Laysan Albatross'}
{'index': 3, 'name': 'Sooty Albatross'}
{'index': 4, 'name': 'Groove billed Ani'}
{'index': 5, 'name': 'Crested Auklet'}
{'index': 6, 'name': 'Least Auklet'}
{'index': 7, 'name': 'Parakeet Auklet'}
{'index': 8, 'name': 'Rhinoceros Auklet'}
{'index': 9, 'name': 'Brewer Blackbird'}
{'index': 10, 'name': 'Red winged Blackbird'}
{'index': 11, 'name': 'Rusty Blackbird'}
{'index': 12, 'name': 'Yellow headed Blackbird'}
{'index': 13, 'name': 'Bobolink'}
{'index': 14, 'name': 'Indigo Bunting'}
{'index': 15, 'name': 'Lazuli Bunting'}
{'index': 16, 'name': 'Painted Bunting'}
{'index': 17, 'name': 'Cardinal'}
{'index': 18, 'name': 'Spotted Catbird'}
{'index': 19, 'name': 'Gray Catbird'}
{'index': 20, 'name': 'Yellow breasted Chat'}
{'index': 21, 'name': 'Eastern Towhee'}
{'index': 22, 'name': 'Chuck will Widow'}
{'index': 23, 'name': 'Brandt Cormorant'}
{'index': 24, 'name': 'Red faced Cormorant'}
{'index': 25

In [47]:
mapping

{1: 'Black footed Albatross',
 2: 'Laysan Albatross',
 3: 'Sooty Albatross',
 4: 'Groove billed Ani',
 5: 'Crested Auklet',
 6: 'Least Auklet',
 7: 'Parakeet Auklet',
 8: 'Rhinoceros Auklet',
 9: 'Brewer Blackbird',
 10: 'Red winged Blackbird',
 11: 'Rusty Blackbird',
 12: 'Yellow headed Blackbird',
 13: 'Bobolink',
 14: 'Indigo Bunting',
 15: 'Lazuli Bunting',
 16: 'Painted Bunting',
 17: 'Cardinal',
 18: 'Spotted Catbird',
 19: 'Gray Catbird',
 20: 'Yellow breasted Chat',
 21: 'Eastern Towhee',
 22: 'Chuck will Widow',
 23: 'Brandt Cormorant',
 24: 'Red faced Cormorant',
 25: 'Pelagic Cormorant',
 26: 'Bronzed Cowbird',
 27: 'Shiny Cowbird',
 28: 'Brown Creeper',
 29: 'American Crow',
 30: 'Fish Crow',
 31: 'Black billed Cuckoo',
 32: 'Mangrove Cuckoo',
 33: 'Yellow billed Cuckoo',
 34: 'Gray crowned Rosy Finch',
 35: 'Purple Finch',
 36: 'Northern Flicker',
 37: 'Acadian Flycatcher',
 38: 'Great Crested Flycatcher',
 39: 'Least Flycatcher',
 40: 'Olive sided Flycatcher',
 41: 'Sciss

In [48]:
import json
json.dumps(mapping)

'{"1": "Black footed Albatross", "2": "Laysan Albatross", "3": "Sooty Albatross", "4": "Groove billed Ani", "5": "Crested Auklet", "6": "Least Auklet", "7": "Parakeet Auklet", "8": "Rhinoceros Auklet", "9": "Brewer Blackbird", "10": "Red winged Blackbird", "11": "Rusty Blackbird", "12": "Yellow headed Blackbird", "13": "Bobolink", "14": "Indigo Bunting", "15": "Lazuli Bunting", "16": "Painted Bunting", "17": "Cardinal", "18": "Spotted Catbird", "19": "Gray Catbird", "20": "Yellow breasted Chat", "21": "Eastern Towhee", "22": "Chuck will Widow", "23": "Brandt Cormorant", "24": "Red faced Cormorant", "25": "Pelagic Cormorant", "26": "Bronzed Cowbird", "27": "Shiny Cowbird", "28": "Brown Creeper", "29": "American Crow", "30": "Fish Crow", "31": "Black billed Cuckoo", "32": "Mangrove Cuckoo", "33": "Yellow billed Cuckoo", "34": "Gray crowned Rosy Finch", "35": "Purple Finch", "36": "Northern Flicker", "37": "Acadian Flycatcher", "38": "Great Crested Flycatcher", "39": "Least Flycatcher", "

In [50]:
with open('classes.json', 'w') as out:
    json.dump(mapping, out, indent=2)