Skip to content

Commit

Permalink
completion on foreign keys.
Browse files Browse the repository at this point in the history
  • Loading branch information
srgsf committed Feb 4, 2022
1 parent 620def1 commit e737ffc
Show file tree
Hide file tree
Showing 17 changed files with 1,249 additions and 172 deletions.
5 changes: 5 additions & 0 deletions README.md
Expand Up @@ -34,6 +34,11 @@ sqls aims to provide advanced intelligence for you to edit sql in your own edito
- DDL(Data Definition Language)
- [ ] CREATE TABLE
- [ ] ALTER TABLE

#### Join completion
If the tables are connected with a foreign key sqls can complete ```JOIN``` statements

![join_completion](imgs/sqls-fk_joins.gif)

#### CodeAction

Expand Down
Binary file added imgs/sqls-fk_joins.gif
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
164 changes: 164 additions & 0 deletions internal/completer/candidates.go
@@ -1,6 +1,7 @@
package completer

import (
"fmt"
"strings"

"github.com/lighttiger2505/sqls/internal/database"
Expand Down Expand Up @@ -155,6 +156,169 @@ func (c *Completer) TableCandidates(parent *completionParent, targetTables []*pa
return candidates
}

func (c *Completer) joinCandidates(lastTable *parseutil.TableInfo,
targetTables, allTables []*parseutil.TableInfo,
joinOn, lowercaseKeywords bool) []lsp.CompletionItem {
var candidates []lsp.CompletionItem
if len(c.DBCache.ForeignKeys) == 0 {
return candidates
}

tMap := make(map[string]*parseutil.TableInfo)
for _, t := range targetTables {
tMap[t.Name] = t
}
fkMap := make(map[string][][]*database.ForeignKey)
if lastTable == nil {
for t := range tMap {
for k, v := range c.DBCache.ForeignKeys[t] {
fkMap[k] = append(fkMap[k], v)
}
}
} else {
delete(tMap, lastTable.Name)
rTab := []*parseutil.TableInfo{lastTable}
if !joinOn {
rTab = resolveTables(lastTable, c.DBCache)
}
for _, lt := range rTab {
for k, v := range c.DBCache.ForeignKeys[lt.Name] {
if _, ok := tMap[k]; ok {
fkMap[lt.Name] = append(fkMap[lt.Name], v)
}
}
}

for _, t := range rTab {
if _, ok := tMap[t.Name]; !ok {
tMap[t.Name] = t
}
}
}

aliases := make(map[string]interface{})
for _, t := range allTables {
if t.Alias != "" {
aliases[t.Alias] = true
}
}

for k, v := range fkMap {
for _, fks := range v {
for _, fk := range fks {
candidates = append(candidates, generateForeignKeyCandidate(k, tMap, aliases,
fk, joinOn, lowercaseKeywords))
}
}
}
return candidates
}

func resolveTables(t *parseutil.TableInfo, cache *database.DBCache) []*parseutil.TableInfo {
if _, ok := cache.ColumnDescs(t.Name); ok {
return []*parseutil.TableInfo{t}
}
var rv []*parseutil.TableInfo
targetName := strings.ToLower(t.Name)
for _, cond := range cache.SortedTables() {
if strings.Contains(strings.ToLower(cond), targetName) {
rv = append(rv, &parseutil.TableInfo{
Name: cond,
})
}
}
return rv
}

func generateTableAlias(target string,
aliases map[string]interface{}) string {
ch := []rune(target)[0]
i := 1
var rv string
for {
rv = fmt.Sprintf("%c%d", ch, i)
if _, ok := aliases[rv]; ok {
i++
continue
}
break
}
return rv
}

func generateForeignKeyCandidate(target string,
tMap map[string]*parseutil.TableInfo,
aliases map[string]interface{},
fk *database.ForeignKey,
joinOn, lowercaseKeywords bool) lsp.CompletionItem {
var tAlias string
if joinOn {
tAlias = tMap[target].Alias
if tAlias == "" {
tAlias = tMap[target].Name
}
} else {
tAlias = generateTableAlias(target, aliases)
}
builder := []struct {
sb *strings.Builder
alias string
}{
{
sb: &strings.Builder{},
alias: tAlias,
},
{
sb: &strings.Builder{},
alias: tAlias,
},
}
if !joinOn {
builder[1].alias = fmt.Sprintf("${1:%s}", tAlias)
onKw := "ON"
if lowercaseKeywords {
onKw = "on"
}
for _, b := range builder {
b.sb.WriteString(fmt.Sprintf("%s %s %s ", target, b.alias, onKw))
}
}
andKw := " AND "
if lowercaseKeywords {
andKw = " and "
}
prefix := ""
for _, cur := range *fk {
tIdx, rIdx := 0, 1
if cur[rIdx].Table == target {
tIdx, rIdx = rIdx, tIdx
}
for _, b := range builder {
b.sb.WriteString(prefix)
}
prefix = andKw
for _, b := range builder {
b.sb.WriteString(strings.Join([]string{b.alias, cur[tIdx].Name}, "."))
b.sb.WriteString(" = ")
}
rAlias := tMap[cur[rIdx].Table].Alias
if rAlias == "" {
rAlias = cur[rIdx].Table
}
for _, b := range builder {
b.sb.WriteString(strings.Join([]string{rAlias, cur[rIdx].Name}, "."))
}
}
builder[1].sb.WriteString("$0")
return lsp.CompletionItem{
Label: builder[0].sb.String(),
Kind: lsp.SnippetCompletion,
Detail: "Join generator for foreign key",
InsertText: builder[1].sb.String(),
InsertTextFormat: lsp.SnippetTextFormat,
}
}

func generateTableCandidates(tables []string, dbCache *database.DBCache) []lsp.CompletionItem {
candidates := []lsp.CompletionItem{}
for _, tableName := range tables {
Expand Down
48 changes: 44 additions & 4 deletions internal/completer/completer.go
Expand Up @@ -32,6 +32,8 @@ const (
CompletionTypeChange
CompletionTypeUser
CompletionTypeSchema
CompletionTypeJoin
CompletionTypeJoinOn
)

func (ct completionType) String() string {
Expand Down Expand Up @@ -112,7 +114,7 @@ func (c *Completer) Complete(text string, params lsp.CompletionParams, lowercase
lastWord := getLastWord(text, params.Position.Line+1, params.Position.Character)
withBackQuote := strings.HasPrefix(lastWord, "`")

items := []lsp.CompletionItem{}
var items []lsp.CompletionItem

if c.DBCache != nil {
if completionTypeIs(ctx.types, CompletionTypeColumn) {
Expand All @@ -130,7 +132,11 @@ func (c *Completer) Complete(text string, params lsp.CompletionParams, lowercase
items = append(items, candidates...)
}
if completionTypeIs(ctx.types, CompletionTypeTable) {
candidates := c.TableCandidates(ctx.parent, definedTables)
excl := definedTables
if completionTypeIs(ctx.types, CompletionTypeJoin) {
excl = nil
}
candidates := c.TableCandidates(ctx.parent, excl)
if withBackQuote {
candidates = toQuotedCandidates(candidates)
}
Expand All @@ -157,6 +163,22 @@ func (c *Completer) Complete(text string, params lsp.CompletionParams, lowercase
}
items = append(items, candidates...)
}
joinOn := completionTypeIs(ctx.types, CompletionTypeJoinOn)
if completionTypeIs(ctx.types, CompletionTypeJoin) || joinOn {
table, err := parseutil.ExtractLastTable(parsed, pos)
if err != nil {
return nil, err
}
tables, err := parseutil.ExtractPrevTables(parsed, pos)
if err != nil {
return nil, err
}
candidates := c.joinCandidates(table, tables, definedTables, joinOn, lowercaseKeywords)
if withBackQuote {
candidates = toQuotedCandidates(candidates) // what to do here?
}
items = append(candidates, items...)
}
}

if completionTypeIs(ctx.types, CompletionTypeKeyword) {
Expand Down Expand Up @@ -185,6 +207,8 @@ func populateSortText(items []lsp.CompletionItem) {
// This prefix defines the alphabetic priority of each kind.
func getSortTextPrefix(kind lsp.CompletionItemKind) string {
switch kind {
case lsp.SnippetCompletion:
return "00"
case lsp.FieldCompletion:
return "0"
case lsp.ClassCompletion:
Expand All @@ -208,7 +232,6 @@ func getSortTextPrefix(kind lsp.CompletionItemKind) string {
lsp.OperatorCompletion,
lsp.PropertyCompletion,
lsp.ReferenceCompletion,
lsp.SnippetCompletion,
lsp.StructCompletion,
lsp.TextCompletion,
lsp.TypeParameterCompletion,
Expand Down Expand Up @@ -249,7 +272,7 @@ func getCompletionTypes(nw *parseutil.NodeWalker) *CompletionContext {
}

syntaxPos := parseutil.CheckSyntaxPosition(nw)
t := []completionType{}
var t []completionType
p := noneParent
switch {
case syntaxPos == parseutil.ColName:
Expand Down Expand Up @@ -349,6 +372,23 @@ func getCompletionTypes(nw *parseutil.NodeWalker) *CompletionContext {
CompletionTypeFunction,
}
}
case syntaxPos == parseutil.JoinClause:
t = []completionType{
CompletionTypeJoin,
CompletionTypeTable,
CompletionTypeReferencedTable,
CompletionTypeSchema,
CompletionTypeView,
CompletionTypeSubQuery,
}
case syntaxPos == parseutil.JoinOn:
t = []completionType{
CompletionTypeJoinOn,
CompletionTypeColumn,
CompletionTypeReferencedTable,
CompletionTypeSubQueryColumn,
CompletionTypeSubQuery,
}
case syntaxPos == parseutil.InsertColumn:
t = []completionType{
CompletionTypeColumn,
Expand Down
35 changes: 35 additions & 0 deletions internal/completer/completer_test.go
Expand Up @@ -165,3 +165,38 @@ func TestComplete(t *testing.T) {
})
}
}

func TestGenerateAlias(t *testing.T) {
noMatchesTable := make(map[string]interface{})
noMatchesTable["XX"] = true
matchesTable := make(map[string]interface{})
matchesTable["XX"] = true
matchesTable["T1"] = true

tests := []struct {
name string
table string
tMap map[string]interface{}
want string
}{
{
"no matches",
"Table",
noMatchesTable,
"T1",
},
{
"matches",
"Table",
matchesTable,
"T2",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := generateTableAlias(tt.table, tt.tMap); got != tt.want {
t.Errorf("generateAlias() = %v, want %v", got, tt.want)
}
})
}
}
32 changes: 30 additions & 2 deletions internal/database/cache.go
Expand Up @@ -28,7 +28,7 @@ func (u *DBCacheGenerator) GenerateDBCachePrimary(ctx context.Context) (*DBCache
return nil, err
}
dbCache.Schemas = make(map[string]string)
for index,element := range schemas{
for index, element := range schemas {
dbCache.Schemas[strings.ToUpper(index)] = element
}

Expand All @@ -45,14 +45,15 @@ func (u *DBCacheGenerator) GenerateDBCachePrimary(ctx context.Context) (*DBCache
return nil, err
}
dbCache.SchemaTables = make(map[string][]string)
for index,element := range schemaTables{
for index, element := range schemaTables {
dbCache.SchemaTables[strings.ToUpper(index)] = element
}

dbCache.ColumnsWithParent, err = u.genColumnCacheCurrent(ctx, dbCache.defaultSchema)
if err != nil {
return nil, err
}
dbCache.ForeignKeys, err = u.genForeignKeysCache(ctx, dbCache.defaultSchema)
return dbCache, nil
}

Expand Down Expand Up @@ -88,6 +89,32 @@ func (u *DBCacheGenerator) genColumnCacheAll(ctx context.Context) (map[string][]
return genColumnMap(columnDescs), nil
}

func (u *DBCacheGenerator) genForeignKeysCache(ctx context.Context, schemaName string) (map[string]map[string][]*ForeignKey, error) {
retVal := make(map[string]map[string][]*ForeignKey)
fk, err := u.repo.DescribeForeignKeysBySchema(ctx, schemaName)
if err != nil {
return nil, err
}

for _, cur := range fk {
elem := (*cur)[0]
refs, ok := retVal[elem[0].Table]
if !ok {
refs = make(map[string][]*ForeignKey)
}
refs[elem[1].Table] = append(refs[elem[1].Table], cur)
retVal[elem[0].Table] = refs

refs, ok = retVal[elem[1].Table]
if !ok {
refs = make(map[string][]*ForeignKey)
}
refs[elem[0].Table] = append(refs[elem[0].Table], cur)
retVal[elem[1].Table] = refs
}
return retVal, nil
}

func genColumnMap(columnDescs []*ColumnDesc) map[string][]*ColumnDesc {
columnMap := map[string][]*ColumnDesc{}
for _, desc := range columnDescs {
Expand All @@ -107,6 +134,7 @@ type DBCache struct {
Schemas map[string]string
SchemaTables map[string][]string
ColumnsWithParent map[string][]*ColumnDesc
ForeignKeys map[string]map[string][]*ForeignKey
}

func (dc *DBCache) Database(dbName string) (db string, ok bool) {
Expand Down

0 comments on commit e737ffc

Please sign in to comment.