-
Notifications
You must be signed in to change notification settings - Fork 269
/
mad_hatter.py
303 lines (234 loc) · 11 KB
/
mad_hatter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
import os
import glob
import shutil
import inspect
import traceback
from copy import deepcopy
from typing import List, Dict
from cat.log import log
import cat.utils as utils
from cat.utils import singleton
from cat.db import crud
from cat.db.models import Setting
from cat.mad_hatter.plugin_extractor import PluginExtractor
from cat.mad_hatter.plugin import Plugin
from cat.mad_hatter.decorators.hook import CatHook
from cat.mad_hatter.decorators.tool import CatTool
from cat.experimental.form import CatForm
# This class is responsible for plugins functionality:
# - loading
# - prioritizing
# - executing
@singleton
class MadHatter:
# loads and execute plugins
# - enter into the plugin folder and loads everthing
# that is decorated or named properly
# - orders plugged in hooks by name and priority
# - exposes functionality to the cat
def __init__(self):
self.plugins: Dict[str, Plugin] = {} # plugins dictionary
self.hooks: Dict[str, List[CatHook]] = {} # dict of active plugins hooks ( hook_name -> [CatHook, CatHook, ...])
self.tools: List[CatTool] = [] # list of active plugins tools
self.forms: List[CatForm] = [] # list of active plugins forms
self.active_plugins: List[str] = []
self.plugins_folder = utils.get_plugins_path()
# this callback is set from outside to be notified when plugin sync is finished
self.on_finish_plugins_sync_callback = lambda: None
self.find_plugins()
def install_plugin(self, package_plugin):
# extract zip/tar file into plugin folder
extractor = PluginExtractor(package_plugin)
plugin_path = extractor.extract(self.plugins_folder)
# remove zip after extraction
os.remove(package_plugin)
# get plugin id (will be its folder name)
plugin_id = os.path.basename(plugin_path)
# create plugin obj
self.load_plugin(plugin_path)
# activate it
self.toggle_plugin(plugin_id)
def uninstall_plugin(self, plugin_id):
if self.plugin_exists(plugin_id) and (plugin_id != "core_plugin"):
# deactivate plugin if it is active (will sync cache)
if plugin_id in self.active_plugins:
self.toggle_plugin(plugin_id)
# remove plugin from cache
plugin_path = self.plugins[plugin_id].path
del self.plugins[plugin_id]
# remove plugin folder
shutil.rmtree(plugin_path)
# discover all plugins
def find_plugins(self):
# emptying plugin dictionary, plugins will be discovered from disk
# and stored in a dictionary plugin_id -> plugin_obj
self.plugins = {}
self.active_plugins = self.load_active_plugins_from_db()
# plugins are found in the plugins folder,
# plus the default core plugin s(where default hooks and tools are defined)
core_plugin_folder = "cat/mad_hatter/core_plugin/"
# plugin folder is "cat/plugins/" in production, "tests/mocks/mock_plugin_folder/" during tests
all_plugin_folders = [core_plugin_folder] + glob.glob(f"{self.plugins_folder}*/")
log.info("ACTIVE PLUGINS:")
log.info(self.active_plugins)
# discover plugins, folder by folder
for folder in all_plugin_folders:
self.load_plugin(folder)
plugin_id = os.path.basename(os.path.normpath(folder))
if plugin_id in self.active_plugins:
self.plugins[plugin_id].activate()
self.sync_hooks_tools_and_forms()
def load_plugin(self, plugin_path):
# Instantiate plugin.
# If the plugin is inactive, only manifest will be loaded
# If active, also settings, tools and hooks
try:
plugin = Plugin(plugin_path)
# if plugin is valid, keep a reference
self.plugins[plugin.id] = plugin
except Exception as e:
# Something happened while loading the plugin.
# Print the error and go on with the others.
log.error(str(e))
# Load hooks, tools and forms of the active plugins into MadHatter
def sync_hooks_tools_and_forms(self):
# emptying tools, hooks and forms
self.hooks = {}
self.tools = []
self.forms = []
for _, plugin in self.plugins.items():
# load hooks, tools and forms from active plugins
if plugin.id in self.active_plugins:
# cache tools
self.tools += plugin.tools
self.forms += plugin.forms
# cache hooks (indexed by hook name)
for h in plugin.hooks:
if h.name not in self.hooks.keys():
self.hooks[h.name] = []
self.hooks[h.name].append(h)
# sort each hooks list by priority
for hook_name in self.hooks.keys():
self.hooks[hook_name].sort(key=lambda x: x.priority, reverse=True)
# notify sync has finished (the Cat will ensure all tools are embedded in vector memory)
self.on_finish_plugins_sync_callback()
# check if plugin exists
def plugin_exists(self, plugin_id):
return plugin_id in self.plugins.keys()
def load_active_plugins_from_db(self):
active_plugins = crud.get_setting_by_name("active_plugins")
if active_plugins is None:
active_plugins = []
else:
active_plugins = active_plugins["value"]
# core_plugin is always active
if "core_plugin" not in active_plugins:
active_plugins += ["core_plugin"]
return active_plugins
def save_active_plugins_to_db(self, active_plugins):
new_setting = {
"name": "active_plugins",
"value": active_plugins
}
new_setting = Setting(**new_setting)
crud.upsert_setting_by_name(new_setting)
# activate / deactivate plugin
def toggle_plugin(self, plugin_id):
if self.plugin_exists(plugin_id):
plugin_is_active = plugin_id in self.active_plugins
# update list of active plugins
if plugin_is_active:
log.warning(f"Toggle plugin {plugin_id}: Deactivate")
# Execute hook on plugin deactivation
# Deactivation hook must happen before actual deactivation,
# otherwise the hook will not be available in _plugin_overrides anymore
for hook in self.plugins[plugin_id]._plugin_overrides:
if hook.name == "deactivated":
hook.function(self.plugins[plugin_id])
# Deactivate the plugin
self.plugins[plugin_id].deactivate()
# Remove the plugin from the list of active plugins
self.active_plugins.remove(plugin_id)
else:
log.warning(f"Toggle plugin {plugin_id}: Activate")
# Activate the plugin
self.plugins[plugin_id].activate()
# Execute hook on plugin activation
# Activation hook must happen before actual activation,
# otherwise the hook will still not be available in _plugin_overrides
for hook in self.plugins[plugin_id]._plugin_overrides:
if hook.name == "activated":
hook.function(self.plugins[plugin_id])
# Add the plugin in the list of active plugins
self.active_plugins.append(plugin_id)
# update DB with list of active plugins, delete duplicate plugins
self.save_active_plugins_to_db(list(set(self.active_plugins)))
# update cache and embeddings
self.sync_hooks_tools_and_forms()
else:
raise Exception("Plugin {plugin_id} not present in plugins folder")
# execute requested hook
def execute_hook(self, hook_name, *args, cat):
# check if hook is supported
if hook_name not in self.hooks.keys():
raise Exception(f"Hook {hook_name} not present in any plugin")
# Hook has no arguments (aside cat)
# no need to pipe
if len(args) == 0:
for hook in self.hooks[hook_name]:
try:
log.debug(f"Executing {hook.plugin_id}::{hook.name} with priority {hook.priority}")
hook.function(cat=cat)
except Exception as e:
log.error(f"Error in plugin {hook.plugin_id}::{hook.name}")
log.error(e)
plugin_obj = self.plugins[hook.plugin_id]
log.warning(plugin_obj.plugin_specific_error_message())
traceback.print_exc()
return
# Hook with arguments.
# First argument is passed to `execute_hook` is the pipeable one.
# We call it `tea_cup` as every hook called will receive it as an input,
# can add sugar, milk, or whatever, and return it for the next hook
tea_cup = deepcopy(args[0])
# run hooks
for hook in self.hooks[hook_name]:
try:
# pass tea_cup to the hooks, along other args
# hook has at least one argument, and it will be piped
log.debug(f"Executing {hook.plugin_id}::{hook.name} with priority {hook.priority}")
tea_spoon = hook.function(
deepcopy(tea_cup),
*deepcopy(args[1:]),
cat=cat
)
#log.debug(f"Hook {hook.plugin_id}::{hook.name} returned {tea_spoon}")
if tea_spoon is not None:
tea_cup = tea_spoon
except Exception as e:
log.error(f"Error in plugin {hook.plugin_id}::{hook.name}")
log.error(e)
plugin_obj = self.plugins[hook.plugin_id]
log.warning(plugin_obj.plugin_specific_error_message())
traceback.print_exc()
# tea_cup has passed through all hooks. Return final output
return tea_cup
# get plugin object (used from within a plugin)
# TODO: should we allow to take directly another plugins' obj?
# TODO: throw exception if this method is called from outside the plugins folder
def get_plugin(self):
# who's calling?
calling_frame = inspect.currentframe().f_back
# Get the module associated with the frame
module = inspect.getmodule(calling_frame)
# Get the absolute and then relative path of the calling module's file
abs_path = inspect.getabsfile(module)
rel_path = os.path.relpath(abs_path)
# Replace the root and get only the current plugin folder
plugin_suffix = rel_path.replace(utils.get_plugins_path(), "")
# Plugin's folder
name = plugin_suffix.split("/")[0]
return self.plugins[name]
@property
def procedures(self):
return self.tools + self.forms