# **Forward Pass for FNN**

In [None]:
import tensorflow as tf
import numpy as np
import pandas as pd
import os

from tensorflow.python.framework.convert_to_constants import (
    convert_variables_to_constants_v2_as_graph
)

# ==========================================================
# USER INPUTS (ONLY THESE)
# ==========================================================

MODEL_PATH = r"D:\college\sem-8\models\FNN\asthma_fnn_binary.h5"
DATASET_PATH = r"D:\college\sem-8\dataset\clean_asthma_dataset.csv"
DATASET_FRACTION = 0.02
BATCH_SIZE = 32

# ==========================================================
# 1️⃣ LOAD MODEL (ANY KERAS MODEL)
# ==========================================================

model = tf.keras.models.load_model(MODEL_PATH, compile=False)
print("✔ Model loaded")

# Infer model input shape (remove batch dim)
model_input_shape = model.input_shape
if isinstance(model_input_shape, list):
    model_input_shape = model_input_shape[0]

feature_shape = tuple(model_input_shape[1:])
print("✔ Model input shape:", feature_shape)

# ==========================================================
# 2️⃣ LOAD DATASET (MODEL-AWARE)
# ==========================================================

def load_dataset(path, feature_shape):
    ext = os.path.splitext(path)[1].lower()

    # ---------- CSV / TABULAR ----------
    if ext == ".csv":
        df = pd.read_csv(path)

        num_features = int(np.prod(feature_shape))
        X = df.iloc[:, :num_features].values.astype(np.float32)

        # Dummy labels (not used for FLOPs)
        y = np.zeros(len(X), dtype=np.int32)

        return tf.data.Dataset.from_tensor_slices((X, y))

    # ---------- NUMPY ----------
    if ext == ".npz":
        data = np.load(path)
        return tf.data.Dataset.from_tensor_slices(
            (data["x"], data["y"])
        )

    raise ValueError("Unsupported dataset format")

dataset = load_dataset(DATASET_PATH, feature_shape)
dataset = dataset.batch(BATCH_SIZE)
print("✔ Dataset loaded")

# ==========================================================
# 3️⃣ SAMPLE DATASET (1–2%)
# ==========================================================

def sample_dataset(ds, fraction):
    total_batches = tf.data.experimental.cardinality(ds).numpy()
    sample_batches = max(1, int(total_batches * fraction))
    return ds.take(sample_batches), total_batches

sampled_ds, total_batches = sample_dataset(dataset, DATASET_FRACTION)

# ==========================================================
# 4️⃣ COMPUTE FLOPs (ROBUST & SAFE)
# ==========================================================

def compute_flops_per_sample(model, feature_shape):
    """
    Computes FLOPs for ONE forward pass of ONE sample
    """
    input_spec = tf.TensorSpec(
        shape=(1,) + feature_shape,
        dtype=tf.float32
    )

    concrete_fn = tf.function(model).get_concrete_function(input_spec)

    frozen_func, _ = convert_variables_to_constants_v2_as_graph(
        concrete_fn
    )

    flops = tf.compat.v1.profiler.profile(
        graph=frozen_func.graph,
        options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
    )

    return flops.total_float_ops

flops_per_sample = compute_flops_per_sample(model, feature_shape)
print("✔ FLOPs per sample:", flops_per_sample)

# ==========================================================
# 5️⃣ EXTRAPOLATE TO FULL DATASET
# ==========================================================

total_samples = total_batches * BATCH_SIZE
total_flops = flops_per_sample * total_samples
total_gflops = total_flops / 1e9

# ==========================================================
# 6️⃣ FINAL REPORT
# ==========================================================

print("\n========== FLOPs ESTIMATION REPORT ==========")
print(f"Total samples     : {total_samples}")
print(f"FLOPs per sample  : {flops_per_sample:.3e}")
print(f"Total FLOPs       : {total_flops:.3e}")
print(f"Total GFLOPs      : {total_gflops:.3f}")
print("============================================")


# **Forward Pass for CNN**

In [None]:
import tensorflow as tf
import numpy as np
import os
from tensorflow.python.framework.convert_to_constants import (
    convert_variables_to_constants_v2_as_graph
)

# ==========================================================
# USER INPUTS
# ==========================================================

MODEL_PATH = r"D:\college\sem-8\models\CNN\pothole_cnn_model.h5"
IMAGE_DIR  = r"D:\college\sem-8\dataset\pothole"   # class-wise folders
DATASET_FRACTION = 0.02
BATCH_SIZE = 16

