-
Notifications
You must be signed in to change notification settings - Fork 2
/
rule.go
156 lines (136 loc) · 5.7 KB
/
rule.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
package sampler
import (
"fmt"
. "github.com/gomlx/exceptions"
"github.com/gomlx/gomlx/types/shapes"
)
// Rule defines one rule of the sampling strategy.
// It's created by [Strategy.Nodes], [Strategy.NodesFromSet] and [Rule.FromEdges].
// Don't modify it directly.
type Rule struct {
Sampler *Sampler
Strategy *Strategy
// Name of the [Rule].
Name string
// ConvKernelScopeName doesn't affect sampling, but can be used to uniquely identify
// the scope used for the kernels in a GNN to do convolutions on this rule.
// If two rules have the same ConvKernelScopeName, they will share weights.
ConvKernelScopeName string
// UpdateKernelScopeName doesn't affect sampling, but can be used to uniquely identify
// the scope used for the kernels in a GNN to do convolutions on this rule.
// If two rules have the same UpdateKernelScopeName, they will share weights.
UpdateKernelScopeName string
// NodeTypeName of the nodes sampled by this rule.
NodeTypeName string
// NumNodes for NodeTypeName. Only used if NodeSet is not provided.
NumNodes int32
// SourceRule is the Name of the [Rule] this rule uses as source, or empty if
// this is a "Node" sampling rule (a root/seed sampling)
SourceRule *Rule
// Dependents is the list of Rules that depend on this one.
// That is other rules that have this Rule as [SourceRule].
// This is to keep track of the graph, and are not involved on the sampling of this rule.
Dependents []*Rule
// EdgeType that connects the [SourceRule] node type, to the node type ([NodeTypeName]) of this Rule.
// This is only set if this is an edge sampling rule. A node sampling rule (for seeds) have this set to nil.
EdgeType *EdgeType
// Count is the number of samples to create. It will define the last dimension of the tensor sampled.
Count int
// Shape of the sample for this rule.
Shape shapes.Shape
// NodeSet is a set of indices that a "Node" rule is allowed to sample from.
// E.g.: have separate NodeSet for train, test and validation datasets.
NodeSet []int32
}
// IsNode returns whether this is a "Node" rule, it can also be seen as a root rule.
func (r *Rule) IsNode() bool {
return r.SourceRule == nil
}
// IsIdentitySubRule returns whether this is an identity sub-rule with a 1-to-1 mapping.
func (r *Rule) IsIdentitySubRule() bool {
return r.SourceRule != nil && r.EdgeType == nil
}
// WithKernelScopeName will set both ConvKernelScopeName and UpdateKernelScopeName to `name`.
func (r *Rule) WithKernelScopeName(name string) *Rule {
r.ConvKernelScopeName = name
r.UpdateKernelScopeName = name
return r
}
// String returns an informative description of the rule.
func (r *Rule) String() string {
if r.IsNode() {
var sourceSetDesc string
if r.NodeSet != nil {
sourceSetDesc = fmt.Sprintf(", NodeSet.size=%d", len(r.NodeSet))
}
return fmt.Sprintf("Rule %q: type=Node, nodeType=%q, Shape=%s (size=%d)%s", r.Name, r.NodeTypeName, r.Shape, r.Shape.Size(), sourceSetDesc)
}
if r.IsIdentitySubRule() {
return fmt.Sprintf("Rule %q: type=Edge, nodeType=%q, Shape=%s (size=%d), SourceRule=%q, EdgeType=Identity",
r.Name, r.NodeTypeName, r.Shape, r.Shape.Size(), r.SourceRule.Name)
}
return fmt.Sprintf("Rule %q: type=Edge, nodeType=%q, Shape=%s (size=%d), SourceRule=%q, EdgeType=%q",
r.Name, r.NodeTypeName, r.Shape, r.Shape.Size(), r.SourceRule.Name, r.EdgeType.Name)
}
// FromEdges returns a [Rule] that samples nodes from the edges connecting the results of the current Rule `r`.
func (r *Rule) FromEdges(name, edgeTypeName string, count int) *Rule {
strategy := r.Strategy
if strategy.frozen {
Panicf("Strategy is frozen, that is, a dataset was already created and used with NewDataset() and hence can no longer be modified.")
}
if prevRule, found := strategy.Rules[name]; found {
Panicf("rule named %q already exists: %s", name, prevRule)
}
edgeDef, found := r.Sampler.EdgeTypes[edgeTypeName]
if !found {
Panicf("edge type %q not found to sample from in rule %q", edgeTypeName, name)
}
if edgeDef.SourceNodeType != r.NodeTypeName {
Panicf("edge type %q connects %q to %q: but you are using it on sampling rule %q, which is of node type %q",
edgeTypeName, edgeDef.SourceNodeType, edgeDef.TargetNodeType, r.Name, r.NodeTypeName)
}
newShape := r.Shape.Copy()
newShape.Dimensions = append(newShape.Dimensions, count)
newRule := &Rule{
Sampler: r.Sampler,
Strategy: strategy,
Name: name,
NodeTypeName: edgeDef.TargetNodeType,
SourceRule: r,
EdgeType: edgeDef,
Count: count,
Shape: newShape,
}
newRule = newRule.WithKernelScopeName("gnn:" + name)
r.Dependents = append(r.Dependents, newRule)
strategy.Rules[name] = newRule
return newRule
}
// IdentitySubRule creates a sub-rule that copies over the current rule, adding one rank (but same size).
// This is useful when trying to split updates into different parts, with the "IdentitySubRule" taking a
// subset of the dependents.
func (r *Rule) IdentitySubRule(name string) *Rule {
strategy := r.Strategy
if strategy.frozen {
Panicf("Strategy is frozen, that is, a dataset was already created and used with NewDataset() and hence can no longer be modified.")
}
if prevRule, found := strategy.Rules[name]; found {
Panicf("rule named %q already exists: %s", name, prevRule)
}
newShape := r.Shape.Copy()
newShape.Dimensions = append(newShape.Dimensions, 1)
newRule := &Rule{
Sampler: r.Sampler,
Strategy: strategy,
Name: name,
NodeTypeName: r.NodeTypeName,
SourceRule: r,
EdgeType: nil, // This identifies this as an identity sub-rule.
Count: 1, // 1-to-1 mapping.
Shape: newShape,
}
newRule = newRule.WithKernelScopeName("gnn:" + name)
r.Dependents = append(r.Dependents, newRule)
strategy.Rules[name] = newRule
return newRule
}