In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.applications import MobileNetV3Small
import json
import matplotlib.pyplot as plt
import numpy as np
from io import BytesIO
import sys
import os


# Interactive widgets
from ipywidgets import widgets


# Add the parent directory to the Python path
parent_dir = os.path.abspath(os.path.join(os.getcwd(), ".."))
sys.path.append(parent_dir)

# Now you can import the module
from src.utils import create_model

In [2]:
# Constants
IMG_SIZE = (224, 224)  # Image size (height, width)
BATCH_SIZE = 32  # Batch size

# Example Usage
train_dir = "../data/processed/train/"
val_dir = "../data/processed/val/"
test_dir = "../data/processed/test/"

In [None]:
def parse_jsonl(jsonl_path):
    """
    Parse a JSONL file and yield image paths and labels.
    """
    with open(jsonl_path, "r") as f:
        for line in f:
            item = json.loads(line)
            yield item["image"], item["label"]


def load_datasets_from_directory(data_dir, batch_size=BATCH_SIZE, img_size=IMG_SIZE):
    """
    Create a combined tf.data.Dataset from JSONL files in a directory.
    Args:
        data_dir (str): Directory containing JSONL files for train, val, or test splits.
        batch_size (int): Batch size for the dataset.
        img_size (tuple): Target size for images (height, width).
    Returns:
        tf.data.Dataset: Combined TensorFlow dataset.
    """
    # Collect all JSONL files in the directory
    jsonl_files = [
        os.path.join(data_dir, fname)
        for fname in os.listdir(data_dir)
        if fname.endswith(".jsonl")
    ]

    # Helper function to load and parse one JSONL file
    def load_single_jsonl(jsonl_path):
        """
        Create a tf.data.Dataset from a single JSONL file.
        """

        def generator():
            # Parse JSONL into image paths and labels
            with open(jsonl_path, "r") as f:
                for line in f:
                    item = json.loads(line)
                    yield item["image"], item["label"]

        # Create a dataset for this JSONL file
        dataset = tf.data.Dataset.from_generator(
            generator, output_types=(tf.string, tf.int32), output_shapes=((), ())
        )
        return dataset

    # Combine all datasets using flat_map
    combined_dataset = None
    for jsonl_path in jsonl_files:
        single_dataset = load_single_jsonl(jsonl_path)
        combined_dataset = (
            single_dataset
            if combined_dataset is None
            else combined_dataset.concatenate(single_dataset)
        )

    # Preprocessing pipeline
    def preprocess(image_path, label):
        # Load and decode image
        try:
            image = tf.io.read_file(image_path)
            image = tf.image.decode_png(image, channels=3)
        except tf.errors.NotFoundError:
            print(f"File not found: {image_path.numpy().decode('utf-8')}")
            return None, None

        # Resize
        image = tf.image.resize(image, img_size)
        return image, label

    # Apply preprocessing, batching, and shuffling
    combined_dataset = (
        combined_dataset.map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
        .shuffle(buffer_size=1000)
        .batch(batch_size)
        .prefetch(tf.data.AUTOTUNE)
    )

    return combined_dataset


train_dataset = load_datasets_from_directory(train_dir)
# Check the type
dataset_type = type(train_dataset)
print(
    f"train_dataset inherits from tf.data.Dataset: {issubclass(dataset_type, tf.data.Dataset)}"
)

val_dataset = load_datasets_from_directory(val_dir)
# Check the type
dataset_type = type(val_dataset)
print(
    f"val_dataset inherits from tf.data.Dataset: {issubclass(dataset_type, tf.data.Dataset)}"
)


test_dataset = load_datasets_from_directory(test_dir)

# Check the type
dataset_type = type(test_dataset)
print(
    f"test_dataset inherits from tf.data.Dataset: {issubclass(dataset_type, tf.data.Dataset)}"
)

# Example: Inspect a batch
for images, labels in train_dataset.take(1):
    print(f"Batch of images shape: {images.shape}")
    print(f"Batch of labels: {labels}")