# ==========================================================
# 1️⃣ LOAD CNN MODEL
# ==========================================================

model = tf.keras.models.load_model(MODEL_PATH, compile=False)
print("✔ Model loaded")

# Infer input shape (H, W, C)
model_input_shape = model.input_shape
if isinstance(model_input_shape, list):
    model_input_shape = model_input_shape[0]

feature_shape = tuple(model_input_shape[1:])
print("✔ Model input shape:", feature_shape)

IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS = feature_shape

# ==========================================================
# 2️⃣ LOAD IMAGE DATASET (MODEL-AWARE)
# ==========================================================

dataset = tf.keras.preprocessing.image_dataset_from_directory(
    IMAGE_DIR,
    image_size=(IMG_HEIGHT, IMG_WIDTH),
    batch_size=BATCH_SIZE,
    shuffle=False
)

print("✔ Dataset loaded")

# ==========================================================
# 3️⃣ SAMPLE DATASET (1–2%)
# ==========================================================

def sample_dataset(ds, fraction):
    total_batches = tf.data.experimental.cardinality(ds).numpy()
    sample_batches = max(1, int(total_batches * fraction))
    return ds.take(sample_batches), total_batches

sampled_ds, total_batches = sample_dataset(dataset, DATASET_FRACTION)

# ==========================================================
# 4️⃣ COMPUTE FLOPs PER SAMPLE
# ==========================================================

def compute_flops_per_sample(model, feature_shape):
    input_spec = tf.TensorSpec(
        shape=(1,) + feature_shape,
        dtype=tf.float32
    )

    concrete_fn = tf.function(model).get_concrete_function(input_spec)

    frozen_func, _ = convert_variables_to_constants_v2_as_graph(
        concrete_fn
    )

    flops = tf.compat.v1.profiler.profile(
        graph=frozen_func.graph,
        options=tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
    )

    return flops.total_float_ops

flops_per_sample = compute_flops_per_sample(model, feature_shape)
print("✔ FLOPs per sample:", flops_per_sample)

# ==========================================================
# 5️⃣ EXTRAPOLATE TO FULL DATASET
# ==========================================================

total_samples = total_batches * BATCH_SIZE
total_flops = flops_per_sample * total_samples
total_gflops = total_flops / 1e9

# ==========================================================
# 6️⃣ FINAL REPORT
# ==========================================================

print("\n========== CNN FLOPs ESTIMATION REPORT ==========")
print(f"Total samples     : {total_samples}")
print(f"FLOPs per sample  : {flops_per_sample:.3e}")
print(f"Total FLOPs       : {total_flops:.3e}")
print(f"Total GFLOPs      : {total_gflops:.3f}")
print("===============================================")


# **Foreward Pass FNN with memory check and inference time for whole dataset**

In [None]:
"""
FNN FEASIBILITY CHECK (TensorFlow 2.x)

✔ FLOPs → speed estimation
✔ Memory → crash prediction (YES / NO)
✔ User-provided batch size
✔ Optimizer auto-detection
✔ Training cost estimation
✔ Colab GPU support
"""

import tensorflow as tf
import pandas as pd
from tensorflow.python.framework.convert_to_constants import (
    convert_variables_to_constants_v2
)

# ==========================================================
# HARDWARE DATABASE
# ==========================================================
GPU_DATABASE = [
    # Local GPUs
    ("RTX 3060", 12 * 1024, 9_000),
    ("RTX 4090", 24 * 1024, 55_000),
    ("RTX 5090", 32 * 1024, 60_000),

    # Google Colab GPUs
    ("Colab T4",   16 * 1024, 4_000),
    ("Colab P100", 16 * 1024, 9_000),
    ("Colab V100", 16 * 1024, 14_000),
    ("Colab A100", 40 * 1024, 14_000),
]

CPU_DATABASE = [
    (8,  8 * 1024, 70),
    (16, 16 * 1024, 100),
    (32, 32 * 1024, 130),
]

# ==========================================================
# MENU INPUT
# ==========================================================
def select_option(title, options):
    print(f"\n{title}")
    for i, opt in enumerate(options, 1):
        print(f"{i}. {opt}")
    while True:
        c = input("Select option number: ")
        if c.isdigit() and 1 <= int(c) <= len(options):
            return int(c) - 1
        print("❌ Invalid selection. Try again.")

