Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Project as database for create/retrain/finetune #9135

Merged
merged 7 commits into from
May 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 12 additions & 1 deletion mindsdb/api/executor/command_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -967,7 +967,7 @@ def _sync_predictor_check(self, phase_name):
def answer_retrain_predictor(self, statement, database_name):
model_record = self._get_model_info(statement.name, database_name=database_name)["model_record"]

if statement.integration_name is None:
if statement.query_str is None:
if model_record.data_integration_ref is not None:
if model_record.data_integration_ref["type"] == "integration":
integration = self.session.integration_controller.get_by_id(
Expand All @@ -977,6 +977,9 @@ def answer_retrain_predictor(self, statement, database_name):
raise EntityNotExistsError(
"The database from which the model was trained no longer exists"
)
elif statement.integration_name is None:
# set to current project
statement.integration_name = Identifier(database_name)

ml_handler = None
if statement.using is not None:
Expand Down Expand Up @@ -1018,6 +1021,10 @@ def answer_finetune_predictor(self, statement, database_name):
# repack using with lower names
statement.using = {k.lower(): v for k, v in statement.using.items()}

if statement.query_str is not None and statement.integration_name is None:
# set to current project
statement.integration_name = Identifier(database_name)

# use current ml handler
integration_record = get_predictor_integration(model_record)
if integration_record is None:
Expand Down Expand Up @@ -1515,6 +1522,10 @@ def answer_create_predictor(self, statement: CreatePredictor, database_name):

ml_integration_name = statement.using.pop("engine", ml_integration_name)

if statement.query_str is not None and statement.integration_name is None:
# set to current project
statement.integration_name = Identifier(database_name)

try:
ml_handler = self.session.integration_controller.get_ml_handler(
ml_integration_name
Expand Down
3 changes: 3 additions & 0 deletions mindsdb/integrations/libs/ml_handler_process/learn_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,9 @@ def learn_process(data_integration_ref: dict, problem_definition: dict, fetch_da
query_ast = parse_sql(fetch_data_query, dialect='mindsdb')
view_meta = project.query_view(query_ast)
sqlquery = SQLQuery(view_meta['query_ast'], session=sql_session)
elif data_integration_ref['type'] == 'project':
query_ast = parse_sql(fetch_data_query, dialect='mindsdb')
sqlquery = SQLQuery(query_ast, session=sql_session)

result = sqlquery.fetch(view='dataframe')
training_data_df = result['result']
Expand Down
6 changes: 5 additions & 1 deletion mindsdb/interfaces/jobs/jobs_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,7 @@ def delete(self, name, project_name):

# delete context
query_context_controller.drop_query_context('job', record.id)
query_context_controller.drop_query_context('job-if', record.id)

def _delete_record(self, record):
record.deleted_at = dt.datetime.now()
Expand Down Expand Up @@ -336,7 +337,6 @@ def execute_task_local(self, record_id, history_id=None):
if record.user_class is not None:
ctx.user_class = record.user_class

query_context_controller.set_context('job', record.id)
if history_id is None:
history_record = db.JobsHistory(
job_id=record.id,
Expand All @@ -363,6 +363,7 @@ def execute_task_local(self, record_id, history_id=None):
command_executor = ExecuteCommands(sql_session)

# job with condition?
query_context_controller.set_context('job-if', record.id)
error = ''
to_execute_query = True
if record.if_query_str is not None:
Expand Down Expand Up @@ -390,7 +391,10 @@ def execute_task_local(self, record_id, history_id=None):
if error or data is None or len(data) == 0:
to_execute_query = False

query_context_controller.release_context('job-if', record.id)
if to_execute_query:

query_context_controller.set_context('job', record.id)
for sql in split_sql(record.query_str):
try:
# fill template variables
Expand Down
2 changes: 1 addition & 1 deletion mindsdb/interfaces/model/model_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ def _get_data_integration_ref(statement, database_controller):
# TODO improve here. Suppose that it is view
if data_integration_meta['type'] == 'project':
data_integration_ref = {
'type': 'view'
'type': 'project'
}
elif data_integration_meta['type'] == 'system':
data_integration_ref = {
Expand Down
3 changes: 3 additions & 0 deletions mindsdb/interfaces/query_context/context_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,9 @@ def handle_db_context_vars(self, query: ASTNode, dn, session) -> tuple:
values = self.__get_init_last_values(l_query, dn, session)
if rec is None:
self.__add_context_record(context_name, query_str, values)
if context_name.startswith('job-if-'):
# add context for job also
self.__add_context_record(context_name.replace('job-if', 'job'), query_str, values)
else:
rec.values = values
else:
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_project_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def test_version_managing(self):
ret = self.run_sql(
'''
CREATE model proj.task_model
from dummy_data (select * from tasks)
from proj (select * from dummy_data.tasks)
PREDICT a
using engine='dummy_ml',
tag = 'first',
Expand Down