-
Notifications
You must be signed in to change notification settings - Fork 245
/
remote.py
1722 lines (1551 loc) · 76 KB
/
remote.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
This module provides the ``FlyteRemote`` object, which is the end-user's main starting point for interacting
with a Flyte backend in an interactive and programmatic way. This of this experience as kind of like the web UI
but in Python object form.
"""
from __future__ import annotations
import base64
import functools
import hashlib
import os
import pathlib
import time
import typing
import uuid
from collections import OrderedDict
from dataclasses import asdict, dataclass
from datetime import datetime, timedelta
from flyteidl.admin.signal_pb2 import Signal, SignalListRequest, SignalSetRequest
from flyteidl.core import literals_pb2 as literals_pb2
from flytekit import Literal
from flytekit.clients.friendly import SynchronousFlyteClient
from flytekit.clients.helpers import iterate_node_executions, iterate_task_executions
from flytekit.configuration import Config, FastSerializationSettings, ImageConfig, SerializationSettings
from flytekit.core import constants, utils
from flytekit.core.base_task import PythonTask
from flytekit.core.context_manager import FlyteContext, FlyteContextManager
from flytekit.core.data_persistence import FileAccessProvider
from flytekit.core.launch_plan import LaunchPlan
from flytekit.core.python_auto_container import PythonAutoContainerTask
from flytekit.core.reference_entity import ReferenceSpec
from flytekit.core.type_engine import LiteralsResolver, TypeEngine
from flytekit.core.workflow import WorkflowBase
from flytekit.exceptions import user as user_exceptions
from flytekit.exceptions.user import FlyteEntityAlreadyExistsException, FlyteEntityNotExistException
from flytekit.loggers import remote_logger
from flytekit.models import common as common_models
from flytekit.models import filters as filter_models
from flytekit.models import launch_plan as launch_plan_models
from flytekit.models import literals as literal_models
from flytekit.models import task as task_models
from flytekit.models import types as type_models
from flytekit.models.admin import common as admin_common_models
from flytekit.models.admin import workflow as admin_workflow_models
from flytekit.models.admin.common import Sort
from flytekit.models.core import workflow as workflow_model
from flytekit.models.core.identifier import Identifier, ResourceType, SignalIdentifier, WorkflowExecutionIdentifier
from flytekit.models.core.workflow import NodeMetadata
from flytekit.models.execution import (
ExecutionMetadata,
ExecutionSpec,
NodeExecutionGetDataResponse,
NotificationList,
WorkflowExecutionGetDataResponse,
)
from flytekit.remote.entities import FlyteLaunchPlan, FlyteNode, FlyteTask, FlyteTaskNode, FlyteWorkflow
from flytekit.remote.executions import FlyteNodeExecution, FlyteTaskExecution, FlyteWorkflowExecution
from flytekit.remote.interface import TypedInterface
from flytekit.remote.lazy_entity import LazyEntity
from flytekit.remote.remote_callable import RemoteEntity
from flytekit.tools.fast_registration import fast_package
from flytekit.tools.script_mode import fast_register_single_script, hash_file
from flytekit.tools.translator import (
FlyteControlPlaneEntity,
FlyteLocalEntity,
Options,
get_serializable,
get_serializable_launch_plan,
)
ExecutionDataResponse = typing.Union[WorkflowExecutionGetDataResponse, NodeExecutionGetDataResponse]
MOST_RECENT_FIRST = admin_common_models.Sort("created_at", admin_common_models.Sort.Direction.DESCENDING)
class RegistrationSkipped(Exception):
"""
RegistrationSkipped error is raised when trying to register an entity that is not registrable.
"""
pass
@dataclass
class ResolvedIdentifiers:
project: str
domain: str
name: str
version: str
def _get_latest_version(list_entities_method: typing.Callable, project: str, domain: str, name: str):
named_entity = common_models.NamedEntityIdentifier(project, domain, name)
entity_list, _ = list_entities_method(
named_entity,
limit=1,
sort_by=Sort("created_at", Sort.Direction.DESCENDING),
)
admin_entity = None if not entity_list else entity_list[0]
if not admin_entity:
raise user_exceptions.FlyteEntityNotExistException("Named entity {} not found".format(named_entity))
return admin_entity.id.version
def _get_entity_identifier(
list_entities_method: typing.Callable,
resource_type: int, # from flytekit.models.core.identifier.ResourceType
project: str,
domain: str,
name: str,
version: typing.Optional[str] = None,
):
return Identifier(
resource_type,
project,
domain,
name,
version if version is not None else _get_latest_version(list_entities_method, project, domain, name),
)
def _get_git_repo_url(source_path):
"""
Get git repo URL from remote.origin.url
"""
try:
from git import Repo
return "github.com/" + Repo(source_path).remotes.origin.url.split(".git")[0].split(":")[-1]
except ImportError:
remote_logger.warning("Could not import git. is the git executable installed?")
except Exception:
# If the file isn't in the git repo, we can't get the url from git config
remote_logger.debug(f"{source_path} is not a git repo.")
return ""
class FlyteRemote(object):
"""Main entrypoint for programmatically accessing a Flyte remote backend.
The term 'remote' is synonymous with 'backend' or 'deployment' and refers to a hosted instance of the
Flyte platform, which comes with a Flyte Admin server on some known URI.
"""
def __init__(
self,
config: Config,
default_project: typing.Optional[str] = None,
default_domain: typing.Optional[str] = None,
data_upload_location: str = "s3://my-s3-bucket/data",
**kwargs,
):
"""Initialize a FlyteRemote object.
:type kwargs: All arguments that can be passed to create the SynchronousFlyteClient. These are usually grpc
parameters, if you want to customize credentials, ssl handling etc.
:param default_project: default project to use when fetching or executing flyte entities.
:param default_domain: default domain to use when fetching or executing flyte entities.
:param data_upload_location: this is where all the default data will be uploaded when providing inputs.
The default location - `s3://my-s3-bucket/data` works for sandbox/demo environment. Please override this for non-sandbox cases.
"""
if config is None or config.platform is None or config.platform.endpoint is None:
raise user_exceptions.FlyteAssertion("Flyte endpoint should be provided.")
self._client = SynchronousFlyteClient(config.platform, **kwargs)
self._config = config
# read config files, env vars, host, ssl options for admin client
self._default_project = default_project
self._default_domain = default_domain
self._file_access = FileAccessProvider(
local_sandbox_dir=os.path.join(config.local_sandbox_path, "control_plane_metadata"),
raw_output_prefix=data_upload_location,
data_config=config.data_config,
)
# Save the file access object locally, build a context for it and save that as well.
self._ctx = FlyteContextManager.current_context().with_file_access(self._file_access).build()
@property
def context(self) -> FlyteContext:
return self._ctx
@property
def client(self) -> SynchronousFlyteClient:
"""Return a SynchronousFlyteClient for additional operations."""
return self._client
@property
def default_project(self) -> str:
"""Default project to use when fetching or executing flyte entities."""
return self._default_project
@property
def default_domain(self) -> str:
"""Default project to use when fetching or executing flyte entities."""
return self._default_domain
@property
def config(self) -> Config:
"""Image config."""
return self._config
@property
def file_access(self) -> FileAccessProvider:
"""File access provider to use for offloading non-literal inputs/outputs."""
return self._file_access
def remote_context(self):
"""Context manager with remote-specific configuration."""
return FlyteContextManager.with_context(
FlyteContextManager.current_context().with_file_access(self.file_access)
)
def fetch_task_lazy(
self, project: str = None, domain: str = None, name: str = None, version: str = None
) -> LazyEntity:
"""
Similar to fetch_task, just that it returns a LazyEntity, which will fetch the workflow lazily.
"""
if name is None:
raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.")
def _fetch():
return self.fetch_task(project=project, domain=domain, name=name, version=version)
return LazyEntity(name=name, getter=_fetch)
def fetch_task(self, project: str = None, domain: str = None, name: str = None, version: str = None) -> FlyteTask:
"""Fetch a task entity from flyte admin.
:param project: fetch entity from this project. If None, uses the default_project attribute.
:param domain: fetch entity from this domain. If None, uses the default_domain attribute.
:param name: fetch entity with matching name.
:param version: fetch entity with matching version. If None, gets the latest version of the entity.
:returns: :class:`~flytekit.remote.tasks.task.FlyteTask`
:raises: FlyteAssertion if name is None
"""
if name is None:
raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.")
task_id = _get_entity_identifier(
self.client.list_tasks_paginated,
ResourceType.TASK,
project or self.default_project,
domain or self.default_domain,
name,
version,
)
admin_task = self.client.get_task(task_id)
flyte_task = FlyteTask.promote_from_model(admin_task.closure.compiled_task.template)
flyte_task.template._id = task_id
return flyte_task
def fetch_workflow_lazy(
self, project: str = None, domain: str = None, name: str = None, version: str = None
) -> LazyEntity[FlyteWorkflow]:
"""
Similar to fetch_workflow, just that it returns a LazyEntity, which will fetch the workflow lazily.
"""
if name is None:
raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.")
def _fetch():
return self.fetch_workflow(project, domain, name, version)
return LazyEntity(name=name, getter=_fetch)
def fetch_workflow(
self, project: str = None, domain: str = None, name: str = None, version: str = None
) -> FlyteWorkflow:
"""
Fetch a workflow entity from flyte admin.
:param project: fetch entity from this project. If None, uses the default_project attribute.
:param domain: fetch entity from this domain. If None, uses the default_domain attribute.
:param name: fetch entity with matching name.
:param version: fetch entity with matching version. If None, gets the latest version of the entity.
:raises: FlyteAssertion if name is None
"""
if name is None:
raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.")
workflow_id = _get_entity_identifier(
self.client.list_workflows_paginated,
ResourceType.WORKFLOW,
project or self.default_project,
domain or self.default_domain,
name,
version,
)
admin_workflow = self.client.get_workflow(workflow_id)
compiled_wf = admin_workflow.closure.compiled_workflow
wf_templates = [compiled_wf.primary.template]
wf_templates.extend([swf.template for swf in compiled_wf.sub_workflows])
node_launch_plans = {}
# TODO: Inspect branch nodes for launch plans
for wf_template in wf_templates:
for node in FlyteWorkflow.get_non_system_nodes(wf_template.nodes):
if node.workflow_node is not None and node.workflow_node.launchplan_ref is not None:
lp_ref = node.workflow_node.launchplan_ref
if node.workflow_node.launchplan_ref not in node_launch_plans:
admin_launch_plan = self.client.get_launch_plan(lp_ref)
node_launch_plans[node.workflow_node.launchplan_ref] = admin_launch_plan.spec
return FlyteWorkflow.promote_from_closure(compiled_wf, node_launch_plans)
def fetch_launch_plan(
self, project: str = None, domain: str = None, name: str = None, version: str = None
) -> FlyteLaunchPlan:
"""Fetch a launchplan entity from flyte admin.
:param project: fetch entity from this project. If None, uses the default_project attribute.
:param domain: fetch entity from this domain. If None, uses the default_domain attribute.
:param name: fetch entity with matching name.
:param version: fetch entity with matching version. If None, gets the latest version of the entity.
:returns: :class:`~flytekit.remote.launch_plan.FlyteLaunchPlan`
:raises: FlyteAssertion if name is None
"""
if name is None:
raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.")
launch_plan_id = _get_entity_identifier(
self.client.list_launch_plans_paginated,
ResourceType.LAUNCH_PLAN,
project or self.default_project,
domain or self.default_domain,
name,
version,
)
admin_launch_plan = self.client.get_launch_plan(launch_plan_id)
flyte_launch_plan = FlyteLaunchPlan.promote_from_model(launch_plan_id, admin_launch_plan.spec)
wf_id = flyte_launch_plan.workflow_id
workflow = self.fetch_workflow(wf_id.project, wf_id.domain, wf_id.name, wf_id.version)
flyte_launch_plan._interface = workflow.interface
flyte_launch_plan._flyte_workflow = workflow
return flyte_launch_plan
def fetch_execution(self, project: str = None, domain: str = None, name: str = None) -> FlyteWorkflowExecution:
"""Fetch a workflow execution entity from flyte admin.
:param project: fetch entity from this project. If None, uses the default_project attribute.
:param domain: fetch entity from this domain. If None, uses the default_domain attribute.
:param name: fetch entity with matching name.
:returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution`
:raises: FlyteAssertion if name is None
"""
if name is None:
raise user_exceptions.FlyteAssertion("the 'name' argument must be specified.")
execution = FlyteWorkflowExecution.promote_from_model(
self.client.get_execution(
WorkflowExecutionIdentifier(
project or self.default_project,
domain or self.default_domain,
name,
)
)
)
return self.sync_execution(execution)
######################
# Listing Entities #
######################
def list_signals(
self,
execution_name: str,
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
limit: int = 100,
filters: typing.Optional[typing.List[filter_models.Filter]] = None,
) -> typing.List[Signal]:
"""
:param execution_name: The name of the execution. This is the tailend of the URL when looking at the workflow execution.
:param project: The execution project, will default to the Remote's default project.
:param domain: The execution domain, will default to the Remote's default domain.
:param limit: The number of signals to fetch
:param filters: Optional list of filters
"""
wf_exec_id = WorkflowExecutionIdentifier(
project=project or self.default_project, domain=domain or self.default_domain, name=execution_name
)
req = SignalListRequest(workflow_execution_id=wf_exec_id.to_flyte_idl(), limit=limit, filters=filters)
resp = self.client.list_signals(req)
s = resp.signals
return s
def set_signal(
self,
signal_id: str,
execution_name: str,
value: typing.Union[literal_models.Literal, typing.Any],
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
python_type: typing.Optional[typing.Type] = None,
literal_type: typing.Optional[type_models.LiteralType] = None,
):
"""
:param signal_id: The name of the signal, this is the key used in the approve() or wait_for_input() call.
:param execution_name: The name of the execution. This is the tail-end of the URL when looking
at the workflow execution.
:param value: This is either a Literal or a Python value which FlyteRemote will invoke the TypeEngine to
convert into a Literal. This argument is only value for wait_for_input type signals.
:param project: The execution project, will default to the Remote's default project.
:param domain: The execution domain, will default to the Remote's default domain.
:param python_type: Provide a python type to help with conversion if the value you provided is not a Literal.
:param literal_type: Provide a Flyte literal type to help with conversion if the value you provided
is not a Literal
"""
wf_exec_id = WorkflowExecutionIdentifier(
project=project or self.default_project, domain=domain or self.default_domain, name=execution_name
)
if isinstance(value, Literal):
remote_logger.debug(f"Using provided {value} as existing Literal value")
lit = value
else:
lt = literal_type or (
TypeEngine.to_literal_type(python_type) if python_type else TypeEngine.to_literal_type(type(value))
)
lit = TypeEngine.to_literal(self.context, value, python_type or type(value), lt)
remote_logger.debug(f"Converted {value} to literal {lit} using literal type {lt}")
req = SignalSetRequest(id=SignalIdentifier(signal_id, wf_exec_id).to_flyte_idl(), value=lit.to_flyte_idl())
# Response is empty currently, nothing to give back to the user.
self.client.set_signal(req)
def recent_executions(
self,
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
limit: typing.Optional[int] = 100,
) -> typing.List[FlyteWorkflowExecution]:
# Ignore token for now
exec_models, _ = self.client.list_executions_paginated(
project or self.default_project,
domain or self.default_domain,
limit,
sort_by=MOST_RECENT_FIRST,
)
return [FlyteWorkflowExecution.promote_from_model(e) for e in exec_models]
def list_tasks_by_version(
self,
version: str,
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
limit: typing.Optional[int] = 100,
) -> typing.List[FlyteTask]:
if not version:
raise ValueError("Must specify a version")
named_entity_id = common_models.NamedEntityIdentifier(
project=project or self.default_project,
domain=domain or self.default_domain,
)
# Ignore token for now
t_models, _ = self.client.list_tasks_paginated(
named_entity_id,
filters=[filter_models.Filter.from_python_std(f"eq(version,{version})")],
limit=limit,
)
return [FlyteTask.promote_from_model(t.closure.compiled_task.template) for t in t_models]
#####################
# Register Entities #
#####################
def _resolve_identifier(self, t: int, name: str, version: str, ss: SerializationSettings) -> Identifier:
ident = Identifier(
resource_type=t,
project=ss.project if ss and ss.project else self.default_project,
domain=ss.domain if ss and ss.domain else self.default_domain,
name=name,
version=version or ss.version,
)
if not ident.project or not ident.domain or not ident.name or not ident.version:
raise ValueError(
f"To register a new {ident.resource_type}, (project, domain, name, version) required, "
f"received ({ident.project}, {ident.domain}, {ident.name}, {ident.version})."
)
return ident
def raw_register(
self,
cp_entity: FlyteControlPlaneEntity,
settings: SerializationSettings,
version: str,
create_default_launchplan: bool = True,
options: Options = None,
og_entity: FlyteLocalEntity = None,
) -> typing.Optional[Identifier]:
"""
Raw register method, can be used to register control plane entities. Usually if you have a Flyte Entity like a
WorkflowBase, Task, LaunchPlan then use other methods. This should be used only if you have already serialized entities
:param cp_entity: The controlplane "serializable" version of a flyte entity. This is in the form that FlyteAdmin
understands.
:param settings: SerializationSettings to be used for registration - especially to identify the id
:param version: Version to be registered
:param create_default_launchplan: boolean that indicates if a default launch plan should be created
:param options: Options to be used if registering a default launch plan
:param og_entity: Pass in the original workflow (flytekit type) if create_default_launchplan is true
:return: Identifier of the created entity
"""
if isinstance(cp_entity, RemoteEntity):
if isinstance(cp_entity, (FlyteWorkflow, FlyteTask)):
if not cp_entity.should_register:
remote_logger.debug(f"Skipping registration of remote entity: {cp_entity.name}")
raise RegistrationSkipped(f"Remote task/Workflow {cp_entity.name} is not registrable.")
else:
remote_logger.debug(f"Skipping registration of remote entity: {cp_entity.name}")
raise RegistrationSkipped(f"Remote task/Workflow {cp_entity.name} is not registrable.")
if isinstance(
cp_entity,
(
workflow_model.Node,
workflow_model.WorkflowNode,
workflow_model.BranchNode,
workflow_model.TaskNode,
),
):
remote_logger.debug("Ignoring nodes for registration.")
return None
elif isinstance(cp_entity, ReferenceSpec):
remote_logger.debug(f"Skipping registration of Reference entity, name: {cp_entity.template.id.name}")
return None
if isinstance(cp_entity, task_models.TaskSpec):
if isinstance(cp_entity, FlyteTask):
version = cp_entity.id.version
ident = self._resolve_identifier(ResourceType.TASK, cp_entity.template.id.name, version, settings)
try:
self.client.create_task(task_identifer=ident, task_spec=cp_entity)
except FlyteEntityAlreadyExistsException:
remote_logger.info(f" {ident} Already Exists!")
return ident
if isinstance(cp_entity, admin_workflow_models.WorkflowSpec):
if isinstance(cp_entity, FlyteWorkflow):
version = cp_entity.id.version
ident = self._resolve_identifier(ResourceType.WORKFLOW, cp_entity.template.id.name, version, settings)
try:
self.client.create_workflow(workflow_identifier=ident, workflow_spec=cp_entity)
except FlyteEntityAlreadyExistsException:
remote_logger.info(f" {ident} Already Exists!")
if create_default_launchplan:
if not og_entity:
raise user_exceptions.FlyteValueException(
"To create default launch plan, please pass in the original flytekit workflow `og_entity`"
)
# Let us also create a default launch-plan, ideally the default launchplan should be added
# to the orderedDict, but we do not.
default_lp = LaunchPlan.get_default_launch_plan(self.context, og_entity)
lp_entity = get_serializable_launch_plan(
OrderedDict(),
settings,
default_lp,
recurse_downstream=False,
options=options,
)
try:
self.client.create_launch_plan(lp_entity.id, lp_entity.spec)
except FlyteEntityAlreadyExistsException:
remote_logger.info(f" {lp_entity.id} Already Exists!")
return ident
if isinstance(cp_entity, launch_plan_models.LaunchPlan):
ident = self._resolve_identifier(ResourceType.LAUNCH_PLAN, cp_entity.id.name, version, settings)
try:
self.client.create_launch_plan(launch_plan_identifer=ident, launch_plan_spec=cp_entity.spec)
except FlyteEntityAlreadyExistsException:
remote_logger.info(f" {ident} Already Exists!")
return ident
raise AssertionError(f"Unknown entity of type {type(cp_entity)}")
def _serialize_and_register(
self,
entity: FlyteLocalEntity,
settings: typing.Optional[SerializationSettings],
version: str,
options: typing.Optional[Options] = None,
) -> Identifier:
"""
This method serializes and register the given Flyte entity
:return: Identifier of the registered entity
"""
m = OrderedDict()
# Create dummy serialization settings for now.
# TODO: Clean this up by using lazy usage of serialization settings in translator.py
serialization_settings = settings
is_dummy_serialization_setting = False
if not settings:
serialization_settings = SerializationSettings(
ImageConfig.auto_default_image(),
project=self.default_project,
domain=self.default_domain,
version=version,
)
is_dummy_serialization_setting = True
_ = get_serializable(m, settings=serialization_settings, entity=entity, options=options)
ident = None
for entity, cp_entity in m.items():
if not isinstance(cp_entity, admin_workflow_models.WorkflowSpec) and is_dummy_serialization_setting:
# Only in the case of workflows can we use the dummy serialization settings.
raise user_exceptions.FlyteValueException(
settings,
f"No serialization settings set, but workflow contains entities that need to be registered. {cp_entity.id.name}",
)
try:
ident = self.raw_register(
cp_entity,
settings=settings,
version=version,
create_default_launchplan=True,
options=options,
og_entity=entity,
)
except RegistrationSkipped:
pass
return ident
def register_task(
self, entity: PythonTask, serialization_settings: SerializationSettings, version: typing.Optional[str] = None
) -> FlyteTask:
"""
Register a qualified task (PythonTask) with Remote
For any conflicting parameters method arguments are regarded as overrides
:param entity: PythonTask can be either @task or a instance of a Task class
:param serialization_settings: Settings that will be used to override various serialization parameters.
:param version: version that will be used to register. If not specified will default to using the serialization settings default
:return:
"""
ident = self._serialize_and_register(entity=entity, settings=serialization_settings, version=version)
ft = self.fetch_task(
ident.project,
ident.domain,
ident.name,
ident.version,
)
ft._python_interface = entity.python_interface
return ft
def register_workflow(
self,
entity: WorkflowBase,
serialization_settings: typing.Optional[SerializationSettings] = None,
version: typing.Optional[str] = None,
default_launch_plan: typing.Optional[bool] = True,
options: typing.Optional[Options] = None,
) -> FlyteWorkflow:
"""
Use this method to register a workflow.
:param version: version for the entity to be registered as
:param entity: The workflow to be registered
:param serialization_settings: The serialization settings to be used
:param default_launch_plan: This should be true if a default launch plan should be created for the workflow
:param options: Additional execution options that can be configured for the default launchplan
:return:
"""
ident = self._resolve_identifier(ResourceType.WORKFLOW, entity.name, version, serialization_settings)
if serialization_settings:
b = serialization_settings.new_builder()
b.project = ident.project
b.domain = ident.domain
b.version = ident.version
serialization_settings = b.build()
ident = self._serialize_and_register(entity, serialization_settings, version, options)
if default_launch_plan:
default_lp = LaunchPlan.get_default_launch_plan(self.context, entity)
self.register_launch_plan(
default_lp, version=ident.version, project=ident.project, domain=ident.domain, options=options
)
remote_logger.debug("Created default launch plan for Workflow")
fwf = self.fetch_workflow(ident.project, ident.domain, ident.name, ident.version)
fwf._python_interface = entity.python_interface
return fwf
def fast_package(self, root: os.PathLike, deref_symlinks: bool = True, output: str = None) -> (bytes, str):
"""
Packages the given paths into an installable zip and returns the md5_bytes and the URL of the uploaded location
:param root: path to the root of the package system that should be uploaded
:param output: output path. Optional, will default to a tempdir
:param deref_symlinks: if symlinks should be dereferenced. Defaults to True
:return: md5_bytes, url
"""
# Create a zip file containing all the entries.
zip_file = fast_package(root, output, deref_symlinks)
md5_bytes, _ = hash_file(pathlib.Path(zip_file))
# Upload zip file to Admin using FlyteRemote.
return self._upload_file(pathlib.Path(zip_file))
def _upload_file(
self, to_upload: pathlib.Path, project: typing.Optional[str] = None, domain: typing.Optional[str] = None
) -> typing.Tuple[bytes, str]:
"""
Function will use remote's client to hash and then upload the file using Admin's data proxy service.
:param to_upload: Must be a single file
:param project: Project to upload under, if not supplied will use the remote's default
:param domain: Domain to upload under, if not specified will use the remote's default
:return: The uploaded location.
"""
if not to_upload.is_file():
raise ValueError(f"{to_upload} is not a single file, upload arg must be a single file.")
md5_bytes, str_digest = hash_file(to_upload)
remote_logger.debug(f"Text hash of file to upload is {str_digest}")
upload_location = self.client.get_upload_signed_url(
project=project or self.default_project,
domain=domain or self.default_domain,
content_md5=md5_bytes,
filename=to_upload.name,
)
self._ctx.file_access.put_data(str(to_upload), upload_location.signed_url)
remote_logger.debug(
f"Uploading {to_upload} to {upload_location.signed_url} native url {upload_location.native_url}"
)
return md5_bytes, upload_location.native_url
@staticmethod
def _version_from_hash(
md5_bytes: bytes,
serialization_settings: SerializationSettings,
*additional_context: str,
) -> str:
"""
The md5 version that we send to S3/GCS has to match the file contents exactly,
but we don't have to use it when registering with the Flyte backend.
To avoid changes in the For that add the hash of the compilation settings to hash of file
:param md5_bytes:
:param serialization_settings:
:param additional_context: This is for additional context to factor into the version computation,
meant for objects (like Options for instance) that don't easily consistently stringify.
:return:
"""
from flytekit import __version__
additional_context = additional_context or []
h = hashlib.md5(md5_bytes)
h.update(bytes(serialization_settings.to_json(), "utf-8"))
h.update(bytes(__version__, "utf-8"))
for s in additional_context:
h.update(bytes(s, "utf-8"))
return base64.urlsafe_b64encode(h.digest()).decode("ascii")
def register_script(
self,
entity: typing.Union[WorkflowBase, PythonTask],
image_config: typing.Optional[ImageConfig] = None,
version: typing.Optional[str] = None,
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
destination_dir: str = ".",
default_launch_plan: typing.Optional[bool] = True,
options: typing.Optional[Options] = None,
source_path: typing.Optional[str] = None,
module_name: typing.Optional[str] = None,
) -> typing.Union[FlyteWorkflow, FlyteTask]:
"""
Use this method to register a workflow via script mode.
:param destination_dir:
:param domain:
:param project:
:param image_config:
:param version: version for the entity to be registered as
:param entity: The workflow to be registered or the task to be registered
:param default_launch_plan: This should be true if a default launch plan should be created for the workflow
:param options: Additional execution options that can be configured for the default launchplan
:param source_path: The root of the project path
:param module_name: the name of the module
:return:
"""
if image_config is None:
image_config = ImageConfig.auto_default_image()
upload_location, md5_bytes = fast_register_single_script(
source_path,
module_name,
functools.partial(
self.client.get_upload_signed_url,
project=project or self.default_project,
domain=domain or self.default_domain,
filename="scriptmode.tar.gz",
),
)
serialization_settings = SerializationSettings(
project=project,
domain=domain,
image_config=image_config,
git_repo=_get_git_repo_url(source_path),
fast_serialization_settings=FastSerializationSettings(
enabled=True,
destination_dir=destination_dir,
distribution_location=upload_location.native_url,
),
)
if version is None:
# The md5 version that we send to S3/GCS has to match the file contents exactly,
# but we don't have to use it when registering with the Flyte backend.
# For that add the hash of the compilation settings to hash of file
version = self._version_from_hash(md5_bytes, serialization_settings)
if isinstance(entity, PythonTask):
return self.register_task(entity, serialization_settings, version)
return self.register_workflow(entity, serialization_settings, version, default_launch_plan, options)
def register_launch_plan(
self,
entity: LaunchPlan,
version: str,
project: typing.Optional[str] = None,
domain: typing.Optional[str] = None,
options: typing.Optional[Options] = None,
) -> FlyteLaunchPlan:
"""
Register a given launchplan, possibly applying overrides from the provided options.
:param entity: Launchplan to be registered
:param version:
:param project: Optionally provide a project, if not already provided in flyteremote constructor or a separate one
:param domain: Optionally provide a domain, if not already provided in FlyteRemote constructor or a separate one
:param options:
:return:
"""
ss = SerializationSettings(image_config=ImageConfig(), project=project, domain=domain, version=version)
ident = self._resolve_identifier(ResourceType.LAUNCH_PLAN, entity.name, version, ss)
m = OrderedDict()
idl_lp = get_serializable_launch_plan(m, ss, entity, recurse_downstream=False, options=options)
try:
self.client.create_launch_plan(ident, idl_lp.spec)
except FlyteEntityAlreadyExistsException:
remote_logger.debug("Launchplan already exists, ignoring")
flp = self.fetch_launch_plan(ident.project, ident.domain, ident.name, ident.version)
flp._python_interface = entity.python_interface
return flp
####################
# Execute Entities #
####################
def _execute(
self,
entity: typing.Union[FlyteTask, FlyteWorkflow, FlyteLaunchPlan],
inputs: typing.Dict[str, typing.Any],
project: str = None,
domain: str = None,
execution_name: str = None,
options: typing.Optional[Options] = None,
wait: bool = False,
type_hints: typing.Optional[typing.Dict[str, typing.Type]] = None,
overwrite_cache: bool = None,
) -> FlyteWorkflowExecution:
"""Common method for execution across all entities.
:param flyte_id: entity identifier
:param inputs: dictionary mapping argument names to values
:param project: project on which to execute the entity referenced by flyte_id
:param domain: domain on which to execute the entity referenced by flyte_id
:param execution_name: name of the execution
:param wait: if True, waits for execution to complete
:param type_hints: map of python types to inputs so that the TypeEngine knows how to convert the input values
into Flyte Literals.
:param overwrite_cache: Allows for all cached values of a workflow and its tasks to be overwritten
for a single execution. If enabled, all calculations are performed even if cached results would
be available, overwriting the stored data once execution finishes successfully.
:returns: :class:`~flytekit.remote.workflow_execution.FlyteWorkflowExecution`
"""
execution_name = execution_name or "f" + uuid.uuid4().hex[:19]
if not options:
options = Options()
if options.disable_notifications is not None:
if options.disable_notifications:
notifications = None
else:
notifications = NotificationList(options.notifications)
else:
notifications = NotificationList([])
type_hints = type_hints or {}
literal_map = {}
with self.remote_context() as ctx:
input_flyte_type_map = entity.interface.inputs
for k, v in inputs.items():
if input_flyte_type_map.get(k) is None:
raise user_exceptions.FlyteValueException(
k, f"The {entity.__class__.__name__} doesn't have this input key."
)
if isinstance(v, Literal):
lit = v
else:
if k not in type_hints:
try:
type_hints[k] = TypeEngine.guess_python_type(input_flyte_type_map[k].type)
except ValueError:
remote_logger.debug(f"Could not guess type for {input_flyte_type_map[k].type}, skipping...")
variable = entity.interface.inputs.get(k)
hint = type_hints[k]
lit = TypeEngine.to_literal(ctx, v, hint, variable.type)
literal_map[k] = lit
literal_inputs = literal_models.LiteralMap(literals=literal_map)
try:
# Currently, this will only execute the flyte entity referenced by
# flyte_id in the same project and domain. However, it is possible to execute it in a different project
# and domain, which is specified in the first two arguments of client.create_execution. This is useful
# in the case that I want to use a flyte entity from e.g. project "A" but actually execute the entity on a
# different project "B". For now, this method doesn't support this use case.
exec_id = self.client.create_execution(
project or self.default_project,
domain or self.default_domain,
execution_name,
ExecutionSpec(
entity.id,
ExecutionMetadata(
ExecutionMetadata.ExecutionMode.MANUAL,
"placeholder", # Admin replaces this from oidc token if auth is enabled.
0,
),
overwrite_cache=overwrite_cache,
notifications=notifications,
disable_all=options.disable_notifications,
labels=options.labels,
annotations=options.annotations,
raw_output_data_config=options.raw_output_data_config,
auth_role=None,
max_parallelism=options.max_parallelism,
security_context=options.security_context,
),
literal_inputs,
)
except user_exceptions.FlyteEntityAlreadyExistsException:
remote_logger.warning(
f"Execution with Execution ID {execution_name} already exists. "
f"Assuming this is the same execution, returning!"
)
exec_id = WorkflowExecutionIdentifier(
project=project or self.default_project, domain=domain or self.default_domain, name=execution_name
)
execution = FlyteWorkflowExecution.promote_from_model(self.client.get_execution(exec_id))
if wait:
return self.wait(execution)
return execution
def _resolve_identifier_kwargs(
self,
entity: typing.Any,
project: str,
domain: str,
name: str,
version: str,
) -> ResolvedIdentifiers:
"""
Resolves the identifier attributes based on user input, falling back on the default project/domain and
auto-generated version, and ultimately the entity project/domain if entity is a remote flyte entity.
"""
ident = ResolvedIdentifiers(
project=project or self.default_project,
domain=domain or self.default_domain,
name=name or entity.name,
version=version,
)
if not (ident.project and ident.domain and ident.name):
raise ValueError(
f"Cannot launch an execution with missing project/domain/name {ident} for entity type {type(entity)}."
f" Specify them in the execute method or when intializing FlyteRemote"
)
return ident
def execute(
self,
entity: typing.Union[FlyteTask, FlyteLaunchPlan, FlyteWorkflow, PythonTask, WorkflowBase, LaunchPlan],
inputs: typing.Dict[str, typing.Any],
project: str = None,