# ==========================================================
# OPTIMIZER AUTO-DETECTION
# ==========================================================
def detect_optimizer(model):
    try:
        name = model.optimizer.__class__.__name__.lower()
        if "adamw" in name:
            return "adamw"
        if "adam" in name:
            return "adam"
        if "sgd" in name:
            return "sgd"
    except:
        pass
    return None

# ==========================================================
# DATASET VALIDATION (CSV)
# ==========================================================
def load_csv_dataset(csv_path):
    df = pd.read_csv(csv_path)

    X = df.iloc[:, :-1].values
    y = df.iloc[:, -1].values

    num_samples = X.shape[0]
    num_features = X.shape[1]
    num_classes = len(set(y))

    return num_samples, num_features, num_classes

# ==========================================================
# GFLOPs (INFERENCE)
# ==========================================================
def compute_model_gflops(model, input_dim):
    model.trainable = False

    @tf.function
    def forward(x):
        return model(x, training=False)

    concrete_func = forward.get_concrete_function(
        tf.TensorSpec([1, input_dim], tf.float32)
    )

    frozen_func = convert_variables_to_constants_v2(concrete_func)
    graph_def = frozen_func.graph.as_graph_def()

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        flops = tf.compat.v1.profiler.profile(
            graph=graph,
            cmd="op",
            options=opts
        )

    return flops.total_float_ops / 1e9

# ==========================================================
# TRAINING FLOPs (FNN)
# ==========================================================
def estimate_training_gflops(forward_gflops, num_samples, epochs):
    return forward_gflops * num_samples * epochs * 2.0  # FNN

# ==========================================================
# MEMORY (CRASH) CHECK — FNN
# ==========================================================
def memory_check(model, batch_size, available_mem_mb, optimizer):
    param_mb = model.count_params() * 4 / (1024 ** 2)

    # FNN has small activations
    activation_mb = param_mb * 1.5 * batch_size

    optimizer_mb = param_mb * (2 if optimizer != "sgd" else 1)

    total_mb = param_mb + activation_mb + optimizer_mb + 400

    return param_mb, activation_mb, optimizer_mb, total_mb

# ==========================================================
# PRINT HELPERS (STANDARD FORMAT)
# ==========================================================
def print_dataset_info(num_classes, num_samples, input_shape):
    print("\n=== DATASET INFO ===")
    print(f"Classes     : {num_classes}")
    print(f"Samples     : {num_samples}")
    print(f"Input shape : {input_shape}")

def print_flops_info(gflops):
    print("\n=== INFERENCE FLOPs ===")
    print(f"GFLOPs per inference : {gflops:.4f}")

def print_memory_info(batch_size, param_mb, activation_mb, optimizer_mb, total_mb, available_mb):
    print("\n=== MEMORY CHECK ===")
    print(f"Batch size         : {batch_size}")
    print(f"Parameters        : {param_mb:.2f} MB")
    print(f"Activations       : {activation_mb:.2f} MB")
    print(f"Optimizer state   : {optimizer_mb:.2f} MB")
    print(f"Estimated TOTAL   : {total_mb:.2f} MB")
    print(f"Available memory : {available_mb} MB")

    if total_mb > available_mb:
        print("RESULT            : ❌ WILL CRASH (OOM)")
    else:
        print("RESULT            : ✅ WILL NOT CRASH")

def print_speed_info(model_gflops, hardware_gflops_s):
    time_sec = model_gflops / hardware_gflops_s
    fps = 1 / time_sec
    print("\n=== INFERENCE SPEED ===")
    print(f"Time per inference : {time_sec:.6f} sec")
    print(f"Theoretical FPS    : {fps:.1f}")

def print_training_cost(total_train_gflops, hardware_gflops_s):
    train_time_sec = total_train_gflops / hardware_gflops_s
    print("\n=== TRAINING COST ===")
    print(f"Total training FLOPs : {total_train_gflops:.2f} GFLOPs")
    print(f"Training time       : {train_time_sec/60:.2f} minutes")

