import re import frappe from frappe import _ from frappe.utils import cint, cstr, flt SPECIAL_CHAR_PATTERN = re.compile(r"[\W]", flags=re.UNICODE) VARCHAR_CAST_PATTERN = re.compile(r"varchar\(([\d]+)\)") def is_link_doctype_autoincrement(doctype): try: linkMeta = frappe.get_meta(doctype, False) if linkMeta.get('autoname') == 'autoincrement': return True except frappe.DoesNotExistError: import os from frappe.model.sync import get_doc_files from frappe.modules.import_file import read_doc_from_file apps = frappe.get_installed_apps() apps.append(frappe.flags.in_install) for app_name in apps: for module_name in frappe.local.app_modules.get(app_name) or []: folder = os.path.dirname(frappe.get_module(app_name + "." + module_name).__file__) files=[] files = get_doc_files(files=files, start_path=folder) l = len(files) if l: for i, doc_path in enumerate(files): try: docs = read_doc_from_file(doc_path) except OSError: print(f"{doc_path} missing") return if docs: if not isinstance(docs, list): docs = [docs] for doc in docs: if "autoname" in doc: if doc["name"] == doctype and doc["autoname"] == 'autoincrement' : return True return False class InvalidColumnName(frappe.ValidationError): pass class DBTable: def __init__(self, doctype, meta=None): self.doctype = doctype self.table_name = f"tab{doctype}" self.meta = meta or frappe.get_meta(doctype, False) self.columns: dict[str, DbColumn] = {} self.current_columns = {} # lists for change self.add_column: list[DbColumn] = [] self.change_type: list[DbColumn] = [] self.change_name: list[DbColumn] = [] self.add_unique: list[DbColumn] = [] self.add_index: list[DbColumn] = [] self.drop_unique: list[DbColumn] = [] self.drop_index: list[DbColumn] = [] self.set_default: list[DbColumn] = [] # load self.get_columns_from_docfields() def sync(self): if self.meta.get("is_virtual"): # no schema to sync for virtual doctypes return if self.is_new(): self.create() else: frappe.cache().hdel("table_columns", self.table_name) self.alter() def create(self): pass def get_column_definitions(self): column_list = [] + frappe.db.DEFAULT_COLUMNS ret = [] for k in list(self.columns): if k not in column_list: d = self.columns[k].get_definition() if d: ret.append("`" + k + "` " + d) column_list.append(k) return ret def get_index_definitions(self): ret = [] for key, col in self.columns.items(): if ( col.set_index and not col.unique and col.fieldtype in frappe.db.type_map and frappe.db.type_map.get(col.fieldtype)[0] not in ("text", "longtext") ): ret.append("index `" + key + "`(`" + key + "`)") return ret def get_columns_from_docfields(self): """ get columns from docfields and custom fields """ fields = self.meta.get_fieldnames_with_value(with_field_meta=True) # optional fields like _comments if not self.meta.get("istable"): for fieldname in frappe.db.OPTIONAL_COLUMNS: fields.append({"fieldname": fieldname, "fieldtype": "Text"}) # add _seen column if track_seen if self.meta.get("track_seen"): fields.append({"fieldname": "_seen", "fieldtype": "Text"}) for field in fields: if field.get("is_virtual"): continue self.columns[field.get("fieldname")] = DbColumn( self, field.get("fieldname"), field.get("fieldtype"), field.get("length"), field.get("default"), field.get("search_index"), field.get("options"), field.get("unique"), field.get("precision"), ) def validate(self): """Check if change in varchar length isn't truncating the columns""" if self.is_new(): return self.setup_table_columns() columns = [ frappe._dict({"fieldname": f, "fieldtype": "Data"}) for f in frappe.db.STANDARD_VARCHAR_COLUMNS ] if self.meta.get("istable"): columns += [ frappe._dict({"fieldname": f, "fieldtype": "Data"}) for f in frappe.db.CHILD_TABLE_COLUMNS ] columns += self.columns.values() for col in columns: if len(col.fieldname) >= 64: frappe.throw( _("Fieldname is limited to 64 characters ({0})").format(frappe.bold(col.fieldname)) ) if "varchar" in frappe.db.type_map.get(col.fieldtype, ()): # validate length range new_length = cint(col.length) or cint(frappe.db.VARCHAR_LEN) if not (1 <= new_length <= 1000): frappe.throw(_("Length of {0} should be between 1 and 1000").format(col.fieldname)) current_col = self.current_columns.get(col.fieldname, {}) if not current_col: continue current_type = self.current_columns[col.fieldname]["type"] current_length = VARCHAR_CAST_PATTERN.findall(current_type) if not current_length: # case when the field is no longer a varchar continue current_length = current_length[0] if cint(current_length) != cint(new_length): try: # check for truncation max_length = frappe.db.sql( """SELECT MAX(CHAR_LENGTH(`{fieldname}`)) FROM `tab{doctype}`""".format( fieldname=col.fieldname, doctype=self.doctype ) ) except frappe.db.InternalError as e: if frappe.db.is_missing_column(e): # Unknown column 'column_name' in 'field list' continue raise if max_length and max_length[0][0] and max_length[0][0] > new_length: if col.fieldname in self.columns: self.columns[col.fieldname].length = current_length info_message = _( "Reverting length to {0} for '{1}' in '{2}'. Setting the length as {3} will cause truncation of data." ).format(current_length, col.fieldname, self.doctype, new_length) frappe.msgprint(info_message) def is_new(self): return self.table_name not in frappe.db.get_tables() def setup_table_columns(self): # TODO: figure out a way to get key data for c in frappe.db.get_table_columns_description(self.table_name): self.current_columns[c.name.lower()] = c def alter(self): pass class DbColumn: def __init__( self, table, fieldname, fieldtype, length, default, set_index, options, unique, precision ): self.table = table self.fieldname = fieldname self.fieldtype = fieldtype self.length = length self.set_index = set_index self.default = default self.options = options self.unique = unique self.precision = precision self.linkcolumndatatype = None if fieldtype == "Link" and options != "[Select]": if is_link_doctype_autoincrement(options): self.linkcolumndatatype = 'Long Int' def get_definition(self, for_modification=False): column_def = get_definition(self.fieldtype, precision=self.precision, length=self.length, linkcolumndatatype=self.linkcolumndatatype) if not column_def: return column_def if self.fieldtype in ("Check", "Int"): default_value = cint(self.default) or 0 column_def += f" not null default {default_value}" elif self.fieldtype in ("Currency", "Float", "Percent"): default_value = flt(self.default) or 0 column_def += f" not null default {default_value}" elif ( self.default and (self.default not in frappe.db.DEFAULT_SHORTCUTS) and not cstr(self.default).startswith(":") ): column_def += f" default {frappe.db.escape(self.default)}" if self.unique and not for_modification and (column_def not in ("text", "longtext")): column_def += " unique" return column_def def build_for_alter_table(self, current_def): column_type = get_definition(self.fieldtype, self.precision, self.length, linkcolumndatatype=self.linkcolumndatatype) # no columns if not column_type: return # to add? if not current_def: self.fieldname = validate_column_name(self.fieldname) self.table.add_column.append(self) if column_type not in ("text", "longtext"): if self.unique: self.table.add_unique.append(self) if self.set_index: self.table.add_index.append(self) return # type if current_def["type"] != column_type: self.table.change_type.append(self) # unique if (self.unique and not current_def["unique"]) and column_type not in ("text", "longtext"): self.table.add_unique.append(self) elif (current_def["unique"] and not self.unique) and column_type not in ("text", "longtext"): self.table.drop_unique.append(self) # default if ( self.default_changed(current_def) and (self.default not in frappe.db.DEFAULT_SHORTCUTS) and not cstr(self.default).startswith(":") ): self.table.set_default.append(self) # index should be applied or dropped irrespective of type change if (current_def["index"] and not self.set_index) and column_type not in ("text", "longtext"): self.table.drop_index.append(self) elif (not current_def["index"] and self.set_index) and not (column_type in ("text", "longtext")): self.table.add_index.append(self) def default_changed(self, current_def): if "decimal" in current_def["type"]: return self.default_changed_for_decimal(current_def) else: cur_default = current_def["default"] new_default = self.default if cur_default == "NULL" or cur_default is None: cur_default = None else: # Strip quotes from default value # eg. database returns default value as "'System Manager'" cur_default = cur_default.lstrip("'").rstrip("'") fieldtype = self.fieldtype if fieldtype in ["Int", "Check"]: cur_default = cint(cur_default) new_default = cint(new_default) elif fieldtype in ["Currency", "Float", "Percent"]: cur_default = flt(cur_default) new_default = flt(new_default) return cur_default != new_default def default_changed_for_decimal(self, current_def): try: if current_def["default"] in ("", None) and self.default in ("", None): # both none, empty return False elif current_def["default"] in ("", None): try: # check if new default value is valid float(self.default) return True except ValueError: return False elif self.default in ("", None): # new default value is empty return True else: # NOTE float() raise ValueError when "" or None is passed return float(current_def["default"]) != float(self.default) except TypeError: return True def validate_column_name(n): if special_characters := SPECIAL_CHAR_PATTERN.findall(n): special_characters = ", ".join(f'"{c}"' for c in special_characters) frappe.throw( _("Fieldname {0} cannot have special characters like {1}").format( frappe.bold(cstr(n)), special_characters ), frappe.db.InvalidColumnName, ) return n def validate_column_length(fieldname): if len(fieldname) > frappe.db.MAX_COLUMN_LENGTH: frappe.throw(_("Fieldname is limited to 64 characters ({0})").format(fieldname)) def get_definition(fieldtype, precision=None, length=None, linkcolumndatatype=None): if(linkcolumndatatype) : d = frappe.db.type_map.get(linkcolumndatatype) else: d = frappe.db.type_map.get(fieldtype) if not d: return if fieldtype == "Int" and length and length > 11: # convert int to long int if the length of the int is greater than 11 d = frappe.db.type_map.get("Long Int") coltype = d[0] size = d[1] if d[1] else None if size: # This check needs to exist for backward compatibility. # Till V13, default size used for float, currency and percent are (18, 6). if fieldtype in ["Float", "Currency", "Percent"] and cint(precision) > 6: size = "21,9" if length: if coltype == "varchar": size = length elif coltype == "int" and length < 11: # allow setting custom length for int if length provided is less than 11 # NOTE: this will only be applicable for mariadb as frappe implements int # in postgres as bigint (as seen in type_map) size = length if size is not None: coltype = f"{coltype}({size})" return coltype def add_column( doctype, column_name, fieldtype, precision=None, length=None, default=None, not_null=False ): if column_name in frappe.db.get_table_columns(doctype): # already exists return frappe.db.commit() query = "alter table `tab{}` add column {} {}".format( doctype, column_name, get_definition(fieldtype, precision, length), ) if not_null: query += " not null" if default: query += f" default '{default}'" frappe.db.sql(query)