<a href="https://colab.research.google.com/github/epikadith/AI-model-creator-demo/blob/main/arch.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [2]:
from google.colab import userdata
key = userdata.get('GOOGLE_API_KEY')

In [3]:
import json
import os
import random
import time
from google import genai
import tensorflow as tf
from google.genai.errors import APIError

def read_code_file(file_path: str) -> str:
    if not os.path.exists(file_path):
        return ""
    with open(file_path, "r") as f:
        return f.read()

def write_code_file(file_path: str, content: str):
    with open(file_path, "w") as f:
        f.write(content)

def get_dataset():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0
    x_train = x_train[..., tf.newaxis]
    x_test = x_test[..., tf.newaxis]
    return x_train, y_train, x_test, y_test

def run_real_training(code_content: str, model_name: str) -> dict:
    x_train, y_train, x_test, y_test = get_dataset()

    exec_globals = {}
    try:
        # Clean the code content by removing markdown fences if present
        cleaned_code_content = code_content.strip()
        if cleaned_code_content.startswith("```python") and cleaned_code_content.endswith("```"):
            cleaned_code_content = cleaned_code_content[len("```python"): -len("```")].strip()
        elif cleaned_code_content.startswith("```") and cleaned_code_content.endswith("```"):
            cleaned_code_content = cleaned_code_content[len("```"): -len("```")].strip()

        exec(cleaned_code_content, exec_globals)
        MyModel = exec_globals['MyModel']
        model = MyModel()
    except Exception as e:
        print(f"Error instantiating model: {e}")
        return {"loss": 99.0, "accuracy": 0.0, "error": str(e), "code_generated": code_content}

    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )

    print(f"🛠️ Training model: {model_name}...")
    try:
        model.fit(x_train, y_train, epochs=1, verbose=2)
        print(f"📊 Evaluating model: {model_name}...")
        results = model.evaluate(x_test, y_test, verbose=2)
        return {
            "loss": round(results[0], 4),
            "accuracy": round(results[1], 4),
            "params": model.count_params(),
            "code_generated": code_content
        }
    except Exception as e:
        print(f"Error during training or evaluation: {e}")
        return {"loss": 99.0, "accuracy": 0.0, "error": str(e), "code_generated": code_content}


def get_llm_response(client: genai.Client, prompt: str, model_name: str) -> str:
    print(f"Calling LLM with prompt: {prompt[:100]}...")
    try:
        response = client.models.generate_content(model=model_name, contents=prompt)
        return response.text
    except APIError as e:
        print(f"An API error occurred: {e}")
        return "Error: Could not generate a response from the LLM."
    except Exception as e:
        print(f"An unexpected error occurred: {e}")
        return "Error: Could not generate a response from the LLM."

class Researcher:
    def __init__(self, client: genai.Client):
        self.client = client
        self.base_prompt = """
        You are an AI architect. Your task is to design a Python class for a neural network model.
        The model should be a simple classifier using Keras/TensorFlow.

        IMPORTANT: You MUST output ONLY the Python code for the class, enclosed in triple backticks (```python).
        The class name must be 'MyModel'. Ensure the code is syntactically correct and runnable.

        Current model code: {current_code}
        Analysis of past experiments: {analysis}

        Based on the analysis (especially any error messages), design a new or corrected architecture.
        If there was an error, your primary goal is to fix that error in the new code.

        Example model structure:
        ```python
        import tensorflow as tf

        class MyModel(tf.keras.Model):
            def __init__(self):
                super().__init__()
                self.flatten = tf.keras.layers.Flatten()
                self.dense1 = tf.keras.layers.Dense(128, activation='relu')
                self.dense2 = tf.keras.layers.Dense(10, activation='softmax')

            def call(self, inputs):
                x = self.flatten(inputs)
                x = self.dense1(x)
                return self.dense2(x)
        ```
        """

    def propose_new_architecture(self, context: dict) -> str:
        analysis = context.get("analysis", "No previous analysis available.")
        current_code = context.get("parent_code", self.get_initial_code())

        full_prompt = self.base_prompt.format(current_code=current_code, analysis=analysis)
        new_code = get_llm_response(self.client, full_prompt, model_name="gemini-2.5-pro")
        return new_code

    def get_initial_code(self) -> str:
        return """
import tensorflow as tf

class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.conv1 = tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1))
        self.pool1 = tf.keras.layers.MaxPooling2D((2, 2))
        self.flatten = tf.keras.layers.Flatten()

        self.dense1 = tf.keras.layers.Dense(64, activation='relu')
        self.dense2 = tf.keras.layers.Dense(10, activation='softmax')

    def call(self, inputs):
        x = self.conv1(inputs)
        x = self.pool1(x)
        x = self.flatten(x)
        x = self.dense1(x)
        return self.dense2(x)
"""