# ==========================================================
# MAIN
# ==========================================================
def run_feasibility_check():
    print("\nFNN FEASIBILITY CHECK\n")

    model_path = input("Enter FNN model path (.keras/.h5): ")
    dataset_path = input("Enter CSV dataset path: ")
    epochs = int(input("Enter training epochs: "))
    batch_size = int(input("Enter batch size (you plan to use): "))

    model = tf.keras.models.load_model(model_path)
    model.summary()

    num_samples, num_features, num_classes = load_csv_dataset(dataset_path)
    input_shape = (num_features,)

    print_dataset_info(num_classes, num_samples, input_shape)

    forward_gflops = compute_model_gflops(model, num_features)
    print_flops_info(forward_gflops)

    device_idx = select_option("Select device type:", ["GPU", "CPU"])

    if device_idx == 0:
        gpu_idx = select_option("Select GPU:", [g[0] for g in GPU_DATABASE])
        _, available_mem_mb, hw_gflops = GPU_DATABASE[gpu_idx]
    else:
        cpu_idx = select_option(
            "Select CPU RAM:",
            [f"{c[0]} GB RAM" for c in CPU_DATABASE]
        )
        _, available_mem_mb, hw_gflops = CPU_DATABASE[cpu_idx]

    optimizer = detect_optimizer(model)
    if optimizer:
        print(f"✅ Detected optimizer: {optimizer}")
    else:
        opt_idx = select_option("Select optimizer:", ["SGD", "Adam", "AdamW"])
        optimizer = ["sgd", "adam", "adamw"][opt_idx]

    param_mb, act_mb, opt_mb, total_mb = memory_check(
        model, batch_size, available_mem_mb, optimizer
    )

    print_memory_info(
        batch_size, param_mb, act_mb, opt_mb, total_mb, available_mem_mb
    )

    print_speed_info(forward_gflops, hw_gflops)

    total_train_gflops = estimate_training_gflops(
        forward_gflops, num_samples, epochs
    )
    print_training_cost(total_train_gflops, hw_gflops)

# ==========================================================
# RUN
# ==========================================================
if __name__ == "__main__":
    run_feasibility_check()


# **Foreward Pass CNN with memory check and inference time for whole dataset**

In [None]:
"""
CNN FEASIBILITY CHECK (TensorFlow 2.x)

✔ FLOPs → speed estimation
✔ Memory → crash prediction (YES / NO)
✔ User-provided batch size
✔ Optimizer auto-detection
✔ Training cost estimation
✔ Colab GPU support
"""

import tensorflow as tf
import os
from tensorflow.python.framework.convert_to_constants import (
    convert_variables_to_constants_v2
)

# ==========================================================
# HARDWARE DATABASE
# ==========================================================
GPU_DATABASE = [
    # Local GPUs
    ("RTX 3060", 12 * 1024, 9_000),
    ("RTX 4090", 24 * 1024, 55_000),
    ("RTX 5090", 32 * 1024, 60_000),

    # Google Colab GPUs
    ("Colab T4",   16 * 1024, 4_000),
    ("Colab P100", 16 * 1024, 9_000),
    ("Colab V100", 16 * 1024, 14_000),
    ("Colab A100", 40 * 1024, 14_000),
]

CPU_DATABASE = [
    (8,  8 * 1024, 70),
    (16, 16 * 1024, 100),
    (32, 32 * 1024, 130),
]

# ==========================================================
# MENU INPUT
# ==========================================================
def select_option(title, options):
    print(f"\n{title}")
    for i, opt in enumerate(options, 1):
        print(f"{i}. {opt}")
    while True:
        c = input("Select option number: ")
        if c.isdigit() and 1 <= int(c) <= len(options):
            return int(c) - 1
        print("❌ Invalid selection. Try again.")

# ==========================================================
# OPTIMIZER AUTO-DETECTION
# ==========================================================
def detect_optimizer(model):
    try:
        name = model.optimizer.__class__.__name__.lower()
        if "adamw" in name:
            return "adamw"
        if "adam" in name:
            return "adam"
        if "sgd" in name:
            return "sgd"
    except:
        pass
    return None

# ==========================================================
# DATASET VALIDATION
# ==========================================================
def validate_image_dataset(dataset_dir):
    if not os.path.isdir(dataset_dir):
        raise ValueError("❌ Dataset path is not a directory")

    classes = [
        d for d in os.listdir(dataset_dir)
        if os.path.isdir(os.path.join(dataset_dir, d))
    ]

    if len(classes) < 2:
        raise ValueError("❌ Dataset must have ≥2 class folders")

    num_samples = sum(len(files) for _, _, files in os.walk(dataset_dir))

    return len(classes), num_samples

