Skip to content

Commit

Permalink
Tenant variables optimized to use parent instances as default values
Browse files Browse the repository at this point in the history
  • Loading branch information
msmannan00 committed Jun 4, 2024
1 parent c752b2f commit 81247a7
Showing 1 changed file with 116 additions and 17 deletions.
133 changes: 116 additions & 17 deletions backend/globaleaks/models/config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
from sqlalchemy import not_
from sqlalchemy import or_
from globaleaks.models import Config, ConfigL10N, EnabledLanguage
from globaleaks.models.properties import *
from globaleaks.models.config_desc import ConfigDescriptor, ConfigFilters, ConfigL10NFilters
Expand All @@ -9,6 +10,7 @@
from globaleaks.utils.utility import datetime_null


import copy
# List of variables that on creation are set with the value
# they have on the root tenant
inherit_from_root_tenant = [
Expand Down Expand Up @@ -39,31 +41,84 @@ def db_get_configs(session, filter_name):
class ConfigFactory(object):
def __init__(self, session, tid):
self.session = session
self.pid = 1
self.tid = tid

def get_all(self, filter_name):
return {c.var_name: c for c in self.session.query(Config).filter(Config.tid == self.tid, Config.var_name.in_(ConfigFilters[filter_name]))}
def get_all(self, filter_name, reference_pointers=False):
combined_values = self.session.query(Config).filter(
Config.var_name.in_(ConfigFilters[filter_name]),
or_(
Config.tid == self.pid,
Config.tid == self.tid
)
).all()

result = {}
p_result = {}
t_result = {}
for item in combined_values:
if item.tid == self.pid and item.var_name not in result:
result[item.var_name] = item
p_result[item.var_name] = item
if item.tid == self.tid:
result[item.var_name] = item
t_result[item.var_name] = item

if reference_pointers:
return result, p_result, t_result
else:
return result

def update(self, filter_name, data):
for k, v in self.get_all(filter_name).items():
result, p_result, t_result = self.get_all(filter_name, True)
for k, v in result.items():
if k in data:
v.set_v(data[k])
if self.tid != self.pid:
if k in t_result:
if data[k] == p_result[k].value:
self.remove_val(k)
else:
v.set_v(data[k])
elif data[k] != p_result[k].value:
self.session.add(Config({'tid': self.tid, 'var_name': k, 'value': data[k]}))
else:
v.set_v(data[k])

def get_cfg(self, var_name):
return self.session.query(Config).filter(Config.tid == self.tid, Config.var_name == var_name).one_or_none()
subquery = self.session.query(Config) \
.filter(Config.var_name == var_name) \
.filter(or_(Config.tid == self.tid, Config.tid == 1)) \
.order_by(Config.tid.desc()) \
.limit(2).subquery()

config = self.session.query(Config) \
.select_entity_from(subquery) \
.order_by(subquery.c.tid.desc()) \
.first()

return config

def get_val(self, var_name):
v = self.get_cfg(var_name)
if v is None:
config = self.get_cfg(var_name)
if config is None:
return get_default(ConfigDescriptor[var_name].default)

return v.value
new_config = copy.deepcopy(config)
new_config.tid = self.tid
return new_config.value

def set_val(self, var_name, value):
v = self.get_cfg(var_name)
if v:
v.set_v(value)

def remove_val(self, var_name):
v = self.session.query(Config).filter(Config.tid == self.tid, Config.var_name == var_name).one_or_none()

if v:
self.session.delete(v)
self.session.commit()

def serialize(self, filter_name):
return {k: v.value for k, v in self.get_all(filter_name).items()}

Expand All @@ -83,25 +138,61 @@ def update_defaults(self):
class ConfigL10NFactory(object):
def __init__(self, session, tid):
self.session = session
self.pid = 1
self.tid = tid

def initialize(self, keys, lang, data):
for key in keys:
value = data[key][lang] if key in data else ''
self.session.add(ConfigL10N({'tid': self.tid, 'lang': lang, 'var_name': key, 'value': value}))

def get_all(self, filter_name, lang):
return list(self.session.query(ConfigL10N).filter(ConfigL10N.tid == self.tid, ConfigL10N.lang == lang, ConfigL10N.var_name.in_(ConfigL10NFilters[filter_name])))
if self.tid == self.pid:
for key in keys:
value = data[key][lang] if key in data else ''
self.session.add(ConfigL10N({'tid': self.tid, 'lang': lang, 'var_name': key, 'value': value}))

def get_all(self, filter_name, lang, reference_pointers = False):
combined_values = self.session.query(ConfigL10N).filter(
ConfigL10N.lang == lang,
ConfigL10N.var_name.in_(ConfigL10NFilters[filter_name]),
or_(
ConfigL10N.tid == self.pid,
ConfigL10N.tid == self.tid
)
).all()

result = {}
p_result = {}
t_result = {}
for item in combined_values:
if item.tid == self.pid and item.var_name not in result:
result[item.var_name] = item
p_result[item.var_name] = item
if item.tid == self.tid:
result[item.var_name] = item
t_result[item.var_name] = item

if reference_pointers:
return list(result.values()), p_result, t_result
else:
return list(result.values())

def serialize(self, filter_name, lang):
rows = self.get_all(filter_name, lang)
return {c.var_name: c.value for c in rows if c.var_name in ConfigL10NFilters[filter_name]}

def update(self, filter_name, data, lang):
c_map = {c.var_name: c for c in self.get_all(filter_name, lang)}
result, p_result, t_result = self.get_all(filter_name, lang, True)
c_map = {c.var_name: c for c in result}

for key in (x for x in ConfigL10NFilters[filter_name] if x in data):
c_map[key].set_v(data[key])
if key in c_map:
if self.tid != self.pid:
if key in t_result:
if data[key] == p_result[key].value:
self.remove_val(key, lang)
else:
c_map[key].set_v(data[key])
elif data[key] != p_result[key].value:
self.session.add(ConfigL10N({'tid': self.tid, 'lang': lang, 'var_name': key, 'value': data[key]}))
else:
c_map[key].set_v(data[key])

def update_defaults(self, filter_name, langs, data, reset=False):
null = datetime_null()
Expand All @@ -124,6 +215,12 @@ def get_val(self, var_name, lang):

return v.value

def remove_val(self, var_name, lang):
v = self.session.query(ConfigL10N).filter(ConfigL10N.tid == self.tid, ConfigL10N.lang == lang, ConfigL10N.var_name == var_name).one_or_none()
if v:
self.session.delete(v)
self.session.commit()

def set_val(self, var_name, lang, value):
v = self.session.query(ConfigL10N).filter(ConfigL10N.tid == self.tid, ConfigL10N.lang == lang, ConfigL10N.var_name == var_name).one_or_none()
if v:
Expand All @@ -143,6 +240,7 @@ def db_set_config_variable(session, tid, var, val):


def initialize_config(session, tid, mode):
default_tenant_keys = ["subdomain", "onionservice", "https_admin", "https_analyst", "https_cert", "wizard_done", "uuid","default_language","name"]
variables = {}

# Initialization valid for any tenant
Expand All @@ -162,7 +260,8 @@ def initialize_config(session, tid, mode):
variables[name] = root_tenant_node[name]

for name, value in variables.items():
session.add(Config({'tid': tid, 'var_name': name, 'value': value}))
if tid == 1 or name in default_tenant_keys:
session.add(Config({'tid': tid, 'var_name': name, 'value': value}))


def add_new_lang(session, tid, lang, appdata_dict):
Expand Down

0 comments on commit 81247a7

Please sign in to comment.