In [None]:
def load_custom_image_dataset(directory, img_size=IMG_SIZE, batch_size=BATCH_SIZE):
    """
    Load a custom dataset of images organized by 'charts' and 'non_charts' subdirectories.

    Args:
        directory (str): Path to the directory containing images organized in subdirectories by label.
        img_size (tuple): Target size for images (height, width).
        batch_size (int): Number of images per batch.

    Returns:
        tf.data.Dataset: A TensorFlow dataset containing images and their labels.
        class_names (list): List of class names inferred from subdirectory names.
    """
    # Load dataset with labels inferred from directory structure
    dataset = tf.keras.utils.image_dataset_from_directory(
        directory,
        labels="inferred",  # Infer labels from subdirectory names
        label_mode="int",  # Return integer labels
        image_size=img_size,
        batch_size=batch_size,
        shuffle=True,
    )

    # Extract class names before transformations
    class_names = dataset.class_names

    return dataset, class_names


# Use it
custom_train_dir = "../data/processed/train"
custom_train_dataset, class_names = load_custom_image_dataset(custom_train_dir)

custom_val_dir = "../data/processed/val"
custom_val_dataset, _ = load_custom_image_dataset(custom_val_dir)

# Check the type
dataset_type = type(custom_train_dataset)
print(
    f"custom_train_dataset inherits from tf.data.Dataset: {issubclass(dataset_type, tf.data.Dataset)}"
)

# Check the type
dataset_type = type(custom_val_dataset)
print(
    f"custom_val_dataset inherits from tf.data.Dataset: {issubclass(dataset_type, tf.data.Dataset)}"
)


print(f"Class names: {class_names}")  # Output: ['charts', 'non_charts']

# Inspect the dataset
for images, labels in custom_train_dataset.take(1):
    print(f"Batch of images shape: {images.shape}")
    print(f"Batch of labels: {labels}")

In [None]:
train_dataset = train_dataset.concatenate(custom_train_dataset)
val_dataset = val_dataset.concatenate(custom_val_dataset)

# Example: Inspect a batch
for images, labels in train_dataset.take(1):
    print(f"Batch of images shape: {images.shape}")
    print(f"Batch of labels: {labels}")

In [None]:
list(train_dataset.take(1))[0]

In [None]:
# Get one batch from the dataset
image_batch, label_batch = list(train_dataset.take(1))[0]

# Check the shapes
print(f"image batch shape: {image_batch.shape}")
print(f"label batch shape: {label_batch.shape}")

In [8]:


# # Visualize a few samples
# for i in range(5):  # Display the first 5 images in the batch
#     plt.imshow(image_batch[i].numpy().astype("uint8"))
#     plt.title(f"Label: {label_batch[i]}")
#     plt.axis("off")
#     plt.show()

In [None]:
def plot_image_grid(images, labels, class_names, rows=2, cols=4):
    """
    Plot a tight grid of randomly selected images with their labels.

    Args:
        images (numpy.ndarray): Batch of images to display.
        labels (numpy.ndarray): Corresponding labels.
        class_names (list): Class names for labels.
        rows (int): Number of rows in the grid.
        cols (int): Number of columns in the grid.
    """
    # Shuffle the indices
    indices = np.random.permutation(len(images))
    selected_images = images[indices[: rows * cols]]
    selected_labels = labels[indices[: rows * cols]]

    fig, axes = plt.subplots(rows, cols, figsize=(cols * 3, rows * 3))
    fig.tight_layout(pad=1.0)

    for i, ax in enumerate(axes.flat):
        if i >= len(selected_images):
            break
        ax.imshow(selected_images[i].astype("uint8"))
        ax.set_title(f"Label: {class_names[selected_labels[i]]}")
        ax.axis("off")

    plt.show()


# Convert tensors to numpy arrays for visualization
image_batch_np = image_batch.numpy()
label_batch_np = label_batch.numpy()

# Class names for labels
class_names = ["charts", "non_charts"]

# Plot the grid
plot_image_grid(image_batch_np, label_batch_np, class_names, rows=2, cols=4)

In [None]:
print(image_batch[0].numpy())

In [None]:
print(label_batch.numpy())

In [None]:
# Check the range of values
print(f"max value: {np.max(image_batch[0].numpy())}")
print(f"min value: {np.min(image_batch[0].numpy())}")

In [13]:
rescale_layer = tf.keras.layers.Rescaling(scale=1.0 / 255)

In [None]:
image_scaled = rescale_layer(image_batch[20]).numpy()

print(image_scaled)

In [None]:
print(f"max value: {np.max(image_scaled)}")
print(f"min value: {np.min(image_scaled)}")

