In [120]:
import pandas as pd
import numpy as np
import json
import yaml
import os
from dotenv import load_dotenv
from typing import List, Dict

In [121]:
# Load .env variables
load_dotenv()

data_dir = os.getenv('DATA_DIR')
final_model_dir = os.getenv('MODEL_DIR')


In [122]:

def get_yaml_file_paths(data_dir: str) -> List[str]:
    """
    Retrieve the paths to all 'data.yaml' files within subdirectories of a given directory.

    Args:
        data_dir (str): The root directory containing subdirectories to search for 'data.yaml' files.

    Returns:
        List[str]: A list of file paths to 'data.yaml' files found within the subdirectories.
    """
    # Verify the directory exists
    if not os.path.exists(data_dir):
        raise FileNotFoundError(f"The directory '{data_dir}' does not exist.")
    
    # Ensure the input path is a directory
    if not os.path.isdir(data_dir):
        raise ValueError(f"The path '{data_dir}' is not a directory.")

    yaml_files = []
    
    # Iterate through each subdirectory in the given directory
    for dir_name in os.listdir(data_dir):
        directory = os.path.join(data_dir, dir_name)
        
        if os.path.isdir(directory):
            for filename in os.listdir(directory):
                if filename == 'data.yaml':
                    yaml_path = os.path.join(directory, filename)
                    yaml_files.append(yaml_path)

    return yaml_files

In [123]:
def get_yaml_data(yaml_files: List[str]) -> List:
    """
    Load data from each yaml file

    Args:
        yaml_files: list of yaml file paths
    Returns:
        List of unique classes from all data.yaml file in data directory
    """
    loaded_yaml_data = []
    # Loop through each file path and load the YAML content
    for file_path in yaml_files:
        with open(file_path, 'r') as file:
            yaml_data = yaml.safe_load(file)
            # print(yaml_data['names'])
            loaded_yaml_data.extend(yaml_data['names'])
    classes_list = list(set(loaded_yaml_data))
    classes_list.sort()
    return classes_list

In [124]:
def create_class_dict(classes_list: List[str]) -> Dict[int, str]:
    """
    Converts a list of classes into a dictionary of indexed pairs.

    Args: 
        classes_list (List[str]): List of unique classes to create the dictionary.

    Returns:
        Dict[int, str]: A dictionary where keys are indices and values are class names.
    """
    classes_dict = dict()
    for i, cls in enumerate(classes_list):
        classes_dict[i] = cls
    return classes_dict

In [None]:
def create_yaml_file(classes_dict: Dict[int, str], output_path: str) -> None:
    """
    Creates a YAML file with the specified structure, where the 'nc' and 'names' 
    fields are derived from the provided classes_dict.

    Args:
        classes_dict (Dict[int, str]): A dictionary where keys are indices and values are class names.
        output_path (str): The file path where the YAML file will be saved.

    Returns:
        None
    """
    nc = len(classes_dict)
    
    names = [classes_dict[i] for i in range(nc)]
    
    data = {
        'train': '../train/images',
        'val': '../valid/images',
        'test': '../test/images',
        'nc': nc,
        'names': names
    }
    
    with open(output_path, 'w') as yaml_file:
        yaml.dump(data, yaml_file, default_flow_style=False)

    print(f"YAML file created at: {output_path}")





In [125]:
def main(data_dir: str) -> None:
    """
    Main function for program runs all functions in sequence
    """
    yaml_files = get_yaml_file_paths(data_dir=data_dir)
    classes_list = get_yaml_data(yaml_files)
    classes_dict = create_class_dict(classes_list)
    create_yaml_file(classes_dict, final_model_dir)
    print(classes_dict)

In [127]:
main(data_dir=data_dir)

{0: 'almond butter', 1: 'apple', 2: 'avocado', 3: 'bacon', 4: 'baking soda', 5: 'balsamic vinaigrette', 6: 'balsamic vinegar', 7: 'barbecue sauce', 8: 'basil', 9: 'basil pesto', 10: 'beans', 11: 'bitter gaurd', 12: 'black beans', 13: 'black pepper', 14: 'bread', 15: 'bringal', 16: 'brown onion', 17: 'buffalo sauce', 18: 'butter', 19: 'cabbage', 20: 'cajun', 21: 'cake', 22: 'candy', 23: 'canola oil', 24: 'capsicum', 25: 'carrots', 26: 'cauliflower', 27: 'cayenne pepper', 28: 'cereal', 29: 'cheese', 30: 'chicken', 31: 'chicken stock', 32: 'chickpeas', 33: 'chillies', 34: 'chips', 35: 'chocolate', 36: 'cinnamon', 37: 'coffee', 38: 'coriander', 39: 'corn', 40: 'cucumber', 41: 'cumin', 42: 'egg', 43: 'fish', 44: 'flour', 45: 'garlic', 46: 'ginger', 47: 'gnocchi', 48: 'grapes', 49: 'hoison sauce', 50: 'honey', 51: 'hot sauce', 52: 'italian herbs', 53: 'jalapeno', 54: 'jam', 55: 'juice', 56: 'ketchup', 57: 'kiwi', 58: 'kumara', 59: 'laksa paste', 60: 'lemon', 61: 'lettuce', 62: 'lime juice', 