# ==========================================================
# GFLOPs (INFERENCE)
# ==========================================================
def compute_model_gflops(model, input_shape):
    model.trainable = False

    @tf.function
    def forward(x):
        return model(x, training=False)

    concrete_func = forward.get_concrete_function(
        tf.TensorSpec([1] + list(input_shape), tf.float32)
    )

    frozen_func = convert_variables_to_constants_v2(concrete_func)
    graph_def = frozen_func.graph.as_graph_def()

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        flops = tf.compat.v1.profiler.profile(
            graph=graph,
            cmd="op",
            options=opts
        )

    return flops.total_float_ops / 1e9

# ==========================================================
# TRAINING FLOPs (CNN)
# ==========================================================
def estimate_training_gflops(forward_gflops, num_samples, epochs):
    return forward_gflops * num_samples * epochs * 2.5

# ==========================================================
# MEMORY (CRASH) CHECK — CNN
# ==========================================================
def memory_check(model, batch_size, available_mem_mb, optimizer):
    param_mb = model.count_params() * 4 / (1024 ** 2)
    activation_mb = param_mb * 4 * batch_size   # CNN multiplier
    optimizer_mb = param_mb * (2 if optimizer != "sgd" else 1)

    total_mb = param_mb + activation_mb + optimizer_mb + 600

    return param_mb, activation_mb, optimizer_mb, total_mb

# ==========================================================
# PRINT HELPERS (STANDARD FORMAT)
# ==========================================================
def print_dataset_info(num_classes, num_samples, input_shape):
    print("\n=== DATASET INFO ===")
    print(f"Classes     : {num_classes}")
    print(f"Samples     : {num_samples}")
    print(f"Input shape : {input_shape}")

def print_flops_info(gflops):
    print("\n=== INFERENCE FLOPs ===")
    print(f"GFLOPs per inference : {gflops:.4f}")

def print_memory_info(batch_size, param_mb, activation_mb, optimizer_mb, total_mb, available_mb):
    print("\n=== MEMORY CHECK ===")
    print(f"Batch size         : {batch_size}")
    print(f"Parameters        : {param_mb:.2f} MB")
    print(f"Activations       : {activation_mb:.2f} MB")
    print(f"Optimizer state   : {optimizer_mb:.2f} MB")
    print(f"Estimated TOTAL   : {total_mb:.2f} MB")
    print(f"Available memory : {available_mb} MB")

    if total_mb > available_mb:
        print("RESULT            : ❌ WILL CRASH (OOM)")
        return False
    else:
        print("RESULT            : ✅ WILL NOT CRASH")
        return True

def print_speed_info(model_gflops, hardware_gflops_s):
    time_sec = model_gflops / hardware_gflops_s
    fps = 1 / time_sec
    print("\n=== INFERENCE SPEED ===")
    print(f"Time per inference : {time_sec:.6f} sec")
    print(f"Theoretical FPS    : {fps:.1f}")

def print_training_cost(total_train_gflops, hardware_gflops_s):
    train_time_sec = total_train_gflops / hardware_gflops_s
    print("\n=== TRAINING COST ===")
    print(f"Total training FLOPs : {total_train_gflops:.2f} GFLOPs")
    print(f"Training time       : {train_time_sec/60:.2f} minutes")

# ==========================================================
# MAIN
# ==========================================================
def run_feasibility_check():
    print("\nCNN FEASIBILITY CHECK\n")

    model_path = input("Enter CNN model path (.keras/.h5): ")
    dataset_path = input("Enter image dataset directory: ")
    epochs = int(input("Enter training epochs: "))
    batch_size = int(input("Enter batch size (you plan to use): "))

    model = tf.keras.models.load_model(model_path)
    model.summary()

    num_classes, num_samples = validate_image_dataset(dataset_path)
    input_shape = model.input_shape[1:]

    print_dataset_info(num_classes, num_samples, input_shape)

    # FLOPs
    forward_gflops = compute_model_gflops(model, input_shape)
    print_flops_info(forward_gflops)

    # Device selection
    device_idx = select_option("Select device type:", ["GPU", "CPU"])

    if device_idx == 0:
        gpu_idx = select_option("Select GPU:", [g[0] for g in GPU_DATABASE])
        _, available_mem_mb, hw_gflops = GPU_DATABASE[gpu_idx]
    else:
        cpu_idx = select_option(
            "Select CPU RAM:",
            [f"{c[0]} GB RAM" for c in CPU_DATABASE]
        )
        _, available_mem_mb, hw_gflops = CPU_DATABASE[cpu_idx]

    # Optimizer
    optimizer = detect_optimizer(model)
    if optimizer:
        print(f"✅ Detected optimizer: {optimizer}")
    else:
        opt_idx = select_option("Select optimizer:", ["SGD", "Adam", "AdamW"])
        optimizer = ["sgd", "adam", "adamw"][opt_idx]

    # Memory check
    param_mb, act_mb, opt_mb, total_mb = memory_check(
        model, batch_size, available_mem_mb, optimizer
    )

    print_memory_info(
        batch_size, param_mb, act_mb, opt_mb, total_mb, available_mem_mb
    )

    # Speed
    print_speed_info(forward_gflops, hw_gflops)

    # Training cost
    total_train_gflops = estimate_training_gflops(
        forward_gflops, num_samples, epochs
    )
    print_training_cost(total_train_gflops, hw_gflops)

