# INDEX
* [Functions](#Functions)
* [Configuration](#Configuration)
* [Preprocess training data](#Preprocess-training-data)
* [Train](#Train)
* [Test model](#Test-model)
    * [Predict](#Predict)
    * [Explore dataset](#Explore-dataset)

# Functions

In [16]:
import sys
sys.path.append('../') #append a relative path to the top package to the search path

In [17]:
# %load libtrain.py
import datetime
import functools
import json
import pathlib
import shutil
import tempfile
from typing import Dict, List

import dacite
import pandas as pd
import tensorflow as tf
from robotoff.taxonomy import Taxonomy
from tensorflow import keras
from tensorflow.data import Dataset
from tensorflow.keras import callbacks
from tensorflow.python.ops import summary_ops_v2

import settings
from category_classification.data_utils import (
    TFTransformer,
    create_tf_dataset,
    load_dataframe,
)
from category_classification.models import (
    KerasPreprocessing,
    build_model,
    construct_preprocessing,
    to_serving_model,
)

from category_classification.config import Config

from utils.io import (
    copy_category_taxonomy,
    save_category_vocabulary,
    save_config,
    save_json,
)
from utils.metrics import evaluation_report

def create_model(config: Config, preprocess: KerasPreprocessing) -> keras.Model:
    model = build_model(config.model_config, preprocess)
    loss_fn = keras.losses.BinaryCrossentropy(
        label_smoothing=config.train_config.label_smoothing
    )
    optimizer = keras.optimizers.Adam(learning_rate=config.train_config.lr)
    model.compile(
        optimizer=optimizer,
        loss=loss_fn,
        metrics=["binary_accuracy", "Precision", "Recall"],
    )
    return model


class TBCallback(callbacks.TensorBoard):
    """Get around a bug where you cannot use the TensorBoard callback with the StringLookup layers
    - https://github.com/tensorflow/tensorboard/issues/4530#issuecomment-783318292"""

    def _log_weights(self, epoch):
        with self._train_writer.as_default():
            with summary_ops_v2.always_record_summaries():
                for layer in self.model.layers:
                    for weight in layer.weights:
                        if hasattr(weight, "name"):
                            weight_name = weight.name.replace(":", "_")
                            summary_ops_v2.histogram(weight_name, weight, step=epoch)
                            if self.write_images:
                                self._log_weight_as_image(weight, weight_name, epoch)
                self._train_writer.flush()


def train(
    model: keras.Model,
    save_dir: pathlib.Path,
    config: Config,
    category_vocab: List[str],
):
    print("Starting training...")
    temporary_log_dir = pathlib.Path(tempfile.mkdtemp())
    print("Temporary log directory: {}".format(temporary_log_dir))

    tf_transformer = TFTransformer(category_vocab)

    train = create_tf_dataset("train", config.train_config.batch_size, tf_transformer)
    val = create_tf_dataset("val", config.train_config.batch_size, tf_transformer)

    history=model.fit(train,
        epochs= config.train_config.epochs,
        validation_data=val,
        callbacks=[
            callbacks.TerminateOnNaN(),
            callbacks.ModelCheckpoint(
                filepath=str(save_dir / "weights.{epoch:02d}-{val_loss:.4f}"),
                monitor="val_loss",
                save_best_only=True,
                save_format='tf',
            ),
            #TBCallback(log_dir=str(temporary_log_dir), histogram_freq=1),
            callbacks.EarlyStopping(monitor="val_loss", patience=4),
            #callbacks.CSVLogger(str(save_dir / "training.csv")),
            callbacks.History()
        ],
    )
    print("Training ended")
    return history


# Configuration

In [19]:
# load config json
import json
 
# Opening JSON file
with open('../config.json') as json_file:
    json_config = json.load(json_file)
json_config 

config=dacite.from_dict(Config, json_config)
model_config=config.model_config

output_dir:pathlib.Path = pathlib.Path("../models")
replicates = 1

output_dir.mkdir(parents=True, exist_ok=True)

# Preprocess training data

In [20]:
%%time
keras_preprocess = construct_preprocessing(
    model_config.category_min_count,
    model_config.ingredient_min_count,
    model_config.product_name_max_tokens,
    model_config.product_name_max_length,
    load_dataframe("train"),
)
print("Pre-processed training data")

2022-04-06 09:23:20.877531: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-04-06 09:23:21.123901: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:185] None of the MLIR Optimization Passes are enabled (registered 2)


Pre-processed training data
CPU times: user 4min 37s, sys: 6.47 s, total: 4min 43s
Wall time: 4min 30s


# Train

In [7]:
%%time
if replicates == 1:
    save_dirs = [output_dir]
else:
    save_dirs = [output_dir / str(i) for i in range(replicates)]

for i, save_dir in enumerate(save_dirs):
    model = create_model(config, keras_preprocess)
    save_dir.mkdir(exist_ok=True)
    config.train_config.start_datetime = str(datetime.datetime.utcnow())
    print(f"Starting training repeat {i}")

    save_config(config, save_dir)
    copy_category_taxonomy(settings.CATEGORY_TAXONOMY_PATH, save_dir)
    save_category_vocabulary(keras_preprocess.category_vocab, save_dir)

    history=train(
        model,
        save_dir,
        config,
        keras_preprocess.category_vocab,
    )

    config.train_config.end_datetime = str(datetime.datetime.utcnow())
    save_config(config, save_dir)
    config.train_config.start_datetime = None
    config.train_config.end_datetime = None

Starting training repeat 0
Starting training...
Temporary log directory: /var/folders/c7/w4lf4cp91_j_p3dxm9w00rmh0000gn/T/tmp2o2lcebx
Epoch 1/50
INFO:tensorflow:Assets written to: models/weights.01-0.0040/assets
Epoch 2/50
INFO:tensorflow:Assets written to: models/weights.02-0.0023/assets
Epoch 3/50
INFO:tensorflow:Assets written to: models/weights.03-0.0018/assets
Epoch 4/50
INFO:tensorflow:Assets written to: models/weights.04-0.0016/assets
Epoch 5/50
INFO:tensorflow:Assets written to: models/weights.05-0.0015/assets
Epoch 6/50
INFO:tensorflow:Assets written to: models/weights.06-0.0015/assets
Epoch 7/50
INFO:tensorflow:Assets written to: models/weights.07-0.0014/assets
Epoch 8/50
INFO:tensorflow:Assets written to: models/weights.08-0.0014/assets
Epoch 9/50
INFO:tensorflow:Assets written to: models/weights.09-0.0014/assets
Epoch 10/50
INFO:tensorflow:Assets written to: models/weights.10-0.0014/assets
Epoch 11/50
INFO:tensorflow:Assets written to: models/weights.11-0.0014/assets
Epoch 

# Test model

In [27]:
# load saved model
model=keras.models.load_model('../models/base/saved_model')

In [51]:
category_vocab=keras_preprocess.category_vocab
tf_transformer = TFTransformer(category_vocab)

# Train & validation Dataset => investigate on characteristics of the source
traindata = create_tf_dataset("train", config.train_config.batch_size, tf_transformer)
valdata = create_tf_dataset("val", config.train_config.batch_size, tf_transformer)
testdata = create_tf_dataset("test", config.train_config.batch_size, tf_transformer)


In [60]:
valdata

<PaddedBatchDataset shapes: (((None, None), (None,)), (None, 3969)), types: ((tf.string, tf.string), tf.int32)>

In [61]:
testdata

<PaddedBatchDataset shapes: (((None, None), (None,)), (None, 3969)), types: ((tf.string, tf.string), tf.int32)>

## Predict

In [52]:
%%time
y_pred_val = model.predict(valdata)

CPU times: user 1min 33s, sys: 18.7 s, total: 1min 52s
Wall time: 1min 20s


In [53]:
%%time
y_pred_test=model.predict(testdata)

CPU times: user 1min 40s, sys: 17.7 s, total: 1min 58s
Wall time: 1min 25s


In [55]:
y_pred_val.shape, y_pred_test

((87434, 3969),
 array([[1.1422783e-02, 2.9106607e-08, 1.8616578e-06, ..., 1.3573744e-10,
         1.6678333e-09, 1.1400034e-07],
        [1.9434567e-05, 6.0262022e-09, 4.3250511e-06, ..., 9.5710753e-18,
         1.8540440e-14, 8.8534751e-13],
        [9.8198652e-04, 5.3054283e-10, 1.0436243e-08, ..., 1.4927385e-14,
         1.0176477e-12, 1.2061078e-12],
        ...,
        [7.5667924e-01, 1.3700799e-11, 2.1688192e-12, ..., 2.8258068e-18,
         2.3219839e-16, 3.5128806e-13],
        [2.4676323e-04, 4.8964339e-14, 6.0147217e-12, ..., 1.8756643e-16,
         2.5486165e-15, 1.1011587e-15],
        [1.1568367e-03, 1.1596524e-05, 5.4461787e-05, ..., 1.3919157e-11,
         2.4204746e-13, 6.0214539e-10]], dtype=float32))

In [49]:
y_pred_val

array([[1.2839634e-05, 2.3829339e-10, 1.6613468e-09, ..., 1.5564917e-11,
        3.6129602e-10, 2.1100532e-12],
       [1.0495579e-01, 2.0508166e-09, 2.3943665e-09, ..., 1.3911613e-08,
        1.0217517e-08, 3.8321792e-07],
       [1.9471225e-05, 6.9090309e-14, 3.4578873e-15, ..., 1.7461139e-10,
        1.5341870e-09, 3.1657670e-11],
       ...,
       [3.8076937e-03, 1.6157297e-06, 6.6229504e-06, ..., 1.5495643e-06,
        8.5407401e-06, 3.0552085e-06],
       [1.3194382e-03, 1.6692168e-10, 3.1451568e-09, ..., 4.1883336e-14,
        4.0008542e-17, 5.4479357e-14],
       [1.3995171e-04, 4.6972590e-09, 6.7628537e-07, ..., 2.1551901e-16,
        4.0634002e-10, 5.0259403e-15]], dtype=float32)

In [58]:
from sklearn.metrics import accuracy_score
# accuracy_score is for a classification model
accuracy_score(y_pred_test,y_pred_val)

ValueError: continuous-multioutput is not supported

Error occurs due to continuous values

## Explore dataset

In [10]:
import pandas as pd
from robotoff.utils import gzip_jsonl_iter
import pathlib

import settings
from typing import Any, Callable, Dict, Iterable, Optional, List

def create_dataframe(split: str, lang: str) -> pd.DataFrame:
     if split not in ("train", "test", "val"):
         raise ValueError("split must be either 'train', 'test' or 'val'")

     file_name = "category_{}.{}.jsonl.gz".format(lang, split)
     full_path = settings.DATA_DIR / file_name
     return pd.DataFrame(iter_product(full_path))

def count_categories(df: pd.DataFrame) -> Dict:
    categories_count = defaultdict(int)

    for categories in df.categories_tags:
        for category in categories:
            categories_count[category] += 1

    return categories_count

def iter_product(data_path: pathlib.Path):
    for product in gzip_jsonl_iter(data_path):
        product.pop("images", None)

        if "nutriments" in product:
            nutriments = product["nutriments"] or {}

        yield product

In [11]:
training_ds = create_dataframe("train", "xx")
test_ds = create_dataframe("test", "xx")
val_ds = create_dataframe("val", "xx")

In [56]:
val_ds.shape

(87434, 8)

In [12]:
val_ds.sample(10)

Unnamed: 0,code,nutriments,product_name,categories_tags,ingredient_tags,known_ingredient_tags,ingredients_text,lang
47994,852681918989,"{'cholesterol_100g': 0.071, 'iron_unit': 'mg',...",Organic Double Chocolate Cookies,"[en:biscuits-and-cakes, en:biscuits, en:sweet-...","[en:wheat-flour, en:cereal, en:flour, en:wheat...","[en:wheat-flour, en:cereal, en:flour, en:wheat...","Organic wheat flour, organic semi-sweet chocol...",en
54907,761088191812,"{'energy_value': 69, 'energy-kcal_serving': 16...",Basil Chicken Chili With Beans,"[en:stews, en:meals]","[en:chicken-broth, en:poultry, en:chicken, en:...","[en:chicken-broth, en:poultry, en:chicken, en:...","Chicken stock (water, spice, garlic, salt, bla...",en
25595,853163,"{'sodium_unit': 'g', 'sugars_100g': 46.4, 'pro...",Sicilian lemon curd,"[en:fruit-curds, en:spreads, en:lemon-curds, e...",[],[],,fr
22122,74734115330,"{'nutrition-score-fr': 12, 'fat': 13.33, 'ener...",Crackers,[en:biscuits-and-cakes],"[en:wheat-flour, en:cereal, en:flour, en:wheat...","[en:wheat-flour, en:cereal, en:flour, en:wheat...","Enriched bleached wheat flour (wheat flour, ni...",en
77128,5400111272658,"{'nutrition-score-fr_100g': 13, 'sugars_unit':...",Cornet de glace vanille chocolat,"[en:desserts, en:ice-creams, en:ice-creams-and...","[en:sugar, en:skimmed-milk, en:dairy, en:milk,...","[en:sugar, en:skimmed-milk, en:dairy, en:milk,...","Sucre, lait écrémé, farine de blé, crème fraîc...",fr
77191,3274664099282,"{'carbohydrates': 36.5, 'energy-kcal': 261, 's...",Façon Citron Meringué,"[en:ice-creams-and-sorbets, en:desserts, en:fr...",[],[],,fr
41673,77300505016,"{'nova-group_serving': 1, 'carbohydrates_unit'...",Enriched Long Grain White Rice,"[en:seeds, en:plant-based-foods-and-beverages,...","[en:long-grain-enriched-milled-rice, en:ferric...","[en:ferric-orthophosphate, en:minerals, en:iro...","Long Grain Enriched Milled Rice, Ferric Orthop...",en
84576,78000029000,"{'sodium_serving': 0, 'fat_100g': 0, 'carbohyd...","Sparkling water beverage, black cherry","[en:waters, en:beverages]","[en:carbonated-water, en:water, en:natural-fla...","[en:carbonated-water, en:water, en:natural-fla...","Carbonated water, natural flavors.",en
62737,35826089021,"{'proteins_100g': 0, 'energy_serving': 20.9, '...","Food lion, on-the-go drink mix, lemonade","[en:dried-products, en:dried-products-to-be-re...","[en:e330, en:potassium, en:minerals, en:sodium...","[en:e330, en:potassium, en:minerals, en:sodium...","Citric acid, potassium and sodium citrate, nat...",en
63139,5411788038836,"{'nutrition-score-fr_100g': 6, 'sugars_unit': ...",Umeboshi Past,"[en:fruits-and-vegetables-based-foods, en:plan...","[fr:umeboshi, fr:feuilles-de-shiso, en:sea-sal...","[en:sea-salt, en:salt]","Umeboshi (Prunus mume), feuilles de shiso (Per...",fr
