#### ERD Agent

In [3]:
#pip install openai python-dotenv

In [5]:
import os
from dotenv import find_dotenv
from dotenv import load_dotenv
from openai import OpenAI

In [6]:
env_file = find_dotenv('key.env')
load_dotenv(env_file)
api_key = os.getenv('OPENAI_API_KEY')
client = OpenAI(api_key=api_key)

In [9]:
pip install graphviz

Note: you may need to restart the kernel to use updated packages.


In [10]:
pip install eralchemy[graphviz]

zsh:1: no matches found: eralchemy[graphviz]
Note: you may need to restart the kernel to use updated packages.


In [11]:
import tempfile
from eralchemy import render_er 

In [15]:
#generating the sql alchemy model by taking the user description and returning a response
def build_sqlalchemy_model(user_description):

    #complete sentences here to ensure llm outputs a complete model
    system_prompt = ("You are an expert at building SQLAlchemy models."
    "Given a database description, write SQLAlchemy models using declarative_base."
    "Only output the models. Do not include explanations, code fences, or comments.")

    user_input = f"Create SQLAlchemy models for this database: {user_description}"

    response = client.chat.completions.create(
        model="gpt-4o",
        messages=[
            {"role": "system", "content":system_prompt},
            {"role": "user", "content": user_input}
        ],
        max_tokens = 800,
        temperature = 0.2
    )

    return response.choices[0].message.content

In [17]:
#saving the model to a temporary file
def save_model(erd_model):
    temp = tempfile.NamedTemporaryFile(delete=False, suffix = '.py', mode='w', encoding='utf-8')
    temp.write(erd_model)
    temp.close()
    return temp.name

In [78]:
#saving the erd diagram file with name + time created
import datetime

def unique_name(description, ext="png"):
    for i in range(1, 10000):
        erd_filename = f"erd_model_{i}.{ext}"
        i = i+1
    return erd_filename

In [80]:
def get_base_from_models(model_file_path):
    import importlib.util
    spec = importlib.util.spec_from_file_location("models", model_file_path)
    models = importlib.util.module_from_spec(spec)
    spec.loader.exec_module(models)

    if hasattr(models, 'Base'):
        return models.Base
    else:
        raise ValueError("Could not find SQLAlchemy Base in generated models.")

In [86]:
#to save scheme of this database into a separate file
from sqlalchemy import create_engine
from sqlalchemy.schema import CreateTable

def save_schema(base, schema_filename, ext="sql"):
    engine = create_engine('sqlite:///:memory:')
    base.metadata.create_all(engine)
    for i in range(1, 10000):
        schema_filename = f"schema_{i}.{ext}"
        i = i+1
    with open(schema_filename, "w", encoding='utf-8') as f:
        for table in base.metadata.sorted_tables:
            ddl = str(CreateTable(table).compile(engine))
            f.write(ddl + ";\n\n")

    return schema_filename

In [88]:
if __name__=="__main__":
    print("Describe your database :")
    user_description = input("> ")
    print("\nGenerating SQLAlchemy models with GPT...")
    erd_model = build_sqlalchemy_model(user_description)
    print("\nGenerated SQLAlchemy models:\n")
    print(erd_model)
    print("\nSaving models and generating ERD image...")

    model_file_path = save_model(erd_model)
    base = get_base_from_models(model_file_path)
    output_file = unique_name(user_description)
    
    render_er(base, output_file)
    save_schema(base, "schema_output.sql")
    print("Database schema saved as schema_output.sql")
    print(f"\nERD image saved as {output_file} (open this file to view your diagram.)")    

Describe your database :


>  An ecommerce website with users, products, orders, and reviews. Users can place orders for products and leave reviews for products they purchased



Generating SQLAlchemy models with GPT...

Generated SQLAlchemy models:

from sqlalchemy import Column, Integer, String, ForeignKey, Float, DateTime, Text
from sqlalchemy.orm import relationship
from sqlalchemy.ext.declarative import declarative_base

Base = declarative_base()

class User(Base):
    __tablename__ = 'users'
    
    id = Column(Integer, primary_key=True)
    username = Column(String, nullable=False, unique=True)
    email = Column(String, nullable=False, unique=True)
    password = Column(String, nullable=False)
    
    orders = relationship('Order', back_populates='user')
    reviews = relationship('Review', back_populates='user')

class Product(Base):
    __tablename__ = 'products'
    
    id = Column(Integer, primary_key=True)
    name = Column(String, nullable=False)
    description = Column(Text)
    price = Column(Float, nullable=False)
    stock = Column(Integer, nullable=False)
    
    reviews = relationship('Review', back_populates='product')
    order_items