-
Notifications
You must be signed in to change notification settings - Fork 579
/
computation.proto
957 lines (883 loc) · 47.2 KB
/
computation.proto
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
syntax = "proto3";
package tensorflow_federated.v0;
import "google/protobuf/any.proto";
import "tensorflow_federated/proto/v0/array.proto";
import "tensorflow_federated/proto/v0/data_type.proto";
// A core data structure that contains a serialized representation of a unit of
// processing to perform by the TensorFlow Federated framework. This data
// structure is the primary unit of composition and the means by which we
// represent, store, and exchange federated computations and their constituents
// between system components. It is the lowest and smallest programmable
// abstraction layer that a range of higher-level APIs will be layered upon,
// structured around the minimum set of concepts and abstractions that provide
// a level of expressiveness sufficient to efficiently support current and
// anticipated uses. This layer is not intended for consumption by most users.
//
// In its most general sense, an instance of a Computation as defined here is
// simply an expression that produces a certain value. The structure of this
// expression, typically nested, determines how this value is intended to be
// computed (hence the term "computation"). We may use terms "expression" and
// "computation" interchangeably in this and other files, although technically,
// the term "computation" refers to a process, whereas "expression" refers to
// a specification of that process.
message Computation {
// The type of what's represented by this structure, which may be functional
// or non-functional. If it is a TensorFlow block or a lambda expression,
// the type will be functional. If it is a Struct, or a Call that returns a
// tensor or a Struct in the result, the type will be non-functional.
//
// A Call is a typical way to represent an invocation of a top-level federated
// computation with all its parameters fully specified. Thus, a top-level
// computation with all of its parameters filled in may have a non-functional
// type (the same as type of the result it computes). The illustrative
// example to think of is "(x -> x + 10)(20)", the type of which is an int, a
// non-functional type. If a top-level federated computation has all of its
// parameters filled in, it will assume a similar form.
Type type = 1;
// The specification of the computation to perform.
//
// A hypothetical example of a federated computation definition in Python,
// expressed in a yet-to-be-defined syntax, might translate into definitions
// in a serialized form as shown below.
//
// @tff.computation
// def fed_eval(model):
//
// @tfe.defun
// def local_eval(model):
// ...
// return {'loss': ..., 'accuracy': ...}
//
// client_model = tff.federated_broadcast(model)
// client_metrics = tff.federated_map(local_eval, client_model)
// return tff.federated_mean(client_metrics)
//
//
// fed_eval = Computation(lambda=Lambda(
// parameter_name='model',
// result=Computation(block=Block(
// local=[
// Block.Local(name='local_eval', value=Computation(
// tensorflow=TensorFlow(...))),
// Block.Local(name='client_model', value=Computation(
// call=Call(
// function=Computation(
// intrinsic=Intrinsic(uri='federated_broadcast')),
// argument=Computation(
// reference=Reference(name='model'))))),
// Block.Local(name='client_metrics', value=Computation(
// call=Call(
// function=Computation(
// intrinsic=Intrinsic(uri='federated_map')),
// argument=Computation(
// struct=Struct(element=[
// Struct.Element(
// value=Computation(
// reference=Reference(
// name='local_eval'))),
// Struct.Element(
// value=Computation(
// reference=Reference(
// name='local_client_model')))
// ])))))],
// result=Computation(
// call=Call(
// function=Computation(
// intrinsic=Intrinsic(uri='federated_mean')),
// argument=Computation(
// reference=Reference(name='client_metrics'))))))))
//
oneof computation {
// NON-COMPOSITIONAL CONSTRUCTS.
//
// The following constructs are the basic building blocks that can be
// composed into larger computations with the use of the compositional
// constructs defined below.
// TensorFlow computation. TensorFlow computations have functional type
// signatures that cannot contain FederatedTypes, as they execute locally.
// In order to construct a TensorFlow computation that maps a federated
// value pointwise, one must use a federated map intrinsic (to be defined).
TensorFlow tensorflow = 2;
// A built-in federated communication operator such as broadcast, federated
// sum, etc., or one of the custom operators added to the framework, and
// recognized by the compiler pipeline. Intrinsics have functional types,
// and most are defined as templates that can operate on abstract types,
// and/or federated values with arbitrary placements.
Intrinsic intrinsic = 3;
// An external source of data to be used by a computation.
Data data = 10;
// COMPOSITIONAL CONSTRUCTS.
//
// The following constructs can be used to combine simpler computations
// into more complex ones. For example, they can be used to express the
// top-level orchestration logic of a federated computation that combines
// blocks of client-side and server-side TensorFlow code with federated
// communication operators such as federated aggregation or broadcast.
// A lambda expression is the primary means of defining new parameterized
// computations. lambdas always have functional types.
Lambda lambda = 4;
// A block of computation logic, i.e., a series of expressions that refer
// to one-another. This mechanism is intended as a primary means of
// breaking down longer sequences of processing into simpler parts. A block
// can have a functional or a non-functional type (matching the type of its
// result), as it is primarily a mechanism for organizing code.
Block block = 5;
// A reference to a name defined in a surrounding context, such as a Lambda
// or a Block, with the usual scoping rules (the name refers to the
// innermost scope in which it is defined). Always matches the type of the
// the parameter it references, i.e., T if the type of the lambda is T->T'.
// For example, in a Lambda "x : int -> foo(x)", which associates locally
// name "x" with its parameter, the reference to "x" will be of type "int",
// just as the parameter of the lambda in which the name "x" is defined.
Reference reference = 6;
// A function call is the primary means of using lambdas, TensorFlow blocks,
// and other types of functional constructs to compute a specific result in
// a concrete context. The type of the call is the same as the type of the
// result of the function being called, i.e., a call with parameter of type
// T to a function of type T -> T' has type T'.
Call call = 7;
// A struct is explicitly constructed from individual member values.
Struct struct = 8;
// A selection by name or index from the result of another expression that
// returns a Struct. The type of the selection matches the type of the
// Struct element being selected (known statically, as the name or index is
// known statically, rather than computed).
//
// Note: In higher layers of the API, we will offer convenience mechanisms
// such as selection from a federated type. For example, if "x" is of a
// federated type "{<foo=int, bar=string>}@clients", we will allow notation
// such as "x.foo" as a convenient shortcut for a pointwise selection that
// might be written as "federated_map(x, y->y.foo)" in a more complete form
// even though "x" is technically not a Struct. Here at the level of
// the Computation proto, however, we will represent computations in their
// fully fleshed-out form, with map and other implicit operators already
// injected at construction time by the framework as it translates a Python
// source cosde that defines a computation into this serialized form.
Selection selection = 9;
// A placement literal.
Placement placement = 11;
// A value literal.
Literal literal = 14;
// EXPERIMENTAL CONSTRUCTS.
//
// The following constructs are currently considered experimental, and are
// not formally supported yet. They may be recognized, partially or fully
// handled by parts of the TFF codebase, but at this stage, one should not
// depend on this being the case. Their exact representation may continue
// to evolve, or they may be removed altogether. When these constructs are
// ready for public consumption and come with an explicit support, we will
// remove the "experimental" designation.
// A local (non-federated) computation expressed in XLA. XLA computations
// have functional type signatures, and are used in a manner similar to
// local computations expressed in TensorFlow. However, not all types of
// TensorFlow computations are expressible in XLA (see below for the list
// of current limitations).
Xla xla = 12;
}
// Reserving: The field is deleted.
reserved 13;
// NEXT ID: 14
}
// A generic representation of an arbitrary type, defined as a variant over a
// number of primitive and compound types that can be nested. Note that not all
// nestings expressible with this structure may be valid, e.g., it may not make
// sense to declare a sequence of functions, or a federated type in which
// individual member values are themselves federated. However, rather than
// constraining the set of possible nestings at the syntactic level, which would
// increase boilerplate and could prove limiting in the future, we keep this
// variant structure simple, and we let the set of all valid type nestings be
// determined by the set of the currently supported operators. The current
// limitations on nesting are as follows:
// - FederatedType and FunctionType cannot be nested within a FederatedType or
// within a SequenceType. Currently, these may only be nested within a
// StructType.
// - A SequenceType currently cannot be nested within another SequenceType.
message Type {
oneof type {
FunctionType function = 1;
StructType struct = 2;
SequenceType sequence = 3;
TensorType tensor = 4;
AbstractType abstract = 5;
PlacementType placement = 6;
FederatedType federated = 7;
}
}
// A representation of a functional type. Functions must have at most a single
// parameter and a single result. Multiple parameters or results to be modeled
// as compound types (e.g., as Structs). Note that since functions accept
// generic types, one can declare functions as parameters or results of other
// functions. We may not support functions as first-class values directly in
// the API surface, but the ability to express this is useful in defining type
// signatures for federated communication operators, and to support various
// types of extensibility.
// Concise syntax for examples of functional types: "T -> T'", where T, T' are
// the types of parameter and result, respectively.
message FunctionType {
Type parameter = 1;
Type result = 2;
}
// A representation of a type of a struct. A struct is a compound type
// based on a similar type in Python that defines a finite set of named members,
// the types of which are known statically, that are arranged in a prescribed
// order and can be referred to by their position within the Struct. Note that
// besides structs, this abstract type can also be used to represent dicts,
// OrderedDicts, and regular tuples in Python.
// Concise syntax for examples of struct types: "T_i" or "name_i=T_i" separated
// by commas and optionally enclosed in "<>" (e.g., "<bool, foo=string>"),
// where name_i is the optional name, and T_i is the type of i-th element.
message StructType {
repeated Element element = 1;
message Element {
string name = 1;
Type value = 2;
}
}
// A representation of a type of a sequence. A sequence is a data structure
// that contains multiple elements of the same type that can be accessed only
// in a sequential manner, i.e., through an iterator. For now, we assume that
// a sequence can only be consumed once, i.e., there's no concept of iterator
// reset, as this facilitates high-performance implementations. We may add a
// notion of resettability in the future by introducing additional fields here
// while keeping non-resettability of sequences as the default.
// Concise syntax for examples of sequence types: "T*", where T is the type of
// elements.
message SequenceType {
Type element = 1;
}
// A representation of a type of a single tensor in TensorFlow. Aspects such
// as sparseness are not intended to be represented at this level.
// Concise syntax for examples of tensor types: "dtype[shape]" or "dtype" for
// scalars, e.g., "bool[10]".
message TensorType {
// The data type of the tensor.
DataType dtype = 1;
// The sizes of each dimension of the tensor.
//
// Undefined dimensions are allowed and represented by -1. Defined and
// undefined dimensions are to be considered distinct for type checking
// purposes.
repeated int64 dims = 2;
// True iff the number of dimensions is unknown.
//
// If `dims` is unset:
// - `unknown_rank` == True corresponds to None
// - `unknown_rank` == False corresponds to []
bool unknown_rank = 3;
}
// A representation of an abstract type identified by a string label (analogous
// to "typename T" in C++, with "T" being the label). All occurrences of an
// abstract type with the same label within a type signature are interpreted as
// referring to the same concrete type. Abstract types can thus be used to
// represent templates similar to templates in C++. The label does not have any
// specific meaning otherwise. Any bijective renaming of all labels within a
// type signature is semantically a no-op (i.e., the resulting type definition
// is semantically identical to the original before renaming). The label may be
// modified by the compiler (e.g., due to naming conflicts).
// An AbstractType T might be used, for example, to define a signature of a
// generic aggregation operator as "federated_sum: {T}@clients -> T@server".
// Concise syntax for examples of abstract types: variations of uppercase "T",
// e.g., as in "T -> T'".
message AbstractType {
// The label used to refer to this abstract type within a type signature.
string label = 1;
}
// The term `placement` refers to a representation of an instance of a built-in
// opaque type that conceptually represents a (membership of some) collective
// of participants in a distributed system that may participate in some part of
// a federated computation.
//
// In a typical federated computation, there would typically be at least one
// group of client devices, one or more groups of intermediate aggregators in a
// multi-tiered server architecture, and a central coordinator (perhaps a
// singleton group). With each of these groups, one would associate a separate
// placement (a separate instance of the built-in "placement" type).
//
// Placements are intended to be passed as arguments to some of the federated
// communication operators to determine the group of participants involved in
// the underlying federated communication protocol.
//
// In addition, placements can be used to define federated types (see below),
// i.e., types, values of which are hosted by members of a given collective, and
// thus potentially distributed across multiple locations. In a fully-specified
// federated computation, each concrete value (e.g., tensor) would typically
// have an associated concrete placement value to indicate which group of system
// participants (clients, aggregator or coordinator instances, etc.) it is
// hosted on.
//
// While placement is a first-class type, instances of which may be passed as
// parameters or returned as results, it is not equivalent to a simple vector
// of device addresses. A computation cannot list, add, remove, or test for
// existence of a particular device in a placement, as membership could be
// determined or influenced by factors outside of the programmer's control. For
// example, the membership of the collective of client devices represented by
// a "client" placement will depend on which devices choose to join the system
// and further influenced by factors such as failures and network delays. In
// most types of environments, the membership of a given group of participants
// could be dynamically evolving over time. Federated computations are defined
// at a higher level of abstraction that does not involve dealing with the
// identities of the individual devices.
// A specification of a placement in a federated type. There are two ways of
// specifying a placement in this context that correspond to the two fields in
// the oneof below. Placement labels are used to construct template types of
// federated communication operators that can be applied to federated values.
// They relate all the identically-labeled placements that appear in the type
// signature without prescribing what specifically those placements must be.
// For example, consider the type signature below:
//
// federated_broadcast: T, p: placement -> T@p
//
// Here, "p" is a placement label, the role of which is simply to link the left
// and right sides of the type signature. The represenation of this type
// signature will use PlacementLabel on the left side.
//
// Concrete placement values are essentially placement literals, same as those
// that might appear in a computation body. They are used to bind types to
// specific placements with definite global meaning in a
// particular type of runtime environment.
message PlacementSpec {
oneof placement {
PlacementLabel label = 1;
Placement value = 2;
}
}
// A representation of an abstract placement identified by a string label.
// All occurrences of this abstract placement label within a type signature are
// interpreted as referring to the same specific placement, similarly to how
// this is done for abstract type labels (except that equality of placement
// labels indicates equality of values, not just types). The abstract placement
// label does not have any specific meaning otherwise, and it is not intended to
// be compared with anything other than another abstract placement label
// contained within the same type signature. A bijective renaming of all
// abstract placement labels contained in a type signature is a semantic no-op.
// The label may be modified by the compiler (e.g., due to name conflicts).
message PlacementLabel {
// The label used to refer to this specific placement within a type signature.
string label = 1;
}
// A representation of a specific placement defined globally by the runtime
// environment, and embedded as a literal of the "placement" type within a type
// signature or a computation definition. Unlike the abstract placement labels,
// the URIs in these placement values have a definite global meaning for all
// computations executed within the same environment. The exact set of global
// placement URIs and their meaning will depend on the system architecture and
// the capabilities of the platform. For example, in a production setting, these
// might include dedicated URIs to represent clients, intermediate aggregators,
// and coordinator placements.
message Placement {
// The globally unique URI that defines a specific global placement instance.
// For example, an URI might represent the global collective of all mobile
// devices running a certain app, or it might represent the specific
// well-known address of a central coordinator. The exact naming schemes and
// interpretation of these URIs is TBD, and will be documented later.
string uri = 1;
}
// A representation of a federated type, i.e., one in which member components of
// the federated value are hosted on a collective of devices in a distributed
// system (where in some cases, that collective may be a singleton). As noted
// above in the comment on "PlacementType", examples of such collectives could
// include client devices, intermediate aggregators, central coordinator, etc.,
// with one or more participants. Note that a federated type is a dependent
// type, as the placement label or value contained herein binds it to a specific
// placement, either one that's defined globally, or one that's supplied as a
// parameter and defined in another part of a computation's type signature.
// Concise syntax for federated types: "T@p" or "{T}@p" when "all_equal" is True
// or False, respectively, where "T" is the type of members, and "p" is either
// a placement label or a placement value (generally clear from context).
message FederatedType {
// A specification of placement that identifies the collective of participants
// in a distributed system on which member components of this federated value
// are hosted.
//
// If the federated type appears as a part of a functional type signature,
// this placement will generally be defined using a PlacementLabel to bind it
// to the type of the parameter, e.g., as below:
//
// federated_broadcast: T, p: placement -> T@p
//
// In the above "T@p" is a federated type, with label "p" (represented in the
// type as a PlacementLabel) simply serving as a reference to the parameter
// on the left.
//
// On the other hand, if a federated type appears on its own, not tied to the
// placement of any function parameter, the placement specified here will be
// a concrete placement literal (represented by a PlacementValue).
PlacementSpec placement = 1;
// A bit that, if set, indicates that the member components of the federated
// value are all equal (if not set, member components may vary). This
// distinction is only meaningful for placements that represent collectives,
// such as clients or intermediate aggregators. For placements that represent
// centralized components (such as a central coordinator), this property is
// trivially satisfied (and still documented by setting this bit to True).
bool all_equal = 2;
// The type of the local member components of the federated value, i.e., the
// components that are locally hosted on each individual participant (member
// of the collective determined by the "placement" above).
Type member = 3;
}
// A representation of the type of placements (see the discussion above by the
// definition of the Placement message that represents instances of this type).
// This message is only used in situations, where placement is passed as a
// first-class value (e.g., in the argument to broadcast). The specfications of
// federated types only refer to specific placements (see Placement above).
// Note that there is only a single primitive "placement" type. The embedded
// field "instance_label" does not qualify the type and does not affect type
// equality. It is only used to annotate the instance of this type as it appears
// in a type signature in order to form dependent types.
// Concise syntax for the placement type: "placement" for the type itself, and
// "p: placement" to annotate the specific entry in the type signature with the
// label "p".
message PlacementType {
// An optional label that can be used to refer to the specific instance of the
// "placement" type represented by this entry in the type signature. If this
// field is present in the PlacementType message, generally as a parameter in
// a functional type signature, the label is associated with the specific
// placement value supplied in that parameter, which allows it to be used to
// specify a federated type hosted by the collective of participants
// represented by this placement. For example, consider this type signature:
//
// federated_broadcast: T, p: placement -> T@p
//
// The type specification of the 2nd element of the broadcast argument Struct
// would be PlacementType(instance_label=PlacementLabel(label='p')). Here, the
// type of the second element is still simply "placement"; as noted above,
// there is only one such built-in type to represent all sorts of collectives.
// The presence of the label only associates 'p' with the value of the second
// element of the parameter Struct. On the right side, the pecification of the
// federated result type contains Placement(label=PlacementLabel(label='p')),
// thus binding the placement of the result to the value in the argument. When
// comparing types, the presence of this label is ignored.
PlacementLabel instance_label = 1;
}
// A representation of a section of TensorFlow code.
//
// The type signature associated with this type of computation must be defined
// only in terms of tensors, structs, and sequences. Sequences cannot be nested.
//
// At the moment, we only allow sequences as a parameters (note that pointwise
// transformations of sequences can still be expressed using a map intrinsic).
// This restriction may be relaxed in the future when support for handling data
// sets as first-class objects in TensorFlow evolves.
//
// Note that unlike in polymorphic functions created by tf.defuns, the chosen
// representation requires all type signatures, including those of individual
// elements of a sequence, to be fully specified. In case of sequences, the
// structure of their elements is effectively encoded in the parts of the graph
// that constitute the serialized representation of tf.data.Datasets and
// iterators.
//
// While we will offer support for writing polymorphic TensorFlow logic, types
// will be captured automatically and made concrete based on usage at the Python
// level of the API. Users of TFF will not need to declare them explicitly, but
// template specialization will happen before computation logic gets serialized.
//
// Next id: 8
message TensorFlow {
// The semantics is as follows: the graph embedded here will be instantiated,
// with all placeholder components of the parameter bound to concrete tensors
// or, in case of sequences, to iterators associated with concrete datasets.
// The compomnents of the result will then all be simultaneously evaluated in
// what corresponds to a single Session.run() in non-eager mode.
// Note: Currently, there is no way to represent any higher-level scripting
// over the graph. We require that all control flow logic be expressed using
// control dependencies and other TensorFlow constructs and triggered by the
// evaluation of outputs within a single Session.run(), as postulated above.
// Depending on how restrictive this turns out to be we might, or might not,
// add a script that describes a sequence of Session.run() calls, one-off or
// repeated in a loop, as an optional component in a TensorFlow computation,
// to address the impedance mismatch between push- and pull-based styles of
// processing supported by various parts of the target execution environment.
// A serialized representation of a TensorFlow graph to execute.
//
// Stores a tensorflow.GraphDef message.
// Note: This representation may evolve, e.g., get replaced with a MetaGraph,
// SavedModel, or a similar structure. Dependencies on the exact form of the
// graph encoding used here should be kept to minimum, and proxied by wrapper
// libraries for composing computations in python/core/impl/.
//
// TODO: b/117428091 - Update this representation based on the emerging TF 2.0
// serialization standards as needed if/when they meet the constraints of the
// target production environments, and provided that they don't introduce
// additional complexity.
google.protobuf.Any graph_def = 1;
// String name of an initialization op to run on the graph before fetching
// results. This op is intended only to be used for running tf.Variable
// initializers.
string initialize_op = 4;
// String name of a tensor which may be fed a unique identifier token for the
// current session. This allows TensorFlow custom ops to refer to
// session-global values created by the runner of the current session.
string session_token_tensor_name = 6;
// A pair of bindings for the parameter and the result. The parameter binding
// can be omitted if the computation does not declare a parameter. The result
// binding is mandatory, as all TensorFlow computations must declare results.
Binding parameter = 2;
Binding result = 3;
// A general representation of a binding of either a parameter or a result to
// a part of the embedded TensorFlow graph. Note that the structure of the
// binding is nested, and parallels the structure of the corresponding part of
// the type signature.
message Binding {
oneof binding {
// A binding associated with a struct in the type signature. Specifies an
// individual binding for each element of the struct.
StructBinding struct = 1;
// A binding associated with a (logical) tensor in the type signature.
// Associates that tensor to one or more (concrete) tensors in the graph.
TensorBinding tensor = 2;
// A binding associated with a sequence. Associates the sequence with a
// part of the TensorFlow graph that will represent a data set iterator,
// next element, or an equivalent iterator-like structure.
SequenceBinding sequence = 3;
}
}
// A binding of a Struct declared in the type signature to parts of the
// embedded TensorFlow graph.
message StructBinding {
// Bindings for elements of the Struct. The number of elements in this field
// must be equal to the number of Struct elements declared in the type
// signature, with the k-th binding declared here corresponding to the k-th
// Struct element in the type signature. The element names are omitted since
// they are redundant (correspondence is established by element order).
repeated Binding element = 1;
}
// A representation of a single tensor declared in the type signature in the
// serialized graph representation embedded here.
message TensorBinding {
oneof binding {
// The name of a dense tensor in a TensorFlow graph that corresponds to a
// single tensor component in the type signature.
string tensor_name = 1;
// Note: This structure may eventually be extended with non-dense tensor
// encodings, such as .tensorflow.TensorInfo.CooSparse.
}
}
// A representation of a sequence declared in the type signature.
message SequenceBinding {
// Previously was `iterator_string_handle_name`, but now only
// `variant_tensor_name` is supported.
reserved 1;
oneof binding {
// The name of the variant tensor that represents the data set created
// using `tf.data.experimental.from_variant`.
string variant_tensor_name = 2;
// The name of the string tensor that represents the data set created
// using `tf.raw_ops.DatasetFromGraph`.
string graph_def_tensor_name = 3;
// Note: This structure will likely evolve and get extended with other
// means of encoding data sets in the serialized graph representation.
}
}
// An optional id that can be used to identify identical TensorFlow messages
// without having to compare the (potentially large) `graph_def` fields.
//
// This field is not intended to be set during comptuation
// construction/tracing. Rather, it is designed as a final compilation pass
// that allows execution stacks to "cache" the graphs across invoke calls,
// avoiding costly graph parsing every invocation.
//
// The id is NOT required to be unique across machines, meaning two machines
// producing the same graph_def may have the same ids. If these machines
// should not be talking to the same execution stack.
//
// NOTE: the default value of 0 has the same meaning as having the field
// unset, and having no id. Any code setting this value should exclude zero.
message CacheKey {
uint64 id = 1;
}
CacheKey cache_key = 5;
// A map for layout information for variables and inputs for DTensor based
// executor.
// It is a map with key as variable node name or input tensor binding name.
// The value is DTensor Layout sharding spec.
//
// For example:
// "Dense//kernel" , "X"
// "Dense1//kernel" , "Unsharded"
//
// This map is specified by user while declaring tf_computation and is
// supplemented only when DTensor executor is used. At runtime the layout
// specification is supplimented with DTensor Mesh provided for execution.
// The Layout sharding spec must have the dimension names matching that
// of the Mesh provided at runtime.
//
// The layout map attributes will be ignored for other types of executors
// other than DTensor executor.
message LayoutMap {
map<string, string> name_to_sharding_spec = 1;
}
LayoutMap layout_map = 7;
}
// A representation of an intrinsic function. Intrinsics are functions that are
// known to the framework, and uniquely identified by a URI. This includes both
// the standard federated communication operators, such as, e.g., broadcast,
// federated sum, secure aggregation, and custom operators that might be added
// by the user to the pipeline. The compiler recognizes the intrinsics, and
// replaces them with a suitable implementation. Intrinsics may be both generic
// and specialized, low- and high-level. The exact naming scheme used to
// identify them, and how it can be extended to support new operators defined by
// external contributors, will be described elsewhere.
message Intrinsic {
// The URI that uniquely identifies the intrinsic within the set of operators
// built into the framework.
string uri = 1;
}
// A representation of a parameterized computation defined as a lambda
// expression that consists of a single parameter name, and an expression that
// contains references to this parameter name (the "name" computation variant).
// Lambdas can be nested, e.g., the result can also be a lambda or contain a
// lambda. Inner lambdas are allowed to refer to the parameter defined in the
// outer lambdas. We assume the usual rules of name hiding: inner names obscure
// the outer names.
//
// Concise syntax for lambdas: "parameter_name -> comp" where "comp" represents
// a parameterized computation that produces the result, or in the more general
// form "parameter_name : T -> comp" to indicate that parameter is of type "T".
// For example, a lambda that takes a 2-Struct of an unary operator and an
// integer as input, and returns the result of calling the unary operator
// on the integer, can be written as "x -> x[0](x[1])", or in the full form with
// type annotation as "x: <(int->int), arg=int> -> x[0](x[1])".
message Lambda {
// The name to use internally within this lambda to refer to the parameter.
// The parameter is mandatory. The name defined here can be used internally
// anywhere in the result computation, except if overridden in a nested
// lambda, where it can be hidden by a parameter with a conflicting name.
string parameter_name = 1;
// A computation that represents the result of applying the lambda to the
// parameter. The result may (almost always will) contain references to the
// parameter defined above.
Computation result = 2;
// Note that a Lambda as a whole must have a functional type T -> T', where
// T' is the type of the result, and T is the type of all references to the
// parameter within the result.
}
// A representation of a body of computation logic broken down into a sequence
// of local definitions that gradually build up towards a single final result
// expression. A block defines a sequence of local names, each associated with
// a computation. Computations associated with names introduced later can
// refer to names introduced earlier. At the end of a block is a single result
// computation defined in terms of those locals. It is similar to LET* in LISP.
//
// The intended usage of this abstraction is to break down complex processing
// into simpler, smaller, easier to understand units that are easier to work
// with in this broken-down representation, as opposed to a single monolithic
// complex expression. We expect it to be used, e.g., to represent top-level
// federated orchestration logic.
//
// A block is technically a redundant abstraction, as it can be equivalently
// represented using lambda expressions. For example, a simple block of the
// form "let x=y in z" is equivalent to "(x->z)(y)". Larger blocks can likewise
// be represented similarly as nested lambdas. The main purpose of introducing
// this abstraction is simplicity. While expressible via lambdas, a sequential
// representation is preferred over nested lambdas as it is more readable and
// easier to debug, and more closely matches how code is expected to be executed
// by a runtime environment, in which higher-order functions may be unsupported.
//
// One way to think of blocks is as a generalization of a GraphDef, and such,
// a mechanism for constructing data flow graphs that can include TensorFlow
// blocks and various federated communication operators as processing nodes.
// Indeed, this is the primary intended usage of blocks. In this interpretation
// a block can be thought of as a direct acyclic graph, with the locals and
// the result being the graph "nodes". Locals represent various partial results
// computed along the way, and the result is the "op" that represents the
// output. Each node has associated with it an expression (computation) that
// specifies how to derive its value from the values represented by other nodes
// referenced by name. The presence of such reference to one node's name inside
// another node's expression (computation) can be interpreted as a dependency
// edge in a data flow graph. Indeed, the data flow interpretation corresponds
// to the manner in which processing is expected to flow.
//
// Concise syntax: "let name_1=comp_1, ...., name_k=comp_k in comp" with
// "name_k" and "comp_k" representing the names of the locals, and computations
// that compute the results that those names represent. For example, a complex
// expression "x[0](x[1])" can be represented in a slightly more expanded
// form as "let f=x[0], v=[1] in f(v)".
message Block {
// One or more locals defined within the block, each associating a name with a
// computation. Computations, whether those associated with the locals, or
// that associated with the result below, can contain references to names
// defined earlier, here or in the surrounding context. Self-references are
// prohibited. All names introduced here must be different. Since execution
// semantics at this level is purely functional without side effects, the
// ordering in which the locals are declared is not significant, as it is only
// the dependencies between the computations that effectively determine the
// causal relationships that constrain the order of execution.
//
// Blocks can be nested, just as lambdas, and the same name scoping rules
// apply, i.e., blocks (or lambdas) contained within an embedded computation,
// whether in a local or in the result, are allowed to refer to names defined
// in an outer lambda or block (unless obscured by a nested declaration).
// If names defined in the outer context conflict with those defined in the
// inner congtext (here), the inner names hide outer names in the context in
// which they are defined. Thus, for example, in "x -> let x=1, y=x+1 in y",
// the "x=1" would hide the lambda parameter, and therefore "y=x+1" would
// refer to the inner "x".
repeated Local local = 1;
message Local {
string name = 1;
Computation value = 2;
}
// The result computation. Always required. The computation typically refers
// to locals defined above by name, just like the result in a lambda.
Computation result = 2;
}
// A reference to a computation defined as a local in a block, or to the
// parameter of a lambda.
message Reference {
string name = 1;
}
// A representation of a function call.
//
// The concise notation for function calls is "f(x)" or "f()", where "f" is the
// function, and "x" is the optional argument.
message Call {
// A computation that represents the function to call. The value that this
// represents must be of a functional type.
Computation function = 1;
// A computation that represents the argument to the function specified above.
// Present if and only if "function" declares a parameter. Must match the
// function's parameter type (i.e., the function's parameter type must be
// assignable from the argument type).
Computation argument = 2;
}
// A representation of a Struct constructor.
//
// The concise representation of a Struct constructor is "<>"-enclosed and
// comma-separated list of value or "name=value" sections, for example "<1,2>"
// or "<foo=1,bar=2>".
message Struct {
// The ordering of Struct elements is significant, and determines the type of
// the value represented by the expression. The names are optional.
repeated Element element = 1;
message Element {
string name = 1;
Computation value = 2;
}
}
// A representation of a value selected from a Struct returned by another
// computation.
//
// The concise representation of a selection is "x[index]" for positional
// selection, and "x.name" for name-based selection, where "x" represents the
// source from which to select. For example, in lambda "x -> x[0](x[1])", where
// "x[0]" and "x[1]" both represent selections of named members from the STruct
// "x", respectively.
message Selection {
// The source of selection, always required. This is a computation that
// returns a Struct (possibly nested), from which to select an element
// by name or by index.
Computation source = 1;
// A specification of what to select from the context (Struct). Indexes,
// when applied to Structs, are 0-based, i.e., "[0]" selects the first
// element.
int32 index = 3;
}
// A specification of an external source of data to be used by a computation.
//
// Data streams are curently expected to always be nested structures composed
// of sequences, Structs, and tensor types. Sequences cannot be nested.
// Structs may appear at the outer level (to return multiple sequences,
// e.g., training and testing samples), or at the element level (if sequences
// contain structured elements, e.g., examples already parsed into individual
// features).
//
// Although data could conceivably be modeled via intrinsics, we factor it out
// to more conveniently express various types of input pipelines without having
// to pack everything into a URI. Sources of data could include training
// examples emitted by a mobile app, files on a filesystem, data to obtain from
// a location on the web, etc., and the specification, in addition to the
// origin of the data, could include things like example selection criteria,
// data decoding or simple transformations. For now, this structure is a
// specification to be interpreted by the runtime environment. To be extended
// as needed.
message Data {
oneof data {
// A specification of the data stream as a URI to be interpreted by the
// environment.
string uri = 1 [deprecated = true];
// A specification of the data stream to be interpreted by the environment.
google.protobuf.Any content = 3;
}
reserved 2;
}
// A representation of a section of XLA code (experimental-only).
//
// The type signature associated with this type of computation must be defined
// as a function which accepts and returns tensors and potentially nested
// structures of tensors.
message Xla {
// A serialized representation of XLA code to execute.
//
// Stores an `HloModuleProto` message, as defined in the TensorFlow repo in
// the file "tensorflow/compiler/xla/service/hlo.proto" in the main branch.
//
// It is recommended, albeit not required that the entry computation in this
// module accepts its parameters as a single tuple.
//
// NOTE: As it is experimental-only, this representation may evolve, possibly
// in a manner that is backwards-incompatible. Make sure not to depend on the
// current form of this representation, and not to persist it in places where
// subsequent changes could cause breakages.
google.protobuf.Any hlo_module = 1;
// A pair of bindings for the parameter and the result. The parameter binding
// can be omitted if the computation does not declare a parameter. The result
// binding is mandatory, as all XLA computations must declare results.
Binding parameter = 2;
Binding result = 3;
// A general representation of a binding of either a parameter or a result to
// a part of the embedded HLO module. Note that the structure of the binding
// is nested, and it parallels the structure of the corresponding part of the
// type signature.
message Binding {
oneof binding {
StructBinding struct = 1;
TensorBinding tensor = 2;
}
}
// A binding associated with a struct in the type signature. Specifies an
// individual binding for each element of the struct.
message StructBinding {
// Bindings for elements of the struct. The number of elements in this field
// must be equal to the number of struct elements declared in the type
// signature, with the k-th binding declared here corresponding to the k-th
// struct element in the type signature. The element names are omitted since
// they are redundant (correspondence is established by element order).
repeated Binding element = 1;
}
// A binding associated with a (logical) tensor in the type signature.
// Associates that tensor to one or more (concrete) tensors in the inputs
// or outputs of a computation in the module.
message TensorBinding {
oneof binding {
// The 0-based index of this tensor in (the flattened form of) either the
// parameter or result tuple for the entry computation of the HLO module,
// i.e., the `HloComputationProto` with the id matching the module's
// `entry_computation_id`.
//
// The order of indexes associated with the result tensors is defined by
// the order in which tensors appear in the DFS traversal of the root
// instruction in the computation (which can be a tensor, or a possibly
// recursively nested tuple). For example, if the XLA computation returns
// a nested tuple ((int32, int32), int32), the indexes of the tensors in
// the result are ((0, 1), 2), accordingly.
//
// The order of indexes for parameter tensors is defined likewise. In the
// case of multiple arguments, tensor indexes are determined by traversing
// arguments in the order in which they appear on the parameter list (the
// order of `parameter_number` in the `HloInstructionProto`s.
// For example, if the computation takes 2 arguments, the first of which
// is a 2-tuple of tensors, and the second of which is a tensor, the
// indexes identifying the individual portions of the argument list would
// be (0, 1), 2, i.e., 0 would refer to the first tuple element of the
// first parameter, etc.
int32 index = 1;
}
}
}
// A representation of a literal value.
//
// The type signature associated with this type of computation must be defined
// as a tensor.
message Literal {
Array value = 1;
}