# ==========================================================
# RUN
# ==========================================================
if __name__ == "__main__":
    run_feasibility_check()


# **Foreward Pass ResNet with memory check and inference time for whole dataset**

In [None]:
"""
RESNET FEASIBILITY CHECK (TensorFlow 2.x)

✔ FLOPs → speed estimation
✔ Memory → crash prediction (YES / NO)
✔ User-provided batch size
✔ Optimizer auto-detection
✔ Training cost estimation
✔ Colab GPU support
"""

import tensorflow as tf
import os
from tensorflow.python.framework.convert_to_constants import (
    convert_variables_to_constants_v2
)

# ==========================================================
# HARDWARE DATABASE
# ==========================================================
GPU_DATABASE = [
    # Local GPUs
    ("RTX 3060", 12 * 1024, 9_000),
    ("RTX 4090", 24 * 1024, 55_000),
    ("RTX 5090", 32 * 1024, 60_000),

    # Google Colab GPUs
    ("Colab T4",   16 * 1024, 4_000),
    ("Colab P100", 16 * 1024, 9_000),
    ("Colab V100", 16 * 1024, 14_000),
    ("Colab A100", 40 * 1024, 14_000),
]

CPU_DATABASE = [
    (8,  8 * 1024, 70),
    (16, 16 * 1024, 100),
    (32, 32 * 1024, 130),
]

# ==========================================================
# MENU INPUT
# ==========================================================
def select_option(title, options):
    print(f"\n{title}")
    for i, opt in enumerate(options, 1):
        print(f"{i}. {opt}")

    while True:
        choice = input("Select option number: ")
        if choice.isdigit() and 1 <= int(choice) <= len(options):
            return int(choice) - 1
        print("❌ Invalid selection. Try again.")

# ==========================================================
# OPTIMIZER AUTO-DETECTION
# ==========================================================
def detect_optimizer(model):
    try:
        opt = model.optimizer
        name = opt.__class__.__name__.lower()
        if "adamw" in name:
            return "adamw"
        if "adam" in name:
            return "adam"
        if "sgd" in name:
            return "sgd"
    except:
        pass
    return None

# ==========================================================
# DATASET VALIDATION
# ==========================================================
def validate_image_dataset(dataset_dir):
    if not os.path.isdir(dataset_dir):
        raise ValueError("❌ Dataset path is not a directory")

    classes = [
        d for d in os.listdir(dataset_dir)
        if os.path.isdir(os.path.join(dataset_dir, d))
    ]

    if len(classes) < 2:
        raise ValueError("❌ Dataset must have ≥2 class folders")

    num_samples = sum(len(files) for _, _, files in os.walk(dataset_dir))

    print("\n=== DATASET INFO ===")
    print(f"Classes     : {len(classes)}")
    print(f"Samples     : {num_samples}")
    print(f"Class names : {classes}")

    return num_samples

# ==========================================================
# GFLOPs (INFERENCE)
# ==========================================================
def compute_model_gflops(model, input_shape):
    model.trainable = False

    @tf.function
    def forward(x):
        return model(x, training=False)

    concrete_func = forward.get_concrete_function(
        tf.TensorSpec([1] + list(input_shape), tf.float32)
    )

    frozen_func = convert_variables_to_constants_v2(concrete_func)
    graph_def = frozen_func.graph.as_graph_def()

    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="")
        opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
        flops = tf.compat.v1.profiler.profile(
            graph=graph, cmd="op", options=opts
        )

    gflops = flops.total_float_ops / 1e9

    print("\n=== INFERENCE FLOPs ===")
    print(f"GFLOPs per inference : {gflops:.4f}")

    return gflops

