# Plot Confusion Matrix

In [5]:
from kfp.components import (
    create_component_from_func,
    InputPath,
    OutputPath
)
from typing import List

BASE_IMAGE = "quay.io/ibm/kubeflow-notebook-image-ppc64le:latest"


def plot_confusion_matrix(
        input_column: str,
        label_column: str,
        prep_dataset_dir: InputPath(str),
        model_dir: InputPath(str),
        labels: List[str],
        mlpipeline_ui_metadata_path: OutputPath(),
        dataset_split: str = "test",
        batch_size: int = 20
):
    '''
    Plots a confusion matrix based on a Huggingface Dataset with a test split and a model trained via Keras.

            Parameters:
                    input_column: Input column for the model. Currently only 1 column is supported. Examples: "mel_spectrogram", "pixel_values".
                    label_column: Column with labels to be predicted. Examples: "genre", "labels".
                    prep_dataset_dir: Directory where to load test data from. Example: "/blackboard/prep_dataset".
                    model_dir: Directory where to load the model from. Example: "/blackboard/model".
                    labels: List of possible labels. Example: ["Blues", "Rock", "Country"]
                    dataset_split: Optional name of a dataset's split. Defaults to "test".
                    batch_size: Optional batch size when processing the input dataset. Example: 20.
            Returns:
                    mlpipeline_ui_metadata_path: Data to plot a confusion matrix. The plotted confusion matrix can be viewed via Kubeflow UI's Vizualization for this component inside a pipeline run.
    '''
    from datasets import load_from_disk
    import json
    import logging
    import numpy as np
    import pandas as pd
    from sklearn.metrics import confusion_matrix
    import sys
    import tensorflow as tf
    from transformers import DefaultDataCollator

    logging.basicConfig(
        stream=sys.stdout,
        level=logging.INFO,
        format='%(levelname)s %(asctime)s: %(message)s'
    )
    logger = logging.getLogger()

    dataset = load_from_disk(prep_dataset_dir)
    data_collator = DefaultDataCollator(return_tensors="tf")

    test_dataset = dataset[dataset_split].to_tf_dataset(
        columns=[input_column],
        label_cols=[label_column],
        shuffle=False,
        batch_size=batch_size,
        collate_fn=data_collator
    )
    # see: https://github.com/huggingface/datasets/issues/4478
    tf.data.experimental.save(test_dataset, "./test")
    test_dataset = tf.data.experimental.load("./test")

    model = tf.keras.models.load_model(model_dir)

    # see: https://github.com/huggingface/datasets/issues/4772
    if (label_column == "labels"):
        label_column = "label"
    y_true = np.argmax(
        dataset[dataset_split][label_column],
        axis=1
    )

    y_pred = np.argmax(
        model.predict(dataset[dataset_split][input_column]),
        axis=1
    )
    confusion_matrix = confusion_matrix(y_true, y_pred)

    data = []
    for target_index, target_row in enumerate(confusion_matrix):
        for predicted_index, count in enumerate(target_row):
            data.append((labels[target_index], labels[predicted_index], count))

    df = pd.DataFrame(
        data,
        columns=['target', 'predicted', 'count']
    )

    metadata = {
      'outputs': [{
        'type': 'confusion_matrix',
        'format': 'csv',
        'schema': [
          {'name': 'target', 'type': 'CATEGORY'},
          {'name': 'predicted', 'type': 'CATEGORY'},
          {'name': 'count', 'type': 'NUMBER'},
        ],
        "storage": "inline",
        'source': df.to_csv(
            columns=['target', 'predicted', 'count'],
            header=False,
            index=False),
        'labels': labels,
      }]
    }

    logger.info("Dumping mlpipeline_ui_metadata...")
    with open(mlpipeline_ui_metadata_path, 'w') as metadata_file:
        json.dump(metadata, metadata_file)

    logger.info("Finished.")


load_dataset_comp = create_component_from_func(
    func=plot_confusion_matrix,
    output_component_file='component.yaml',
    base_image=BASE_IMAGE
)