In [1]:
import re
import tempfile
import urllib
import webbrowser
from pathlib import Path

import sqlalchemy as sa


def get_mermaid(constr):
    engine = sa.create_engine(constr)
    meta = sa.MetaData(bind=engine)
    sa.MetaData.reflect(meta)
    nodes = [to_mermaid_node(table) for table in meta.tables.values()]

    edge_fstr = '%s ||--|{ %s : "%s"'
    edges = []
    for table_id, table in meta.tables.items():
        for fk in table.foreign_keys:
            col = fk.column
            edges.append(edge_fstr % (table_id, col.table.fullname, fk.parent.name))
    return "\n  ".join(["erDiagram", *nodes, *edges])


def to_mermaid_node(table: sa.Table):
    pks = set([c.name for c in table.primary_key.columns])
    ind_dic = {i.name: [c.name for c in i.columns] for i in table.indexes}
    fks = set([fk.parent.name for fk in table.foreign_keys])
    cols = []
    for col in table.columns:
        cn = col.name
        extra = "PK" if cn in pks else ("FK" if cn in fks else "")
        in_inds = [k for k, v in ind_dic.items() if cn in v]
        if in_inds:
            comm = f'"in index: {", ".join(in_inds)}"'
        else:
            comm = ""
        cols.append(" ".join(["   ", format_col_type(col), cn, extra, comm]))
    head = table.fullname + " {"
    return "    \n".join([head, *cols]) + "\n  }"


def format_col_type(col):
    try:
        out = col.type.get_col_spec()
    except (AttributeError, NotImplementedError):
        out = str(col.type)
    return re.sub(r"\((\d+)\)", lambda m: "-" + m.groups()[0], out)


def to_file(constr: str, output_fp: str):
    out_path = Path(output_fp)
    try:
        frame = frame_dic[out_path.suffix]
    except KeyError:
        raise ValueError(f"extension not in {frame_dic.keys()}")

    out_path.write_text(frame % get_mermaid(constr))


def open_in_browser(constr):
    output_file = tempfile.NamedTemporaryFile(suffix=".html", delete=False)
    to_file(output_file)
    url = urllib.parse.urlunparse(("file", "", output_file.name, "", "", ""))
    webbrowser.open(url)


html_frame = """
<html>
    <body>
        <script src="https://cdn.jsdelivr.net/npm/mermaid/dist/mermaid.min.js"></script>
        <script>
            mermaid.initialize({ startOnLoad: true });
        </script>

        <h1>EDR</h1>
        <div class="mermaid">
%s
        </div>
    </body>
</html>
"""


md_frame = """

```mermaid
%s
```

"""

frame_dic = {".html": html_frame, ".md": md_frame}

In [2]:
import sqlalchemy as sa

# from sqlmermaid import open_in_browser, get_mermaid


metadata_obj = sa.MetaData()

In [3]:
user = sa.Table(
    "user",
    metadata_obj,
    sa.Column("user_id", sa.Integer, primary_key=True),
    sa.Column("user_name", sa.String(16), nullable=False),
    sa.Column("email_address", sa.String(60)),
    sa.Column("nickname", sa.String(50), nullable=False),
)

In [4]:
departments = sa.Table(
    "departments",
    metadata_obj,
    sa.Column("department_id", sa.Integer, primary_key=True),
    sa.Column("department_name", sa.String(60), nullable=False),
)

In [5]:
employees = sa.Table(
    "employees",
    metadata_obj,
    sa.Column("employee_id", sa.Integer, primary_key=True),
    sa.Column("employee_name", sa.String(60), nullable=False),
    sa.Column("employee_dept", sa.Integer, sa.ForeignKey("departments.department_id")),
)

In [6]:
user_complaints = sa.Table(
    "complaints",
    metadata_obj,
    sa.Column("about_dept", sa.Integer, sa.ForeignKey("departments.department_id")),
    sa.Column("by_user", sa.Integer, sa.ForeignKey("user.user_id")),
    sa.Column("at_time", sa.DateTime),
    sa.Column("text", sa.String(500)),
    sa.Index("comp_ind", "text"),
)

In [9]:
constr = "sqlite:///db.sqlite"  # "sqlite:///:memory:"
engine = sa.create_engine(constr)
metadata_obj.create_all(engine)

In [11]:
print(get_mermaid(constr))

erDiagram
  complaints {    
    INTEGER about_dept FK     
    INTEGER by_user FK     
    DATETIME at_time      
    VARCHAR-500 text  "in index: comp_ind"
  }
  departments {    
    INTEGER department_id PK     
    VARCHAR-60 department_name  
  }
  user {    
    INTEGER user_id PK     
    VARCHAR-16 user_name      
    VARCHAR-60 email_address      
    VARCHAR-50 nickname  
  }
  employees {    
    INTEGER employee_id PK     
    VARCHAR-60 employee_name      
    INTEGER employee_dept FK 
  }
  complaints ||--|{ departments : "about_dept"
  complaints ||--|{ user : "by_user"
  employees ||--|{ departments : "employee_dept"


In [12]:
Path("db.sqlite").unlink()