# Tensorflow Local Dataset Tutorial
In this tutorial, you will see how to use flox to run FL experiments on Tensorflow using a custom dataset stored on real physical endpoints. We will train our model on the [Animals-10 dataset from Kaggle](https://www.kaggle.com/datasets/alessiocorrado99/animals10).

In [11]:
import os
import pickle
import logging

import numpy as np
import tensorflow as tf
from tensorflow import keras

from flox.clients.TensorflowClient import TensorflowClient
from flox.controllers.TensorflowController import TensorflowController
from flox.model_trainers.TensorflowTrainer import TensorflowTrainer

logger = logging.getLogger(__name__)

### Load & Process Data

First, let's load and preprocess the dataset on the *Controller* for evaluating the model. We will use the Animal-10 dataset, which you can get from [Kaggle](https://www.kaggle.com/datasets/alessiocorrado99/animals10). or download it from this [Google Drive directory](https://drive.google.com/drive/u/0/folders/1nGkoNIuwslvfCyFq4eIVDPGcqCYYCmCu).

In [7]:
def process_data(train_image, train_label, num_samples=None):
    depth = 3
    image_size_y = 32
    image_size_x = 32

    if num_samples:
        idx = np.random.choice(np.arange(len(train_image)), num_samples, replace=True)
        train_image = train_image[idx]
        train_label = train_label[idx]

    train_image = train_image.reshape(
        len(train_image), image_size_x, image_size_y, depth
    )
    train_image = train_image / 255.0

    return (train_image, train_label)

with open("../../../data/test_data_animal10_32.pkl", "rb") as file:
    x_test, y_test = pickle.load(file)

x_test, y_test = process_data(x_test, y_test)

### Defining the Model

Next, let's define our Tensorflow model architecture and compile it.

In [12]:
input_shape = (32, 32, 3)
# there are 10 classes in the dataset
num_classes = 10

# define the model architecture
global_model = tf.keras.Sequential(
    [
        tf.keras.Input(shape=input_shape),
        tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        tf.keras.layers.Conv2D(32, kernel_size=(3, 3), activation="relu"),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        tf.keras.layers.Conv2D(64, kernel_size=(3, 3), activation="relu"),
        tf.keras.layers.MaxPooling2D(pool_size=(2, 2)),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dropout(0.5),
        tf.keras.layers.Dense(128, activation="relu"),
        tf.keras.layers.Dense(num_classes, activation="softmax"),
    ]
)

# compile the model
global_model.compile(
    loss="sparse_categorical_crossentropy", optimizer="adam", metrics=["accuracy"]
)

### Instantiating Model Trainer and Client instances

Next, we will initialize an instance of a Tensorflow Model Trainer and Client. Note how we are setting loss to "sparse_categorical_crossentropy" so that the Model Trainer's methods are compatible with the model defined above. You can check out implementations of the classes under ``flox/model_trainers`` and ``flox/clients``, respectively. You can also extend or modify these classes to fit your needs.

In [13]:
tf_trainer = TensorflowTrainer(loss="sparse_categorical_crossentropy")
tf_client = TensorflowClient()

### Instantiating the Controller (FuncX Execution)

Now, let's define our endpoints and initialize the PyTorch *Controller* that will do the heavy lifting of deploying tasks to the endpoints. We will run three rounds of FL, with 100 samples and 1 training epoch on each device. Note that we are specifying ``executor_type`` to "funcx" and we are providing actual funcx endpoint UUIDs. We are also providing the path to the folder where the data is stored on the devices and the filenames. Finally, we'll launch the experiment.

In [None]:
eps = ["fb93a1c2-a8d7-49f3-ad59-375f4e298784", "c7487b2b-b129-47e2-989b-5a9ac361befc"]
logger.info(f"Endpoints: {eps}")

flox_controller = TensorflowController(
    endpoint_ids=eps,
    num_samples=200,
    epochs=1,
    rounds=3,
    client_logic=tf_client,
    global_model=global_model,
    model_trainer=tf_trainer,
    executor_type="funcx",
    data_source="local",
    path_dir="/home/pi/datasets",
    x_train_filename="x_animal10_32.npy",
    y_train_filename="y_animal10_32.npy",
    input_shape=(32, 32, 32, 3),
    x_test=x_test,
    y_test=y_test,
)

logger.info("STARTING FL FLOW...")
flox_controller.run_federated_learning()