/
utils.go
150 lines (129 loc) · 4.28 KB
/
utils.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package athena
import (
"context"
pb "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/core"
pluginsIdl "github.com/flyteorg/flyte/flyteidl/gen/pb-go/flyteidl/plugins"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/errors"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/core/template"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/ioutils"
"github.com/flyteorg/flyte/flyteplugins/go/tasks/pluginmachinery/webapi"
"github.com/flyteorg/flyte/flytestdlib/logger"
"github.com/flyteorg/flyte/flytestdlib/utils"
)
func writeOutput(ctx context.Context, tCtx webapi.StatusContext, externalLocation string) error {
taskTemplate, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return err
}
if taskTemplate.Interface == nil || taskTemplate.Interface.Outputs == nil || taskTemplate.Interface.Outputs.Variables == nil {
logger.Infof(ctx, "The task declares no outputs. Skipping writing the outputs.")
return nil
}
resultsSchema, exists := taskTemplate.Interface.Outputs.Variables["results"]
if !exists {
logger.Infof(ctx, "The task declares no outputs. Skipping writing the outputs.")
return nil
}
return tCtx.OutputWriter().Put(ctx, ioutils.NewInMemoryOutputReader(
&pb.LiteralMap{
Literals: map[string]*pb.Literal{
"results": {
Value: &pb.Literal_Scalar{
Scalar: &pb.Scalar{Value: &pb.Scalar_Schema{
Schema: &pb.Schema{
Uri: externalLocation,
Type: resultsSchema.GetType().GetSchema(),
},
},
},
},
},
},
}, nil, nil))
}
type QueryInfo struct {
QueryString string
Workgroup string
Catalog string
Database string
}
func validateHiveQuery(hiveQuery pluginsIdl.QuboleHiveJob) error {
if hiveQuery.Query == nil {
return errors.Errorf(errors.BadTaskSpecification, "Query is a required field.")
}
if len(hiveQuery.Query.Query) == 0 {
return errors.Errorf(errors.BadTaskSpecification, "Query statement is a required field.")
}
return nil
}
func validatePrestoQuery(prestoQuery pluginsIdl.PrestoQuery) error {
if len(prestoQuery.Statement) == 0 {
return errors.Errorf(errors.BadTaskSpecification, "Statement is a required field.")
}
return nil
}
func extractQueryInfo(ctx context.Context, tCtx webapi.TaskExecutionContextReader) (QueryInfo, error) {
task, err := tCtx.TaskReader().Read(ctx)
if err != nil {
return QueryInfo{}, err
}
switch task.Type {
case "hive":
custom := task.GetCustom()
hiveQuery := pluginsIdl.QuboleHiveJob{}
err := utils.UnmarshalStructToPb(custom, &hiveQuery)
if err != nil {
return QueryInfo{}, errors.Wrapf(ErrUser, err, "Expects a valid QubleHiveJob proto in custom field.")
}
if err = validateHiveQuery(hiveQuery); err != nil {
return QueryInfo{}, errors.Wrapf(ErrUser, err, "Expects a valid QubleHiveJob proto in custom field.")
}
outputs, err := template.Render(ctx, []string{
hiveQuery.Query.Query,
hiveQuery.ClusterLabel,
}, template.Parameters{
TaskExecMetadata: tCtx.TaskExecutionMetadata(),
Inputs: tCtx.InputReader(),
OutputPath: tCtx.OutputWriter(),
Task: tCtx.TaskReader(),
})
if err != nil {
return QueryInfo{}, err
}
return QueryInfo{
QueryString: outputs[0],
Database: outputs[1],
}, nil
case "presto":
custom := task.GetCustom()
prestoQuery := pluginsIdl.PrestoQuery{}
err := utils.UnmarshalStructToPb(custom, &prestoQuery)
if err != nil {
return QueryInfo{}, errors.Wrapf(ErrUser, err, "Expects a valid PrestoQuery proto in custom field.")
}
if err = validatePrestoQuery(prestoQuery); err != nil {
return QueryInfo{}, errors.Wrapf(ErrUser, err, "Expects a valid PrestoQuery proto in custom field.")
}
outputs, err := template.Render(ctx, []string{
prestoQuery.RoutingGroup,
prestoQuery.Catalog,
prestoQuery.Schema,
prestoQuery.Statement,
}, template.Parameters{
TaskExecMetadata: tCtx.TaskExecutionMetadata(),
Inputs: tCtx.InputReader(),
OutputPath: tCtx.OutputWriter(),
Task: tCtx.TaskReader(),
})
if err != nil {
return QueryInfo{}, err
}
return QueryInfo{
Workgroup: outputs[0],
Catalog: outputs[1],
Database: outputs[2],
QueryString: outputs[3],
}, nil
}
return QueryInfo{}, errors.Errorf(ErrUser, "Unexpected task type [%v].", task.Type)
}