forked from cockroachdb/cockroach
-
Notifications
You must be signed in to change notification settings - Fork 0
/
targets.go
96 lines (87 loc) · 3.07 KB
/
targets.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
// Copyright 2016 The Cockroach Authors.
//
// Licensed under the Cockroach Community Licence (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://github.com/cockroachdb/cockroach/blob/master/pkg/ccl/LICENSE
package sqlccl
import (
"github.com/cockroachdb/cockroach/pkg/sql/parser"
"github.com/cockroachdb/cockroach/pkg/sql/sqlbase"
"github.com/pkg/errors"
)
// descriptorsMatchingTargets returns the descriptors that match the targets. A
// database descriptor is included in this set if it matches the targets (or the
// session database) or if one of its tables matches the targets.
func descriptorsMatchingTargets(
sessionDatabase string, descriptors []sqlbase.Descriptor, targets parser.TargetList,
) ([]sqlbase.Descriptor, error) {
// TODO(dan): If the session search path starts including more than virtual
// tables (as of 2017-01-12 it's only pg_catalog), then this method will
// need to support it.
starByDatabase := make(map[string]struct{}, len(targets.Databases))
for _, d := range targets.Databases {
starByDatabase[d.Normalize()] = struct{}{}
}
tablesByDatabase := make(map[string][]string, len(targets.Tables))
for _, pattern := range targets.Tables {
var err error
pattern, err = pattern.NormalizeTablePattern()
if err != nil {
return nil, err
}
switch p := pattern.(type) {
case *parser.TableName:
if sessionDatabase != "" {
if err := p.QualifyWithDatabase(sessionDatabase); err != nil {
return nil, err
}
}
db := p.DatabaseName.Normalize()
tablesByDatabase[db] = append(tablesByDatabase[db], p.TableName.Normalize())
case *parser.AllTablesSelector:
if sessionDatabase != "" {
if err := p.QualifyWithDatabase(sessionDatabase); err != nil {
return nil, err
}
}
starByDatabase[p.Database.Normalize()] = struct{}{}
default:
return nil, errors.Errorf("unknown pattern %T: %+v", pattern, pattern)
}
}
databasesByID := make(map[sqlbase.ID]*sqlbase.DatabaseDescriptor, len(descriptors))
var ret []sqlbase.Descriptor
for _, desc := range descriptors {
if dbDesc := desc.GetDatabase(); dbDesc != nil {
databasesByID[dbDesc.ID] = dbDesc
normalizedDBName := parser.ReNormalizeName(dbDesc.Name)
if _, ok := starByDatabase[normalizedDBName]; ok {
ret = append(ret, desc)
} else if _, ok := tablesByDatabase[normalizedDBName]; ok {
ret = append(ret, desc)
}
}
}
for _, desc := range descriptors {
if tableDesc := desc.GetTable(); tableDesc != nil {
dbDesc, ok := databasesByID[tableDesc.ParentID]
if !ok {
return nil, errors.Errorf("unknown ParentID: %d", tableDesc.ParentID)
}
normalizedDBName := parser.ReNormalizeName(dbDesc.Name)
if _, ok := starByDatabase[normalizedDBName]; ok {
ret = append(ret, desc)
} else if tableNames, ok := tablesByDatabase[normalizedDBName]; ok {
for _, tableName := range tableNames {
if parser.ReNormalizeName(tableName) == parser.ReNormalizeName(tableDesc.Name) {
ret = append(ret, desc)
break
}
}
}
}
}
return ret, nil
}