In [16]:
def normalize_dataset(dataset):
    """
    Normalize a tf.data.Dataset using a Rescaling layer.
    Args:
        dataset: The tf.data.Dataset to normalize.
    Returns:
        A normalized tf.data.Dataset.
    """
    return dataset.map(
        lambda image, label: (rescale_layer(image), label),
        num_parallel_calls=tf.data.AUTOTUNE,
    ).prefetch(tf.data.AUTOTUNE)


# Normalize datasets
train_dataset_scaled = normalize_dataset(train_dataset)
val_dataset_scaled = normalize_dataset(val_dataset)
test_dataset_scaled = normalize_dataset(test_dataset)

In [None]:
# Get one batch of data
sample_batch = list(train_dataset_scaled.take(1))[0]

# Get the image
image_scaled = sample_batch[0][10].numpy()

# Check the range of values for this image
print(f"max value: {np.max(image_scaled)}")
print(f"min value: {np.min(image_scaled)}")

In [None]:
# model, model_file_name = create_model("mobile")
# model, model_file_name = create_model("custom-1")
# model, model_file_name = create_model("resnet")
# model, model_file_name = create_model("efficientnet")
# model, model_file_name = create_model("densenet")
model, model_file_name = create_model("mobile_large")

# Model summary
model.summary()

In [19]:
SHUFFLE_BUFFER_SIZE = 1000
PREFETCH_BUFFER_SIZE = tf.data.AUTOTUNE

train_dataset_final = (
    train_dataset_scaled.cache()
    .shuffle(SHUFFLE_BUFFER_SIZE)
    .prefetch(PREFETCH_BUFFER_SIZE)
)

# Configure the validation dataset
validation_dataset_final = val_dataset_scaled.cache().prefetch(
    PREFETCH_BUFFER_SIZE
)

# Configure the test dataset
test_dataset_final = test_dataset_scaled.cache().prefetch(PREFETCH_BUFFER_SIZE)

In [None]:
save_checkpoint = tf.keras.callbacks.ModelCheckpoint(
    filepath=model_file_name,
    monitor="val_loss",
    save_best_only=True,
    verbose=1,
)

In [None]:
history = model.fit(
    train_dataset_final,
    validation_data=validation_dataset_final,
    epochs=20,
    callbacks=[tf.keras.callbacks.EarlyStopping(patience=3, restore_best_weights=True), save_checkpoint],
    verbose=1,
)

In [21]:
# Save the model --  See save_checkpoint
# model.save(model_file_name)

# model.export("../models/mobilenetv3_classifier_serving")

In [None]:
# Plot the training and validation accuracies for each epoch

acc = history.history["accuracy"]
val_acc = history.history["val_accuracy"]
loss = history.history["loss"]
val_loss = history.history["val_loss"]

epochs = range(len(acc))

plt.plot(epochs, acc, "r", label="Training accuracy")
plt.plot(epochs, val_acc, "b", label="Validation accuracy")
plt.title("Training and validation accuracy")
plt.legend(loc=0)
plt.show()

In [None]:
test_loss, test_accuracy, test_auc = model.evaluate(test_dataset_final)
print(f"Test Loss: {test_loss}")
print(f"Test Accuracy: {test_accuracy}")
print(f"Test AUC: {test_auc}")

In [None]:


# Create the widget and take care of the display
uploader = widgets.FileUpload(accept="image/*", multiple=True)
display(uploader)
out = widgets.Output()
display(out)


def file_predict(filename, file, out):
    """A function for creating the prediction and printing the output."""
    image = tf.keras.utils.load_img(file, target_size=IMG_SIZE)
    image = tf.keras.utils.img_to_array(image)
    image = rescale_layer(image)
    image = np.expand_dims(image, axis=0)

    prediction = model.predict(image, verbose=0)[0][0]

    with out:
        if prediction <= 0.5:
            print(filename + " is a chart")
        else:
            print(filename + " is not a chart")


def on_upload_change(change):
    """A function for geting files from the widget and running the prediction."""
    # Get the newly uploaded file(s)

    items = change.new
    for item in items:  # Loop if there is more than one file uploaded
        file_jpgdata = BytesIO(item.content)
        file_predict(item.name, file_jpgdata, out)


# Run the interactive widget
# Note: it may take a bit after you select the image to upload and process before you see the output.
uploader.observe(on_upload_change, names="value")

In [None]:
# Shutdown the kernel to free up resources.
# Note: You can expect a pop-up when you run this cell. You can safely ignore that and just press `Ok`.

# from IPython import get_ipython

# k = get_ipython().kernel

# k.do_shutdown(restart=False)