In [None]:
from pathlib import Path

import typer
from loguru import logger
from tqdm import tqdm

from lung_cancer_detection.config import METADATA_DIR, FIGURES_DIR
import tensorflow as tf
from tf_dataset_loader import load_datasets
import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt


app = typer.Typer()


@app.command()
def main(
    # ---- REPLACE DEFAULT PATHS AS APPROPRIATE ----
    input_path: Path = METADATA_DIR,
    output_path: Path = FIGURES_DIR,
    image_size: str = "224,224",
    batch_size: int = 32,
    # -----------------------------------------
):
    # ---- Generate Keras Dataset ----
    logger.info("Generating Keras Dataset...")
    train_dataset, test_dataset, val_dataset = load_datasets(METADATA_DIR, image_size=(224, 224), batch_size=32)
    logger.success("Features generation complete.")

    #import full dataset
    metadata_path = METADATA_DIR / "metadata.csv"
    full_dataset = pd.read_csv(metadata_path)

    # -------------------------------------------- Generate Pie Chart of Labels --------------------------------------------
    label_counts = Counter(full_dataset['label'])


    # Extract labels and their counts
    label_names = list(label_counts.keys())
    counts = list(label_counts.values())

    # Create the pie chart
    plt.figure(figsize=(8, 6))
    plt.pie(
        counts,
        labels=label_names,
        autopct='%1.1f%%',
        startangle=90,
        colors=plt.cm.tab20c.colors,
        wedgeprops={'alpha': 0.7}
    )
    plt.title("Distribution of Labels in Dataset")
    plt_name = FIGURES_DIR / "Label_Distribution_Pie_Plot"
    plt.savefig(plt_name)



    # -------------------------------------------- Plot Image Before Processing --------------------------------------------

    class_names = list(full_dataset['label'].unique)
    image_subset = []

    for label in class_names:
        full_dataset[full_dataset['label'] == label][0:4]
        
    
    images, labels = next(iter(train_ds))

    plt.figure(figsize=(10, 10))
    num_images = 16

    for i in range(num_images):
        plt.subplot(4, 4, i + 1)  # 4x4 grid for 16 images
        image = images[i].numpy().astype("uint8")  # Convert tensor to image format
        label = class_names[labels[i].numpy()]  # Map integer label to class name
        plt.imshow(image)
        plt.title(label, fontsize=8)
        plt.axis('off')


if __name__ == "__main__":
    app()
