In [None]:
import numpy as np
import os
import sys
import torch
import torch.utils.checkpoint as checkpoint

from concurrent.futures import ThreadPoolExecutor, as_completed
from flask import Flask, request, jsonify
from hydra import compose, initialize
from pathlib import Path
from omegaconf import OmegaConf
from yolo.config.config import Config
from yolo.model.yolo import YOLO, create_model

# Initialize Flask app
app = Flask(__name__)

# Config path
config_path = "../yolo/config"  # Make sure this is relative
config_name = "config"

model_cache = {}

models = []

# Project root setup
project_root = Path(__file__).resolve().parent.parent
sys.path.append(str(project_root))
print(project_root)

# ----------- TEST FUNCTIONS -----------

def test_build_model_v9c():
    with initialize(config_path=config_path, version_base=None):
        cfg: Config = compose(config_name=config_name)
        OmegaConf.set_struct(cfg.model, False)
        cfg.weight = None
        model = YOLO(cfg.model)
        assert len(model.model) == 39
    print("✅ test_build_model_v9c passed")


def test_build_model_v9m():
    with initialize(config_path=config_path, version_base=None):
        cfg: Config = compose(config_name=config_name, overrides=["model=v9-m"])
        OmegaConf.set_struct(cfg.model, False)
        cfg.weight = None
        model = YOLO(cfg.model)
        assert len(model.model) == 39
    print("✅ test_build_model_v9m passed")


def test_build_model_v7():
    with initialize(config_path=config_path, version_base=None):
        cfg: Config = compose(config_name=config_name, overrides=["model=v7"])
        OmegaConf.set_struct(cfg.model, False)
        cfg.weight = None
        model = YOLO(cfg.model)
        assert len(model.model) == 106
    print("✅ test_build_model_v7 passed")


def get_cfg() -> Config:
    with initialize(config_path=config_path, version_base=None):
        cfg: Config = compose(config_name=config_name)
        cfg.weight = None
    return cfg