# ==========================================================
# TRAINING FLOPs (ResNet)
# ==========================================================
def estimate_training_gflops(forward_gflops, num_samples, epochs):
    return forward_gflops * num_samples * epochs * 3.0

# ==========================================================
# MEMORY (CRASH) CHECK — USER BATCH SIZE
# ==========================================================
def crash_check(model, available_mem_mb, batch_size, optimizer="adam"):
    param_mb = model.count_params() * 4 / (1024 ** 2)

    # ResNet activation cost
    activation_mb = param_mb * 8 * batch_size

    optimizer_multiplier = {
        "sgd": 1.0,
        "adam": 2.0,
        "adamw": 2.0
    }[optimizer]

    optimizer_mb = param_mb * optimizer_multiplier

    total = param_mb + activation_mb + optimizer_mb + 800

    print("\n=== MEMORY CHECK ===")
    print(f"Batch size         : {batch_size}")
    print(f"Parameters        : {param_mb:.2f} MB")
    print(f"Activations       : {activation_mb:.2f} MB")
    print(f"Optimizer state   : {optimizer_mb:.2f} MB")
    print(f"Estimated TOTAL   : {total:.2f} MB")
    print(f"Available memory : {available_mem_mb} MB")

    if total > available_mem_mb:
        print("❌ RESULT: MODEL WILL CRASH (OOM)")
        return False
    else:
        print("✅ RESULT: MODEL WILL NOT CRASH")
        return True

# ==========================================================
# SPEED (INFERENCE)
# ==========================================================
def speed_check(model_gflops, hardware_gflops_s):
    time_sec = model_gflops / hardware_gflops_s
    fps = 1 / time_sec

    print("\n=== INFERENCE SPEED ===")
    print(f"Time per inference : {time_sec:.8f} sec")
    print(f"Theoretical FPS    : {fps:.1f}")

# ==========================================================
# MAIN
# ==========================================================
def run_feasibility_check():
    print("\nRESNET FEASIBILITY CHECK\n")

    model_path = input("Enter ResNet model path (.keras/.h5): ")
    dataset_path = input("Enter image dataset directory: ")
    epochs = int(input("Enter training epochs: "))
    batch_size = int(input("Enter batch size (you plan to use): "))

    model = tf.keras.models.load_model(model_path)
    model.summary()

    num_samples = validate_image_dataset(dataset_path)

    input_shape = model.input_shape[1:]
    print(f"✅ Input shape: {input_shape}")

    # FLOPs
    forward_gflops = compute_model_gflops(model, input_shape)

    # Hardware
    device_idx = select_option("Select device type:", ["GPU", "CPU"])

    if device_idx == 0:
        gpu_idx = select_option("Select GPU:", [g[0] for g in GPU_DATABASE])
        _, mem_mb, gflops_s = GPU_DATABASE[gpu_idx]
    else:
        cpu_idx = select_option(
            "Select CPU RAM:",
            [f"{c[0]} GB RAM" for c in CPU_DATABASE]
        )
        _, mem_mb, gflops_s = CPU_DATABASE[cpu_idx]

    # Optimizer
    optimizer = detect_optimizer(model)
    if optimizer:
        print(f"✅ Detected optimizer: {optimizer}")
    else:
        opt_idx = select_option("Select optimizer:", ["SGD", "Adam", "AdamW"])
        optimizer = ["sgd", "adam", "adamw"][opt_idx]

    # Crash check
    crash_check(model, mem_mb, batch_size, optimizer)

    # Speed
    speed_check(forward_gflops, gflops_s)

    # Training cost
    total_train_gflops = estimate_training_gflops(
        forward_gflops, num_samples, epochs
    )
    train_time_sec = total_train_gflops / gflops_s

    print("\n=== TRAINING COST ===")
    print(f"Total training FLOPs : {total_train_gflops:.2f} GFLOPs")
    print(f"Training time       : {train_time_sec/60:.2f} minutes")

# ==========================================================
# RUN
# ==========================================================
if __name__ == "__main__":
    run_feasibility_check()
