From 81247a779114d1e21c82bbbf4ad89a24628485a3 Mon Sep 17 00:00:00 2001 From: Abdul Mannan Saeed Date: Tue, 4 Jun 2024 12:37:10 +0500 Subject: [PATCH] Tenant variables optimized to use parent instances as default values --- backend/globaleaks/models/config.py | 133 ++++++++++++++++++++++++---- 1 file changed, 116 insertions(+), 17 deletions(-) diff --git a/backend/globaleaks/models/config.py b/backend/globaleaks/models/config.py index ac04eb1358..3611191870 100644 --- a/backend/globaleaks/models/config.py +++ b/backend/globaleaks/models/config.py @@ -1,5 +1,6 @@ # -*- coding: utf-8 -*- from sqlalchemy import not_ +from sqlalchemy import or_ from globaleaks.models import Config, ConfigL10N, EnabledLanguage from globaleaks.models.properties import * from globaleaks.models.config_desc import ConfigDescriptor, ConfigFilters, ConfigL10NFilters @@ -9,6 +10,7 @@ from globaleaks.utils.utility import datetime_null +import copy # List of variables that on creation are set with the value # they have on the root tenant inherit_from_root_tenant = [ @@ -39,31 +41,84 @@ def db_get_configs(session, filter_name): class ConfigFactory(object): def __init__(self, session, tid): self.session = session + self.pid = 1 self.tid = tid - def get_all(self, filter_name): - return {c.var_name: c for c in self.session.query(Config).filter(Config.tid == self.tid, Config.var_name.in_(ConfigFilters[filter_name]))} + def get_all(self, filter_name, reference_pointers=False): + combined_values = self.session.query(Config).filter( + Config.var_name.in_(ConfigFilters[filter_name]), + or_( + Config.tid == self.pid, + Config.tid == self.tid + ) + ).all() + + result = {} + p_result = {} + t_result = {} + for item in combined_values: + if item.tid == self.pid and item.var_name not in result: + result[item.var_name] = item + p_result[item.var_name] = item + if item.tid == self.tid: + result[item.var_name] = item + t_result[item.var_name] = item + + if reference_pointers: + return result, p_result, t_result + else: + return result def update(self, filter_name, data): - for k, v in self.get_all(filter_name).items(): + result, p_result, t_result = self.get_all(filter_name, True) + for k, v in result.items(): if k in data: - v.set_v(data[k]) + if self.tid != self.pid: + if k in t_result: + if data[k] == p_result[k].value: + self.remove_val(k) + else: + v.set_v(data[k]) + elif data[k] != p_result[k].value: + self.session.add(Config({'tid': self.tid, 'var_name': k, 'value': data[k]})) + else: + v.set_v(data[k]) def get_cfg(self, var_name): - return self.session.query(Config).filter(Config.tid == self.tid, Config.var_name == var_name).one_or_none() + subquery = self.session.query(Config) \ + .filter(Config.var_name == var_name) \ + .filter(or_(Config.tid == self.tid, Config.tid == 1)) \ + .order_by(Config.tid.desc()) \ + .limit(2).subquery() + + config = self.session.query(Config) \ + .select_entity_from(subquery) \ + .order_by(subquery.c.tid.desc()) \ + .first() + + return config def get_val(self, var_name): - v = self.get_cfg(var_name) - if v is None: + config = self.get_cfg(var_name) + if config is None: return get_default(ConfigDescriptor[var_name].default) - return v.value + new_config = copy.deepcopy(config) + new_config.tid = self.tid + return new_config.value def set_val(self, var_name, value): v = self.get_cfg(var_name) if v: v.set_v(value) + def remove_val(self, var_name): + v = self.session.query(Config).filter(Config.tid == self.tid, Config.var_name == var_name).one_or_none() + + if v: + self.session.delete(v) + self.session.commit() + def serialize(self, filter_name): return {k: v.value for k, v in self.get_all(filter_name).items()} @@ -83,25 +138,61 @@ def update_defaults(self): class ConfigL10NFactory(object): def __init__(self, session, tid): self.session = session + self.pid = 1 self.tid = tid def initialize(self, keys, lang, data): - for key in keys: - value = data[key][lang] if key in data else '' - self.session.add(ConfigL10N({'tid': self.tid, 'lang': lang, 'var_name': key, 'value': value})) - - def get_all(self, filter_name, lang): - return list(self.session.query(ConfigL10N).filter(ConfigL10N.tid == self.tid, ConfigL10N.lang == lang, ConfigL10N.var_name.in_(ConfigL10NFilters[filter_name]))) + if self.tid == self.pid: + for key in keys: + value = data[key][lang] if key in data else '' + self.session.add(ConfigL10N({'tid': self.tid, 'lang': lang, 'var_name': key, 'value': value})) + + def get_all(self, filter_name, lang, reference_pointers = False): + combined_values = self.session.query(ConfigL10N).filter( + ConfigL10N.lang == lang, + ConfigL10N.var_name.in_(ConfigL10NFilters[filter_name]), + or_( + ConfigL10N.tid == self.pid, + ConfigL10N.tid == self.tid + ) + ).all() + + result = {} + p_result = {} + t_result = {} + for item in combined_values: + if item.tid == self.pid and item.var_name not in result: + result[item.var_name] = item + p_result[item.var_name] = item + if item.tid == self.tid: + result[item.var_name] = item + t_result[item.var_name] = item + + if reference_pointers: + return list(result.values()), p_result, t_result + else: + return list(result.values()) def serialize(self, filter_name, lang): rows = self.get_all(filter_name, lang) return {c.var_name: c.value for c in rows if c.var_name in ConfigL10NFilters[filter_name]} def update(self, filter_name, data, lang): - c_map = {c.var_name: c for c in self.get_all(filter_name, lang)} + result, p_result, t_result = self.get_all(filter_name, lang, True) + c_map = {c.var_name: c for c in result} for key in (x for x in ConfigL10NFilters[filter_name] if x in data): - c_map[key].set_v(data[key]) + if key in c_map: + if self.tid != self.pid: + if key in t_result: + if data[key] == p_result[key].value: + self.remove_val(key, lang) + else: + c_map[key].set_v(data[key]) + elif data[key] != p_result[key].value: + self.session.add(ConfigL10N({'tid': self.tid, 'lang': lang, 'var_name': key, 'value': data[key]})) + else: + c_map[key].set_v(data[key]) def update_defaults(self, filter_name, langs, data, reset=False): null = datetime_null() @@ -124,6 +215,12 @@ def get_val(self, var_name, lang): return v.value + def remove_val(self, var_name, lang): + v = self.session.query(ConfigL10N).filter(ConfigL10N.tid == self.tid, ConfigL10N.lang == lang, ConfigL10N.var_name == var_name).one_or_none() + if v: + self.session.delete(v) + self.session.commit() + def set_val(self, var_name, lang, value): v = self.session.query(ConfigL10N).filter(ConfigL10N.tid == self.tid, ConfigL10N.lang == lang, ConfigL10N.var_name == var_name).one_or_none() if v: @@ -143,6 +240,7 @@ def db_set_config_variable(session, tid, var, val): def initialize_config(session, tid, mode): + default_tenant_keys = ["subdomain", "onionservice", "https_admin", "https_analyst", "https_cert", "wizard_done", "uuid","default_language","name"] variables = {} # Initialization valid for any tenant @@ -162,7 +260,8 @@ def initialize_config(session, tid, mode): variables[name] = root_tenant_node[name] for name, value in variables.items(): - session.add(Config({'tid': tid, 'var_name': name, 'value': value})) + if tid == 1 or name in default_tenant_keys: + session.add(Config({'tid': tid, 'var_name': name, 'value': value})) def add_new_lang(session, tid, lang, appdata_dict):