def get_model(cfg: Config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = create_model(cfg.model, weight_path=None)
    return model.to(device)


def test_model_basic_status(model):
    assert isinstance(model, YOLO)
    assert len(model.model) == 39
    print("✅ test_model_basic_status passed")


def test_yolo_forward_output_shape(model):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    dummy_input = torch.rand(2, 3, 640, 640, device=device)

    output = model(dummy_input)
    output_shape = [(cls.shape, anc.shape, box.shape) for cls, anc, box in output["Main"]]
    assert output_shape == [
        (torch.Size([2, 80, 80, 80]), torch.Size([2, 16, 4, 80, 80]), torch.Size([2, 4, 80, 80])),
        (torch.Size([2, 80, 40, 40]), torch.Size([2, 16, 4, 40, 40]), torch.Size([2, 4, 40, 40])),
        (torch.Size([2, 80, 20, 20]), torch.Size([2, 16, 4, 20, 20]), torch.Size([2, 4, 20, 20])),
    ]
    print("✅ test_yolo_forward_output_shape passed")

def load_model(model_type: str, config: Config):
    if model_type == "v9c":
        return YOLO(config.model)  # Load model v9c
    elif model_type == "v9m":
        config.model = "v9-m"  # Override config for v9-m
        return YOLO(config.model)  # Load model v9m
    elif model_type == "v7":
        config.model = "v7"  # Override config for v7
        return YOLO(config.model)  # Load model v7
    else:
        raise ValueError("Unknown model type!")

import torch

def unload_model(model):
    del model
    torch.cuda.empty_cache()

def save_model_to_disk(model, path):
    torch.save(model.state_dict(), path)

def load_model_from_disk(model, path):
    model.load_state_dict(torch.load(path))
    return model


import torch.utils.checkpoint as checkpoint

def checkpoint_forward(model, input_data):
    return checkpoint.checkpoint(model, input_data)


def run_inference(model, input_data):
    """
    Run inference on a single model.
    
    Args:
        model: The model to run inference on.
        input_data: The input data for the model.
    
    Returns:
        The output of the model.
    """
    with torch.no_grad():  # Disable gradient calculation for inference
        output = model(input_data)

    return [[ k.tolist() for k in w] for w in output["Main"]]

def balance_workload(models, input_data_list):
    """
    Dynamically balance workload across multiple models running in parallel by checking if a model is free.
    
    Args:
        models: A list of models to run inference on.
        input_data_list: A list of input data corresponding to each model.
    
    Returns:
        A list of outputs from each model.
    """
    outputs = []
    future_to_model = {}
    # Create a status dictionary to track if models are free or busy
    model_status = {model: "free" for model in models}  # 'free' or 'busy'
    
    with ThreadPoolExecutor(max_workers=len(models)) as executor:
        # Track the index for the next available task
        task_index = 0
        
        while task_index < len(input_data_list):
            # Look for free models and assign them a task
            for model in models:
                if model_status[model] == "free" and task_index < len(input_data_list):
                    input_data = input_data_list[task_index]
                    # print(input_data.shape)
                    # print(input_data.shape)
                    future = executor.submit(run_inference, model, input_data)
                    future_to_model[future] = model
                    model_status[model] = "busy"  # Mark model as busy
                    task_index += 1  # Move to the next task
                    break  # Only start a task for one free model at a time

            # Collect results as they complete and assign tasks to free models
            for future in as_completed(future_to_model):
                model = future_to_model[future]
                try:
                    output = future.result()
                    outputs.append(output)
                    
                    print(f"{model_status[model]}")
                    # After a model finishes, mark it as free
                    model_status[model] = "free"
                    # print(f"Completed task {len(outputs)} / {len(input_data_list)}")
                    # print(f"Next task index: {task_index}")
                    
                    # Free memory from completed task
                    torch.cuda.empty_cache()

                    # Clear unnecessary memory used by output and intermediate variables
                    del output  # Delete the output of the model after it's added to results

                    # Assign a new task to the now-free model if there are remaining tasks
                    if task_index < len(input_data_list):
                        new_input_data = input_data_list[task_index]
                        future_to_model[executor.submit(run_inference, model, new_input_data)] = model
                        model_status[model] = "busy"  # Mark model as busy again
                        task_index += 1  # Move to the next task
                    else:
                        pass
                        # print("All tasks completed!")

                except Exception as e:
                    print(f"Model {model} generated an exception: {e}")
                    del future_to_model[future]  # Clean up the future in case of error
    
    # Free any remaining memory after all tasks are completed
    torch.cuda.empty_cache()
    
    return outputs


def prepare_models_and_inputs(device):
    """
    Prepare models and input data for inference.
    
    Returns:
        A tuple containing a list of models and a list of input data.
    """
    cfg = get_cfg()
    
    # Load multiple models (e.g., v9c, v9m, v7)
    model_v9c = get_model(cfg).to(device)
    
    models = [model_v9c for w in range(8)]
    
    # Prepare input data for each model
    input_data_list = [torch.rand(2, 3, 640, 640, device=device) for i in range(2)]  # Input for v9c
    
    return models, input_data_list


def prepare_models(device, num_models=8):
    """
    Prepare models and input data for inference.
    
    Returns:
        A tuple containing a list of models and a list of input data.
    """
    cfg = get_cfg()
    
    # Load multiple models (e.g., v9c, v9m, v7)
    model_v9c = get_model(cfg).to(device)
    
    models = [model_v9c for w in range(num_models)]

    return models


def get_models():

    return models


@app.route('/infer', methods=['POST'])
def infer():
    """
    Endpoint to accept input for various models and run inference.
    Expected input: JSON with a list of input data (e.g., images or tensors).
    """
    try:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        print('---------1')
        # Get input data from request   
        data = request.get_json()
        if not data or 'inputs' not in data:
            return jsonify({"error": "No input data provided"}), 400
        
        models = get_models()

        print('---------2')

        [print(np.array(w).shape) for w in data['inputs']]
        # Convert input data to torch tensors
        input_data_list = [torch.tensor(item, device=device).float() for item in data['inputs']]

        input_data = torch.stack(input_data_list)  # Shape: (batch_size, height, width)

        # Step 2: Add channel dimension (assuming single-channel input)
        input_data = input_data.unsqueeze(1)  # Shape: (batch_size, channels=1, height, width)
        input_data = input_data.squeeze(1)
        input_data = input_data.squeeze(0)
        input_data = input_data.unsqueeze(0)
        input_data = input_data.unsqueeze(0)
        # Final 4D tensor

        print(input_data.shape)
        # print(input_data_list[0].shape)
        print('---------3')
        # Balance workload and run inference in parallel
        outputs = balance_workload(models, input_data)
        print('---------4')
        # Format output for response
        output_data = [{"model": i, "output": output} for i, output in enumerate(outputs)]
        print('---------5')
        return jsonify({"outputs": output_data})
    
    except Exception as e:
        return jsonify({"error": str(e)}), 500

@app.route('/')
def home():
    return "Welcome to the model inference API!"

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    models = prepare_models(device)
    print('MODEL')

    port = int(os.getenv("PORT", 5000))
    print(port)
    port += 1
    app.run(debug=True, port=port)
