In [1]:
import xmltodict
import json
import re

In [2]:
def read_file(path="read_txt.txt"):
    with open(path, "r") as f:
        result = f.read()
        f.close

    return result


def write_to_file(str_, path="writequery.txt", append=False):
    if append:
        with open(path, "a") as f:
            f.writelines("\n")
            f.writelines(str_)
            f.close
    else:
        with open(path, "w") as f:
            f.write(str_)
            f.close


TAG_RE = re.compile(r"<[^>]+>")


def remove_tags(text):
    return TAG_RE.sub("", text)

In [3]:
xml_file_path = "TestERD.xml"
xml_file_path2 = "HEIMS_ERD-Main.drawio.xml"
xml_data = read_file(xml_file_path)

In [4]:
json_data = ""
with open(xml_file_path) as xml_file:
    data_dict = xmltodict.parse(xml_file.read())
    json_data = json.dumps(data_dict)
erd_json = json.loads(json_data)
    

In [5]:
json_data2 = ""
with open(xml_file_path2) as xml_file:
    data_dict = xmltodict.parse(xml_file.read())
    json_data2 = json.dumps(data_dict)
erd_json2 = json.loads(json_data2)

In [6]:
class ModelBuilder:
    def __init__(self, json_data: dict) -> None:
        self.json_data = json_data
        self.cells = []
        self.tables = []
        self.model_tables = []
        self.new_model_tables = []

    def render_to_model_to_code(self, filename = None):
        table_name = self.get_tab_name()
        _filename = filename if filename else f'{table_name}-model_tables.json'
        self.new_model_tables = self.new_model_tables if self.new_model_tables else self.get_model_tables()
        write_to_file(json.dumps(self.new_model_tables), _filename)

    def get_model_tables(self, ):
        self.cells = self.cells if self.cells else self.get_cells()
        self.tables = self.tables if self.tables else self.get_tables()

        for table_idx, table in enumerate(self.tables):
            table_data = {}
            table_value = table.get('@value')
            table_id = table.get('@id')

            table_data['table_id'] = table_id
            table_data['name'] = table_value
            table_data['fields'] = []

            field_groups = self.get_field_groups(table_id)

            for field_group_idx, field_group in enumerate(field_groups):
                field_group_id = field_group.get('@id')

                field_data, fields = self.get_fields(field_group_id)

                if field_data['field']:
                    table_data['fields'].append(field_data)

            self.model_tables.append(table_data)
        self.new_model_tables = self.remap_fk_to_pk_uuid()

        return self.new_model_tables

    def get_tab_name(self) -> str:
        return self.json_data["mxfile"]["diagram"]["@name"]

    def get_tables(self) -> list:
        self.cells = self.cells if self.cells else self.get_cells()
        if self.tables:
            return self.tables

        for cell in self.cells:
            cell_value = cell.get("@value")
            cell_id = cell.get("@id")
            cell_parent = cell.get("@parent")

            if str(cell_parent) == "1" and cell_value != None:
                if cell_value:
                    self.tables.append(
                        {**cell, "@value": remove_tags(cell_value)})

                if not cell_value:
                    for _cell in self.cells:
                        if _cell.get("@parent") == cell_id:
                            self.tables.append(
                                {**_cell,
                                    "@value": remove_tags(_cell.get("@value"))}
                            )

        return self.tables

    def get_field_groups(self, table_id: str) -> list:
        field_groups = []

        for idx, cell in enumerate(self.cells):
            cell_parent = cell.get('@parent')

            if cell_parent == table_id:
                field_groups.append(cell)

        return field_groups

    def get_fields(self, field_group_id: str) -> list:
        fields = []
        field_data = {
            "field_group_id": "",
            "field": {},
        }

        for idx, cell in enumerate(self.cells):
            cell_parent = cell.get('@parent')

            if cell_parent == field_group_id:
                fields.append(cell)

        field_data['field_group_id'] = field_group_id

        if len(fields) > 0:
            is_fk = bool(fields[0].get('@value') == 'FK')

            fk_to_pk_mapping = self.fk_to_pk_mapping(field_group_id)

            field_data['field'] = {
                "id": fields[1].get('@id'),
                "name": fields[1].get('@value'),
                "type": fields[2].get('@value'),
                "is_pk": bool(fields[0].get('@value') == 'PK'),
                "is_fk": is_fk,
                "fk_type": fk_to_pk_mapping.get('fk_type') or '',
                "foreign_key": fk_to_pk_mapping
            }

        return field_data, fields

    def get_cells(self) -> str:
        return self.json_data["mxfile"]["diagram"]["mxGraphModel"]["root"]["mxCell"]

    def fk_to_pk_mapping(self, field_group_id):
        mapping_field = {}

        for idx, cell in enumerate(self.cells):
            cell_target = cell.get('@target')
            cell_source = cell.get('@source')
            cell_style = self.style_parser(cell.get('@style'))

            if cell_target == field_group_id:
                mapping_field['source_id'] = cell_source
                mapping_field['target_id'] = cell_target
                # mapping_field['value'] = cell
                mapping_field['fk_type'] = (cell_style.get(
                    'startArrow'), cell_style.get('endArrow'))

        return mapping_field

    def style_parser(self, styles):
        style_dict = {}
        if not styles:
            return None

        _styles = styles.split(';')
        for style in _styles:
            _style = style.split('=')
            try:
                style_dict[_style[0]] = _style[1] if _style[1] else None
            except:
                pass

        return style_dict

    def is_source_has__id_name(self, source_id):
        for idx, table in enumerate(self.model_tables):
            for fields in table['fields']:
                if fields['field'].get('is_pk'):
                    field_group_id = fields['field_group_id']
                    new_uuid_field = table['fields'][0]
                    if field_group_id == source_id:
                        if not fields['field']['name']:
                            return False, new_uuid_field

        return True, None

    def re_map_source_id(self, field, table_idx, field_idx):
        new_mapping_field = {}

        if field['field'].get('is_fk'):
            source_id = field['field']['foreign_key'].get('source_id')
            target_id = field['field']['foreign_key'].get('target_id')
            prev_value = field['field']['foreign_key'].get('value')
            has__id_name, new_uuid_field = self.is_source_has__id_name(
                source_id)
            if not has__id_name:
                new_mapping_field['source_id'] = new_uuid_field['field_group_id']
                new_mapping_field['prev_source_id'] = source_id
                new_mapping_field['target_id'] = target_id
                new_mapping_field['prev_value'] = prev_value

        return new_mapping_field

    def remap_fk_to_pk_uuid(self):
        for table_idx, table in enumerate(self.model_tables):
            for field_idx, field in enumerate(table['fields']):
                new_mapping_field = self.re_map_source_id(
                    field, table_idx, field_idx)
                if new_mapping_field:
                    self.model_tables[table_idx]['fields'][field_idx]['field'] = new_mapping_field

        return self.model_tables

In [7]:
model_builder = ModelBuilder(erd_json2)
model_tables_dict = model_builder.get_model_tables()

In [8]:

model_builder.render_to_model_to_code()