class Engineer:
    def __init__(self, client: genai.Client):
        self.client = client

    def train_and_evaluate(self, code_content: str, model_name: str) -> dict:
        return run_real_training(code_content, model_name)

class Analyst:
    def __init__(self, client: genai.Client):
        self.client = client
        self.base_prompt = """
        You are a seasoned AI researcher. Analyze the following model's performance and code.

        Model code: {code_content}
        Performance metrics: {metrics}

        IMPORTANT: If the metrics contain an 'error' key, your primary task is to analyze that error.
        Provide concrete, line-by-line suggestions for fixing the code, clearly pointing out the problematic parts.
        Your analysis should be directly actionable for the Researcher to generate corrected code.

        If there's no error, provide an insightful summary of the model's strengths and weaknesses,
        and suggest concrete architectural changes for the next iteration to improve performance (e.g., add more layers, change activation, etc.).
        """

    def analyze_results(self, code_content: str, metrics: dict) -> str:
        prompt = self.base_prompt.format(code_content=code_content, metrics=json.dumps(metrics, indent=2))
        analysis = get_llm_response(self.client, prompt, model_name="gemini-2.5-flash")
        return analysis

def main(api_key: str):
    client = genai.Client(api_key=api_key)

    researcher = Researcher(client)
    engineer = Engineer(client)
    analyst = Analyst(client)

    database = []
    current_context = {}

    for i in range(3):
        print(f"\n--- 🚀 Starting Iteration {i+1} ---")

        print("💡 Researcher is proposing a new architecture...")
        new_code = researcher.propose_new_architecture(current_context)
        model_name = f"model_iter_{i+1}"
        file_path = f"{model_name}.py"

        write_code_file(file_path, new_code)

        print("⚙️ Engineer is training the model...")
        performance_metrics = engineer.train_and_evaluate(new_code, model_name)

        print("📊 Analyst is evaluating the results...")
        analysis_report = analyst.analyze_results(new_code, performance_metrics)

        experiment = {
            "id": i + 1,
            "name": model_name,
            "code": new_code,
            "metrics": performance_metrics,
            "analysis": analysis_report
        }
        database.append(experiment)

        current_context["analysis"] = analysis_report
        current_context["parent_code"] = new_code

        print(f"\n✅ Iteration {i+1} complete. Metrics: {performance_metrics}")

if __name__ == "__main__":
    key = key
    main(key)


--- 🚀 Starting Iteration 1 ---
💡 Researcher is proposing a new architecture...
Calling LLM with prompt: 
        You are an AI architect. Your task is to design a Python class for a neural network model.
...
⚙️ Engineer is training the model...
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
[1m11490434/11490434[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step
🛠️ Training model: model_iter_1...
1875/1875 - 26s - 14ms/step - accuracy: 0.9508 - loss: 0.1656
📊 Evaluating model: model_iter_1...
313/313 - 2s - 6ms/step - accuracy: 0.9809 - loss: 0.0615
📊 Analyst is evaluating the results...
Calling LLM with prompt: 
        You are a seasoned AI researcher. Analyze the following model's performance and code.
     ...

✅ Iteration 1 complete. Metrics: {'loss': 0.0615, 'accuracy': 0.9809, 'params': 347146, 'code_generated': "```python\nimport tensorflow as tf\n\nclass MyModel(tf.keras.Model):\n    def __init__(self):\n        super()