/
extractfunc.go
1061 lines (946 loc) · 33.4 KB
/
extractfunc.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
// Copyright 2015-2018 Auburn University and others. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package refactoring
import (
"bytes"
"fmt"
"go/ast"
"go/token"
"go/types"
"reflect"
"sort"
"github.com/godoctor/godoctor/analysis/cfg"
"github.com/godoctor/godoctor/analysis/dataflow"
"github.com/godoctor/godoctor/text"
"golang.org/x/tools/go/ast/astutil"
"golang.org/x/tools/go/loader"
)
/* -=-=- Sorting -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- */
// type for sorting []*types.Var variables alphabetically by name
type varSlice []*types.Var
func (t varSlice) Len() int { return len(t) }
func (t varSlice) Swap(i, j int) { t[i], t[j] = t[j], t[i] }
func (t varSlice) Less(i, j int) bool { return t[i].Name() < t[j].Name() }
func SortVars(vars []*types.Var) {
sort.Sort(varSlice(vars))
}
/* -=-=- stmtRange -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- */
// stmtRange represents a sequence of consecutive statements in the body of a
// BlockStmt, CaseClause, or CommClause.
type stmtRange struct {
// The sequence of ancestor nodes for the statement sequence from the
// enclosing BlockStmt/CaseClause/CommClause upward through the root of
// the AST. pathToRoot[0] will be an instace of *ast.BlockStmt,
// *ast.CaseClause, or *ast.CommClause, and
// pathToRoot[len(pathToRoot)-1] will be an instance of *ast.File.
pathToRoot []ast.Node
// The start and ending indices (inclusive) of the first and last
// statements if this sequence in the list of children for the
// enclosing BlockStmt/CaseClause/CommClause.
firstIdx, lastIdx int
// Control flow graph for the enclosing function
cfg *cfg.CFG
// CFG blocks inside the selected statement range
blocksInRange []ast.Stmt
// For each block in the CFG, variables that are live at its entry
liveIn map[ast.Stmt]map[*types.Var]struct{}
// Definitions reaching the entrypoint to the selection
defsReachingSelection map[ast.Stmt]struct{}
// PackageInfo used to bind variable names to *types.Var objects
pkgInfo *loader.PackageInfo
// The package rooted func / method enclosing the selection
enclosingFunc *ast.FuncDecl
}
// newStmtRange creates a stmtRange corresponding to a selected region of a
// file. If the selected range of characters does not enclose complete
// statements, the stmtRange is adjusted (if possible) to the closest legal
// selection. The given pkgInfo is used to determine the types and bindings of
// variables in the selection.
func newStmtRange(file *ast.File, start, end token.Pos, pkgInfo *loader.PackageInfo) (*stmtRange, error) {
startPath, _ := astutil.PathEnclosingInterval(file, start, start)
endPath, _ := astutil.PathEnclosingInterval(file, end-1, end-1)
// Work downward from the root of the AST, counting the number of nodes
// that enclose both the start and end of the selection
deepestCommonAncestorDepth := -1
for i := 0; i < min(len(startPath), len(endPath)); i++ {
if startPath[len(startPath)-1-i] == endPath[len(endPath)-1-i] {
deepestCommonAncestorDepth++
} else {
break
}
}
// Find the depth of the deepest BlockStmt, CaseClause, or CommClause
// enclosing both the start and end of the selection. If the user
// selected the initialization statement in an if-statement (or
// something similar), raise an error; it cannot be extracted.
blockDepth := deepestCommonAncestorDepth
body := []ast.Stmt{}
loop:
for blockDepth > 0 {
switch node := startPath[len(startPath)-1-blockDepth].(type) {
case *ast.BlockStmt:
body = node.List
break loop
case *ast.CaseClause:
body = node.Body
break loop
case *ast.CommClause:
body = node.Body
break loop
case *ast.IfStmt, *ast.SwitchStmt, *ast.TypeSwitchStmt, *ast.ForStmt, *ast.RangeStmt: // removed *ast.CommClause
if blockDepth != deepestCommonAncestorDepth {
// We are inside one of these constructs, but
// we haven't yet found an enclosing block/etc.
return nil, errInvalidSelection("The initialization statement in an if, switch, type switch, for, or range statement cannot be extracted.")
}
blockDepth--
default:
blockDepth--
}
}
if blockDepth <= 0 {
return nil, errInvalidSelection("Please select a sequence of statements inside a block.")
}
// pathToRoot is the list of ancestor nodes common to all of the
// statements in the selection, from the enclosing
// BlockStmt/CaseClause/CommClause up through the root
pathToRoot := startPath[len(startPath)-1-blockDepth:]
var enclosingFunc *ast.FuncDecl
for _, node := range pathToRoot {
if f, ok := node.(*ast.FuncDecl); ok {
enclosingFunc = f
break
}
}
if enclosingFunc == nil {
return nil, errInvalidSelection("Please select a sequence of statements inside a function declaration.")
}
cfg := cfg.FromFunc(enclosingFunc)
// Find the indices of the first and last statements whose positions
// overlap the selection
firstIdx := -1
lastIdx := -1
for i, stmt := range body {
overlapStart := maxPos(start, stmt.Pos())
overlapEnd := minPos(end, stmt.End())
inSelection := overlapStart < overlapEnd
if inSelection && firstIdx < 0 {
// We found the first statement in the selection
firstIdx = i
lastIdx = i
} else if inSelection && firstIdx >= 0 {
// We found a subsequent statement in the selection
lastIdx = i
} else if !inSelection && lastIdx >= 0 {
// We are beyond the end of the selection; no need to
// check any more statements
break
}
}
if firstIdx < 0 || lastIdx < 0 {
// There are no statements in the block. Most likely, the user
// selected an empty block, {}.
return nil, errInvalidSelection("An empty block cannot be extracted")
}
liveIn, _ := dataflow.LiveVars(cfg, pkgInfo)
result := &stmtRange{
pathToRoot: pathToRoot,
firstIdx: firstIdx,
lastIdx: lastIdx,
cfg: cfg,
blocksInRange: nil,
liveIn: liveIn,
defsReachingSelection: map[ast.Stmt]struct{}{},
pkgInfo: pkgInfo,
enclosingFunc: enclosingFunc,
}
// Determine the subset of blocks in the CFG that correspond to
// statements within the selected region.
blocksInRange := []ast.Stmt{}
for _, stmt := range cfg.Blocks() {
if result.Contains(stmt) {
blocksInRange = append(blocksInRange, stmt)
}
}
result.blocksInRange = blocksInRange
// Find those definitions that reach the entry to the selected region.
reaching := make(map[ast.Stmt]struct{})
for _, entry := range result.EntryPoints() {
for def := range dataflow.DefsReaching(entry, cfg, pkgInfo) {
reaching[def] = struct{}{}
}
}
result.defsReachingSelection = reaching
return result, nil
}
// min returns the minimum of two integers.
func min(m, n int) int {
if m < n {
return m
}
return n
}
// minPos returns the minimum of two token positions
// (equivalently, the position that appears first)
func minPos(m, n token.Pos) token.Pos {
if m < n {
return m
}
return n
}
// maxPos returns the maximum of two token positions
// (equivalently, the position that appears last)
func maxPos(m, n token.Pos) token.Pos {
if m > n {
return m
}
return n
}
// selectedStmts returns the children of the enclosing
// BlockStmt/CaseClause/CommClause that comprise the selected region. Note
// that this only includes immediate children; to visit nested statements, use
// Inspect.
func (r *stmtRange) selectedStmts() []ast.Stmt {
list := []ast.Stmt{}
switch node := r.pathToRoot[0].(type) {
case *ast.BlockStmt:
list = node.List
case *ast.CaseClause:
list = node.Body
case *ast.CommClause:
list = node.Body
default:
panic("unexpected node type")
}
return list[r.firstIdx : r.lastIdx+1]
}
// Inspect traverses the selected statements and their children.
func (r *stmtRange) Inspect(f func(ast.Node) bool) {
for _, node := range r.selectedStmts() {
ast.Inspect(node, f)
}
}
// IsInAnonymousFunc returns true if the selected statements have at least one
// ancestor that is a FuncLit, i.e., an anonymous function.
func (r *stmtRange) IsInAnonymousFunc() bool {
for _, node := range r.pathToRoot {
if _, ok := node.(*ast.FuncLit); ok {
return true
}
}
return false
}
// ContainsAnonymousFunc returns true if a FuncLit node (i.e., an anonymous
// function) appears as a descendent of any of the selected statements.
func (r *stmtRange) ContainsAnonymousFunc() bool {
flag := false
r.Inspect(func(n ast.Node) bool {
if _, ok := n.(*ast.FuncLit); ok {
flag = true
return false
}
return true
})
return flag
}
// ContainsDefer returns true if any of the selected statements, or any of
// their desdendents, are defer statements (DeferStmt nodes).
func (r *stmtRange) ContainsDefer() bool {
flag := false
r.Inspect(func(n ast.Node) bool {
if _, ok := n.(*ast.DeferStmt); ok {
flag = true
return false
}
return true
})
return flag
}
// ContainsReturn returns true if any of the selected statements, or any of
// their desdendents, are return statements (ReturnStmt nodes).
func (r *stmtRange) ContainsReturn() bool {
flag := false
r.Inspect(func(n ast.Node) bool {
if _, ok := n.(*ast.ReturnStmt); ok {
flag = true
return false
}
return true
})
return flag
}
// Contains returns true if the given node lies (lexically) within the region
// of text corresponding to the selected statements. Equivalently, it will
// return true if the given node is either a selected statement or a descendent
// of a selected statement.
func (r *stmtRange) Contains(node ast.Node) bool {
stmts := r.selectedStmts()
firstStmt := stmts[0]
lastStmt := stmts[len(stmts)-1]
return node.Pos() >= firstStmt.Pos() && node.End() <= lastStmt.End()
}
// Pos returns the starting position of the first statement in the selection.
func (r *stmtRange) Pos() token.Pos {
return r.selectedStmts()[0].Pos()
}
// End returns the ending position (exclusive) of the last statement in the
// selection.
func (r *stmtRange) End() token.Pos {
stmts := r.selectedStmts()
return stmts[len(stmts)-1].End()
}
// EntryPoints returns the CFG block(s) corresponding to the statement(s)
// within the selected region that will be the first to execute, before any
// other statements in the selection.
func (r *stmtRange) EntryPoints() []ast.Stmt {
entrySet := map[ast.Stmt]struct{}{}
for _, b := range r.blocksInRange {
for _, pred := range r.cfg.Preds(b) {
if !r.Contains(pred) {
entrySet[b] = struct{}{}
}
}
}
entryPoints := []ast.Stmt{}
for b := range entrySet {
entryPoints = append(entryPoints, b)
}
r.cfg.Sort(entryPoints)
return entryPoints
}
// ExitDestinations returns the CFG block(s) corresponding to the statement(s)
// outside the selected region that could be the first to execute after the
// statements in the selection have executed.
func (r *stmtRange) ExitDestinations() []ast.Stmt {
exitSet := map[ast.Stmt]struct{}{}
for _, b := range r.blocksInRange {
for _, succ := range r.cfg.Succs(b) {
if !r.Contains(succ) {
exitSet[succ] = struct{}{}
}
}
}
exitTo := []ast.Stmt{}
for b := range exitSet {
exitTo = append(exitTo, b)
}
r.cfg.Sort(exitTo)
return exitTo
}
// LocalsLiveAtEntry returns the local variables that are live at the
// entrypoint(s) to the selected region.
func (r *stmtRange) LocalsLiveAtEntry() []*types.Var {
entryPoints := r.EntryPoints()
liveEntry := []*types.Var{}
for _, entry := range entryPoints {
for variable := range r.liveIn[entry] {
liveEntry = append(liveEntry, variable)
}
}
SortVars(liveEntry)
return liveEntry
}
// LocalsLiveAfterExit returns the local variables that are live at the exit
// points from the selected region/at the entrypoints to the next statements
// after the selected statements have executed.
func (r *stmtRange) LocalsLiveAfterExit() []*types.Var {
exitTo := r.ExitDestinations()
liveExit := []*types.Var{}
for _, exit := range exitTo {
for variable := range r.liveIn[exit] {
liveExit = append(liveExit, variable)
}
}
SortVars(liveExit)
return liveExit
}
// LocalsReferenced returns the local variables that are accessed by one or
// more of the selected statements. It returns the variables that are
// (1) assigned, i.e., whose values are completely overwritten;
// (2) updated, i.e., a struct member or array element is modified;
// (3) declared via a var declaration or := operator;
// (4) used, i.e., whose values are read.
// Variables may appear in multiple sets.
func (r *stmtRange) LocalsReferenced() (asgt, updt, decl, use []*types.Var) {
asgtSet, updtSet, declSet, useSet := dataflow.ReferencedVars(r.blocksInRange, r.pkgInfo)
for v := range asgtSet {
asgt = append(asgt, v)
}
for v := range updtSet {
updt = append(updt, v)
}
for v := range declSet {
decl = append(decl, v)
}
for v := range useSet {
use = append(use, v)
}
SortVars(asgt)
SortVars(decl)
SortVars(use)
return
}
func (r *stmtRange) String() string {
stmts := r.selectedStmts()
var b bytes.Buffer
b.WriteString("Statement sequence from ")
b.WriteString(reflect.TypeOf(stmts[0]).String())
b.WriteString(" through ")
b.WriteString(reflect.TypeOf(stmts[len(stmts)-1]).String())
return b.String()
}
/* -=-=- extractedFunc -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-= */
// extractedFunc encapsulates information about the new function that will be
// created from the extracted code, along with how it should be called.
type extractedFunc struct {
name string // name of the new function
recv *types.Var // receiver variable, or nil
params []*types.Var // parameters for the new function
returns []*types.Var // variables whose values will be returned
locals []*types.Var // local variables to declare
localInits map[*types.Var]ast.Expr // initialization expressions for locals
define bool // x := f() instead of x = f()
code []byte // code to copy into the function body
pkgFmt func(p *types.Package) string // rewrite import uses
}
// SourceCode returns source code for (1) the new function declaration that
// should be inserted, and (2) the function call that should replace the
// selected statements.
func (f *extractedFunc) SourceCode() (funcDecl, funcCall string) {
paramNames, paramTypes := namesAndTypes(f.params, f.pkgFmt)
funcDeclParams := createParamDecls(paramNames, paramTypes)
funcCallArgs := commaSeparated(paramNames)
if f.recv != nil {
recvType := types.TypeString(f.recv.Type(), f.pkgFmt)
funcDecl = fmt.Sprintf("(%s %s) %s(%s)",
f.recv.Name(), recvType, f.name, funcDeclParams)
funcCall = fmt.Sprintf("%s.%s(%s)",
f.recv.Name(), f.name, funcCallArgs)
} else {
funcDecl = fmt.Sprintf("%s(%s)", f.name, funcDeclParams)
funcCall = fmt.Sprintf("%s(%s)", f.name, funcCallArgs)
}
names, types := namesAndTypes(f.locals, f.pkgFmt)
localVarDecls := createVarDecls(names, types, initStrings(f.localInits))
if len(f.returns) == 0 {
funcDecl = fmt.Sprintf("\n\nfunc %s {\n%s%s\n}\n",
funcDecl, localVarDecls, f.code)
funcCall = fmt.Sprintf("%s", funcCall)
} else {
returnNames, returnTypes := namesAndTypes(f.returns, f.pkgFmt)
returnExprs := commaSeparated(returnNames)
returnStmt := "return " + returnExprs
assignSymbol := " = "
if f.define {
assignSymbol = " := "
}
funcDefReturnTypes := commaSeparated(returnTypes)
if len(returnNames) > 1 {
funcDecl = fmt.Sprintf("\n\nfunc %s(%s) {\n%s%s\n%s\n}\n",
funcDecl, funcDefReturnTypes, localVarDecls,
f.code, returnStmt)
funcCall = fmt.Sprintf("%s%s%s",
returnExprs, assignSymbol, funcCall)
} else {
funcDecl = fmt.Sprintf("\n\nfunc %s %s {\n%s%s\n%s\n}\n",
funcDecl, funcDefReturnTypes, localVarDecls,
f.code, returnStmt)
funcCall = fmt.Sprintf("%s%s%s",
returnExprs, assignSymbol, funcCall)
}
}
return funcDecl, funcCall
}
// namesAndTypes receives a list of variables and returns strings describing
// their names and types, suitable for use in variable declarations.
func namesAndTypes(vars []*types.Var, fmt types.Qualifier) (names []string, typez []string) {
for _, a := range vars {
if a.Name() != "_" {
names = append(names, a.Name())
typez = append(typez, types.TypeString(a.Type(), fmt))
}
}
return
}
// initStrings receives a map from variables to (constant-valued) expressions
// and converts the keys and values to strings, returning a map from variable
// names to expression text. If an expression is mapped to nil, then it is
// not included in the returned map. This works because createVarDecls will
// declare the variable with a "var" declaration, which is exactly what we
// want in that situation.
func initStrings(inits map[*types.Var]ast.Expr) map[string]string {
result := make(map[string]string)
for variable, expr := range inits {
if expr != nil {
result[variable.Name()] = types.ExprString(expr)
}
}
return result
}
// createVarDecls returns source code for a sequence of var statements
// declaring variables with the given names, types, and initial values.
func createVarDecls(names []string, types []string, localInits map[string]string) string {
var buf bytes.Buffer
for i := 0; i < len(names); i++ {
if init, ok := localInits[names[i]]; ok {
buf.WriteString(names[i] + " := " + init)
} else {
buf.WriteString("var " + names[i] + " " + types[i])
}
if i > 1 || i <= len(names)-1 {
buf.WriteString("\n")
}
}
return buf.String()
}
// commaSeparated concatenates the given strings, separating them by ", "
func commaSeparated(strings []string) string {
var buf bytes.Buffer
for k := 0; k < len(strings); k++ {
buf.WriteString(strings[k])
if k == len(strings)-1 {
break
}
if k > 1 || k < len(strings)-1 {
buf.WriteString(", ")
}
}
return buf.String()
}
// createParamDecls returns source code for a parameter list, declaring
// function parameters with the given names and types.
func createParamDecls(names []string, types []string) string {
var buf bytes.Buffer
for k := 0; k < len(names); k++ {
buf.WriteString(names[k] + " " + types[k])
if k > 1 || k < len(names)-1 {
buf.WriteString(", ")
}
}
return buf.String()
}
/* -=-=- ExtractFunc -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- */
// The ExtractFunc refactoring is used to break down larger functions into
// smaller functions such that the logic of the code remains unchanged.
// The user is expected to extract recvTypeExpr part of code from the function and enter recvTypeExpr valid name
type ExtractFunc struct {
RefactoringBase
funcName string // name of the extracted function
stmtRange *stmtRange // selected statements (to be extracted)
}
func (r *ExtractFunc) Description() *Description {
return &Description{
Name: "Extract Function",
Synopsis: "Extracts statements to a new function/method",
Usage: "<new_name>",
HTMLDoc: extractFuncDoc,
Multifile: false,
Params: []Parameter{{
Label: "Name:",
Prompt: "Enter a name for the new function.",
DefaultValue: "",
}},
OptionalParams: nil,
Hidden: false,
}
}
func (r *ExtractFunc) Run(config *Config) *Result {
if r.Init(config, r.Description()); r.Log.ContainsErrors() {
return &r.Result
}
r.funcName = (config.Args[0]).(string)
if !isIdentifierValid(r.funcName) {
r.Log.Errorf("The name \"%s\" is not a valid Go identifier",
r.funcName)
return &r.Result
}
var err error
r.stmtRange, err = newStmtRange(r.File, r.SelectionStart, r.SelectionEnd, r.SelectedNodePkg)
if err != nil {
r.Log.Error(err)
r.Log.AssociatePos(r.SelectionStart, r.SelectionEnd)
return &r.Result
}
if r.stmtRange.IsInAnonymousFunc() {
r.Log.Error("Code inside an anonymous function cannot be extracted.")
r.Log.AssociatePos(r.SelectionStart, r.SelectionEnd)
return &r.Result
}
// Errors from here onward are non-fatal: The extraction can proceed,
// but it may not preserve semantics.
if r.stmtRange.ContainsAnonymousFunc() {
r.Log.Error("Code containing anonymous functions may not extract correctly.")
r.Log.AssociatePos(r.SelectionStart, r.SelectionEnd)
}
if r.stmtRange.ContainsDefer() {
r.Log.Error("Code containing defer statements may change behavior if it is extracted.")
r.Log.AssociatePos(r.SelectionStart, r.SelectionEnd)
}
if r.stmtRange.ContainsReturn() {
r.Log.Error("Code containing return statements may change behavior if it is extracted.")
r.Log.AssociatePos(r.SelectionStart, r.SelectionEnd)
}
// The next two checks determine if the single-entry-single-exit
// criterion is met. The call to UpdateLog (below) will check the
// refactored code for errors. If the SESE criterion is not met,
// that check will most likely point out the specific problems,
// so don't make too much effort to describe them here.
entryPoints := r.stmtRange.EntryPoints()
if len(entryPoints) > 1 {
r.Log.Error("There are multiple control flow paths into the selected statements. Extraction will likely be incorrect.")
r.Log.AssociatePos(r.SelectionStart, r.SelectionEnd)
}
exitDests := r.stmtRange.ExitDestinations()
if len(exitDests) > 1 {
r.Log.Error("There are multiple control flow paths out of the selected statements. Extraction will likely be incorrect.")
r.Log.AssociatePos(r.SelectionStart, r.SelectionEnd)
}
r.Log.ChangeInitialErrorsToWarnings()
r.addEdits()
r.FormatFileInEditor()
r.UpdateLog(config, true) // Check for errors in the refactored code
return &r.Result
}
// addEdits updates r.Edits, adding edits to insert a new function declaration
// and replace the selected statements with a call to that function.
func (r *ExtractFunc) addEdits() {
funcDecl, funcCall := r.createExtractedFunc().SourceCode()
// Replace the selected statements with a function call
offset := r.Program.Fset.Position(r.stmtRange.Pos()).Offset
length := r.Program.Fset.Position(r.stmtRange.End()).Offset - offset
r.Edits[r.Filename].Add(&text.Extent{offset, length}, funcCall)
next := r.Program.Fset.Position(r.stmtRange.enclosingFunc.End()).Offset
// Insert the new function declaration
r.Edits[r.Filename].Add(&text.Extent{next, 0}, funcDecl)
}
// createExtractedFunc returns an extractedFunc, which contains information
// about the extracted function and how it should be called. Source code can
// be obtained from the extractedFunc object.
func (r *ExtractFunc) createExtractedFunc() *extractedFunc {
recv, params, returns, locals, localInits, declareResult := r.analyzeVars()
startOffset := r.Program.Fset.Position(r.stmtRange.Pos()).Offset
endOffset := r.Program.Fset.Position(r.stmtRange.End()).Offset
code := r.FileContents[startOffset:endOffset]
return &extractedFunc{
name: r.funcName,
recv: recv,
params: params,
returns: returns,
locals: locals,
localInits: localInits,
define: declareResult,
code: code,
pkgFmt: pkgUseFmt(r.SelectedNodePkg.Pkg),
}
}
// pkgUseFmt returns a types.Qualifier similar to types.RelativeTo,
// but instead of returning full paths, it returns only the package's base
// name, i.e. 'github.com/some/pkg' -> 'pkg'. The current package's name
// is also omitted (since it would be a circular dependency on itself).
func pkgUseFmt(pkg *types.Package) types.Qualifier {
if pkg == nil {
return nil // wat
}
return func(other *types.Package) string {
if pkg == other {
return "" // same package; unqualified
}
return other.Name()
}
}
// analyzeVars determines (1) whether the extracted function should be a method
// and if so, what its receiver should be; (2) which local variables used in
// the selected statements should be passed as arguments to the extracted
// function; (3) which local variables' values must be returned from the
// extracted function; (4) which local variables can be redeclared in the
// extracted function (i.e., they do not need to be passed as arguments); and
// (5) when the selected statements are replaced with a function call, whether
// the call should have the form x := f() or x = f() -- i.e., whether the
// result variables should be declared or simply assigned.
func (r *ExtractFunc) analyzeVars() (recv *types.Var,
params, returns, locals []*types.Var,
localInits map[*types.Var]ast.Expr,
declareResult bool) {
aliveFirst := r.stmtRange.LocalsLiveAtEntry()
aliveLast := r.stmtRange.LocalsLiveAfterExit()
assigned, updated, declared, used := r.stmtRange.LocalsReferenced()
defined := union(union(assigned, updated), declared)
// Params = LIVE_IN[Entry(selectionnode)] ⋂ USE[selection]
params = intersection(aliveFirst, union(union(used, assigned), updated))
// returns = LIVE_OUT[exit(sel)] ⋂ DEF[sel]
// If someStruct is a pointer and someStruct.field is assigned, but
// someStruct itself is never reassigned, then it does not need to be
// returned. Likewise, if individual elements of a slice are updated
// but the slice itself is not reassigned, then the slice variable
// does not need to be returned.
updatedOnlyThruPointers := difference(r.varsWithPointerOrSliceTypes(updated), assigned)
returns = difference(
intersection(aliveLast, defined),
updatedOnlyThruPointers)
locals = difference(
union(difference(assigned, params),
difference(used, aliveFirst)),
declared)
// If we are returning the value of a variable declared in the
// selected statements, then the result variable needs to be declared.
declareResult = len(intersection(returns, declared)) > 0
if recvNode := r.stmtRange.enclosingFunc.Recv; recvNode != nil {
recv = r.SelectedNodePkg.ObjectOf(recvNode.List[0].Names[0]).(*types.Var)
params = difference(params, []*types.Var{recv})
returns = difference(returns, []*types.Var{recv})
locals = difference(locals, []*types.Var{recv})
}
// If an argument always has a constant value, there is no reason to
// pass it as an argument. Instead, make it a local variable, and
// set it equal to its constant value.
constants := r.constantValues(params)
for param := range constants {
params = difference(params, []*types.Var{param})
locals = append(locals, param)
}
// Sort each set of variables so we always extract in the same order.
SortVars(params)
SortVars(returns)
SortVars(locals)
return recv, params, returns, locals, constants, declareResult
}
// defs takes a list of variables and determines which are constant-valued; it
// returns a map from (a subset of those) variables to the expressions defining
// their constant values. If the value expression is nil, then the variable
// is defined by a "var" declaration with no initialization expression.
//
// The analysis is not too sophisticated. We only say the variable is
// constant-valued if
// (1) exactly one definition reaches the entry to the selected region, and
// (2) that definition has one of the following forms:
// var name type
// var name type = value
// name := value
// name = value
func (r *ExtractFunc) constantValues(varList []*types.Var) map[*types.Var]ast.Expr {
result := make(map[*types.Var]ast.Expr)
for variable, defs := range r.defsInitializing(varList) {
// fmt.Println(variable.Name(), " has ", len(defs), " defs")
// for s, _ := range defs {
// fmt.Printf("Line %d: ",
// r.Program.Fset.Position(s.Pos()).Line,
// astutil.NodeDescription(s))
// }
if def, ok := extractSingleton(defs); ok {
if expr, isConstant := r.constantAssigned(def); isConstant {
result[variable] = expr
}
}
}
return result
}
// defsInitializing takes a list of variables and returns a map from each
// variable to the set of statements that assign that variable's value at the
// entry to inside the selection.
//
// This is used to determine which variables are constant-valued.
func (r *ExtractFunc) defsInitializing(varList []*types.Var) map[*types.Var]map[ast.Stmt]struct{} {
result := make(map[*types.Var]map[ast.Stmt]struct{})
// When the first statement in the selection is a for-loop, a definition
// inside the loop may reach the beginning of the selection. However,
// these definitions do not affect the initial value of variables, so we
// exclude them.
excluded := make(map[ast.Stmt]bool)
for _, stmt := range r.stmtRange.EntryPoints() {
if isForOrRangeStmt(stmt) {
ast.Inspect(stmt, func(n ast.Node) bool {
if s, ok := n.(ast.Stmt); ok {
excluded[s] = true
}
return true
})
}
}
// Make sure every variable has an entry in the result map
for _, variable := range varList {
result[variable] = make(map[ast.Stmt]struct{})
}
// Add all definitions to the result map
for stmt := range r.stmtRange.defsReachingSelection {
if excluded[stmt] {
continue
}
asgtSet, updtSet, declSet, _ := dataflow.ReferencedVars(
[]ast.Stmt{stmt}, r.stmtRange.pkgInfo)
for variable := range asgtSet {
if _, found := result[variable]; found {
result[variable][stmt] = struct{}{}
}
}
for variable := range updtSet {
if _, found := result[variable]; found {
result[variable][stmt] = struct{}{}
}
}
for variable := range declSet {
if _, found := result[variable]; found {
result[variable][stmt] = struct{}{}
}
}
}
return result
}
func isForOrRangeStmt(stmt ast.Stmt) bool {
switch stmt.(type) {
case *ast.ForStmt, *ast.RangeStmt:
return true
default:
return false
}
}
// extractSingleton returns the only element in the given set, returning that
// element and true if it is a singleton set; otherwise, it returns nil and
// false.
func extractSingleton(set map[ast.Stmt]struct{}) (ast.Stmt, bool) {
if len(set) != 1 {
return nil, false
}
for stmt := range set {
return stmt, true
}
panic("Unreachable")
}
// constantAssigned determines whether the given statement assigns a constant
// value to an identifier, returning the constant expression and true if so.
func (r *ExtractFunc) constantAssigned(stmt ast.Stmt) (ast.Expr, bool) {
switch s := stmt.(type) {
case *ast.AssignStmt:
if len(s.Lhs) == 1 && len(s.Rhs) == 1 &&
isIdentifier(s.Lhs[0]) && isConstant(s.Rhs[0]) {
// identifier = expr
return s.Rhs[0], true
}
return nil, false
case *ast.DeclStmt:
if decl, ok := s.Decl.(*ast.GenDecl); ok && len(decl.Specs) == 1 {
if valueSpec, ok := decl.Specs[0].(*ast.ValueSpec); ok {
if len(valueSpec.Names) == 1 &&
isIdentifier(valueSpec.Names[0]) {
if len(valueSpec.Values) == 0 {
// var name type
return nil, true
}
if len(valueSpec.Values) == 1 &&
isConstant(valueSpec.Values[0]) {
// name := value
return valueSpec.Values[0], true
}
}
}
}
return nil, false
default:
return nil, false
}
}
func isIdentifier(expr ast.Expr) bool {
_, ok := expr.(*ast.Ident)
return ok
}
// constantIds are identifiers that are considered to be constants.
// See analyzeVars and constantValues.
//
// We omit "nil" to avoid introducing "use of untyped nil" errors. To work
// around this, we would need to always introduce a var declaration for nil
// rather than using the short assignment operator).
var constantIds map[string]bool = map[string]bool{
"false": true,
"true": true,
}
// isConstant returns true if an expression is considered to be a constant.
// See analyzeVars and constantValues.
//
// For our purposes, constants are BasicLits (integer, float, imaginary,
// character, and string literals) and certain identifiers (true and false).
func isConstant(expr ast.Expr) bool {
switch x := expr.(type) {
case *ast.BasicLit:
return true
case *ast.Ident:
return constantIds[x.String()]
default:
return false
}
}
// varsWithPointerOrSliceTypes receives a list of variables and returns those
// whose type is either a pointer or slice type.
func (r *ExtractFunc) varsWithPointerOrSliceTypes(varList []*types.Var) []*types.Var {
result := []*types.Var{}
for _, a := range varList {
switch a.Type().(type) {
case *types.Pointer, *types.Slice:
result = append(result, a)
}
}
return result
}
func intersection(s1, s2 []*types.Var) []*types.Var {
result := []*types.Var{}
for i := 0; i < len(s2); i++ {
for j := 0; j < len(s1); j++ {
if s2[i] == s1[j] {
result = append(result, s2[i])
}
}
}
return result