Skip to content

Commit

Permalink
__all__ evaluation (microsoft#863)
Browse files Browse the repository at this point in the history
Fixes microsoft#620.
Fixes microsoft#619.

This adds support for concatenating lists using `+`, doing `+=` on `__all__`, and calling `append` and `extend` on `__all__`. If something goes wrong (an unsupported operation on `__all__` or some unsupported value), then the old behavior continues to be used. I don't track uses of `__all__` indirectly (i.e. passing `__all__` to something that modifies it), only direct actions. If `__all__` is in a more complicated lvar (like `__all__, foo = ...`), then it is ignored. This can be improved later on when we fix up our multiple assignment issues.

This works well for Django models (see microsoft#620), but `numpy`'s import cycles prevent this from having an effect, so the old behavior will be used.

~Tests are WIP.~ I'll need to rebase/merge master when the refs per gets merged.
  • Loading branch information
jakebailey committed Apr 2, 2019
1 parent 838c778 commit 237a18c
Show file tree
Hide file tree
Showing 4 changed files with 290 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,11 @@
// See the Apache Version 2.0 License for specific language governing
// permissions and limitations under the License.

using System;
using System.Collections.Generic;
using Microsoft.Python.Analysis.Modules;
using Microsoft.Python.Analysis.Types;
using Microsoft.Python.Analysis.Types.Collections;
using Microsoft.Python.Analysis.Values;
using Microsoft.Python.Parsing;
using Microsoft.Python.Parsing.Ast;
Expand Down Expand Up @@ -123,6 +126,16 @@ internal sealed partial class ExpressionEval {
return left;
}

if (binop.Operator == PythonOperator.Add
&& left.GetPythonType()?.TypeId == BuiltinTypeId.List
&& right.GetPythonType()?.TypeId == BuiltinTypeId.List) {

var leftVar = GetValueFromExpression(binop.Left) as IPythonCollection;
var rightVar = GetValueFromExpression(binop.Right) as IPythonCollection;

return PythonCollectionType.CreateConcatenatedList(Module.Interpreter, GetLoc(expr), leftVar?.Contents, rightVar?.Contents);
}

return left.IsUnknown() ? right : left;
}
}
Expand Down
103 changes: 102 additions & 1 deletion src/Analysis/Ast/Impl/Analyzer/ModuleWalker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,15 @@
// See the Apache Version 2.0 License for specific language governing
// permissions and limitations under the License.

using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using Microsoft.Python.Analysis.Analyzer.Evaluation;
using Microsoft.Python.Analysis.Documents;
using Microsoft.Python.Analysis.Modules;
using Microsoft.Python.Analysis.Types;
using Microsoft.Python.Analysis.Types.Collections;
using Microsoft.Python.Analysis.Values;
using Microsoft.Python.Core;
using Microsoft.Python.Core.Collections;
Expand All @@ -33,6 +36,7 @@ internal class ModuleWalker : AnalysisWalker {

// A hack to use __all__ export in the most simple case.
private int _allReferencesCount;
private bool _allIsUsable = true;

public ModuleWalker(IServiceContainer services, IPythonModule module, PythonAst ast)
: base(new ExpressionEval(services, module, ast)) {
Expand All @@ -47,6 +51,103 @@ public ModuleWalker(IServiceContainer services, IPythonModule module, PythonAst
return base.Walk(node);
}

public override bool Walk(AugmentedAssignStatement node) {
HandleAugmentedAllAssign(node);
return base.Walk(node);
}

public override bool Walk(CallExpression node) {
HandleAllAppendExtend(node);
return base.Walk(node);
}

private void HandleAugmentedAllAssign(AugmentedAssignStatement node) {
if (!IsHandleableAll(node.Left)) {
return;
}

if (node.Right is ErrorExpression) {
return;
}

if (node.Operator != Parsing.PythonOperator.Add) {
_allIsUsable = false;
return;
}

var rightVar = Eval.GetValueFromExpression(node.Right);
var rightContents = (rightVar as IPythonCollection)?.Contents;

if (rightContents == null) {
_allIsUsable = false;
return;
}

ExtendAll(node.Left, rightContents);
}

private void HandleAllAppendExtend(CallExpression node) {
if (!(node.Target is MemberExpression me)) {
return;
}

if (!IsHandleableAll(me.Target)) {
return;
}

if (node.Args.Count == 0) {
return;
}

IReadOnlyList<IMember> contents = null;
var v = Eval.GetValueFromExpression(node.Args[0].Expression);
if (v == null) {
_allIsUsable = false;
return;
}

switch (me.Name) {
case "append":
contents = new List<IMember>() { v };
break;
case "extend":
contents = (v as IPythonCollection)?.Contents;
break;
}

if (contents == null) {
_allIsUsable = false;
return;
}

ExtendAll(node, contents);
}

private void ExtendAll(Node declNode, IReadOnlyList<IMember> values) {
Eval.LookupNameInScopes(AllVariableName, out var scope, LookupOptions.Normal);
if (scope == null) {
return;
}

var loc = Eval.GetLoc(declNode);

var allContents = (scope.Variables[AllVariableName].Value as IPythonCollection)?.Contents;

var list = PythonCollectionType.CreateConcatenatedList(Module.Interpreter, loc, allContents, values);
var source = list.IsGeneric() ? VariableSource.Generic : VariableSource.Declaration;

Eval.DeclareVariable(AllVariableName, list, source, loc);
}

private bool IsHandleableAll(Node node) {
// TODO: handle more complicated lvars
if (!(node is NameExpression ne)) {
return false;
}

return Eval.CurrentScope == Eval.GlobalScope && ne.Name == AllVariableName;
}

public override bool Walk(PythonAst node) {
Check.InvalidOperation(() => Ast == node, "walking wrong AST");

Expand Down Expand Up @@ -98,7 +199,7 @@ public ModuleWalker(IServiceContainer services, IPythonModule module, PythonAst
SymbolTable.ReplacedByStubs.Clear();
MergeStub();

if (_allReferencesCount == 1 && GlobalScope.Variables.TryGetVariable(AllVariableName, out var variable) && variable?.Value is IPythonCollection collection) {
if (_allIsUsable && _allReferencesCount >= 1 && GlobalScope.Variables.TryGetVariable(AllVariableName, out var variable) && variable?.Value is IPythonCollection collection) {
ExportedMemberNames = collection.Contents
.OfType<IPythonConstant>()
.Select(c => c.GetString())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
using System.Linq;
using Microsoft.Python.Analysis.Values;
using Microsoft.Python.Analysis.Values.Collections;
using Microsoft.Python.Core;

namespace Microsoft.Python.Analysis.Types.Collections {
/// <summary>
Expand Down Expand Up @@ -99,6 +100,11 @@ public override IMember Index(IPythonInstance instance, object index)
return new PythonCollection(collectionType, location, contents, flatten);
}

public static IPythonCollection CreateConcatenatedList(IPythonInterpreter interpreter, LocationInfo location, params IReadOnlyList<IMember>[] manyContents) {
var contents = manyContents?.ExcludeDefault().SelectMany().ToList() ?? new List<IMember>();
return CreateList(interpreter, location, contents);
}

public static IPythonCollection CreateTuple(IPythonInterpreter interpreter, LocationInfo location, IReadOnlyList<IMember> contents) {
var collectionType = new PythonCollectionType(null, BuiltinTypeId.Tuple, interpreter, false);
return new PythonCollection(collectionType, location, contents);
Expand Down
169 changes: 169 additions & 0 deletions src/LanguageServer/Test/ImportsTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -459,5 +459,174 @@ class A3(B1):
comps = cs.GetCompletions(analysis, new SourceLocation(2, 21));
comps.Should().HaveLabels("X");
}

[TestMethod, Priority(0)]
public async Task AllSimple() {
var module1Code = @"
class A:
def foo(self):
pass
pass
class B:
def bar(self):
pass
pass
__all__ = ['A']
";

var appCode = @"
from module1 import *
A().
B().
";

var module1Uri = TestData.GetTestSpecificUri("module1.py");
var appUri = TestData.GetTestSpecificUri("app.py");

var root = Path.GetDirectoryName(appUri.AbsolutePath);
await CreateServicesAsync(root, PythonVersions.LatestAvailable3X);
var rdt = Services.GetService<IRunningDocumentTable>();
var analyzer = Services.GetService<IPythonAnalyzer>();

rdt.OpenDocument(module1Uri, module1Code);

var app = rdt.OpenDocument(appUri, appCode);
await analyzer.WaitForCompleteAnalysisAsync();
var analysis = await app.GetAnalysisAsync(-1);

var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion);
var comps = cs.GetCompletions(analysis, new SourceLocation(4, 5));
comps.Should().HaveLabels("foo");

comps = cs.GetCompletions(analysis, new SourceLocation(5, 5));
comps.Should().NotContainLabels("bar");
}

[DataRow(@"
other = ['B']
__all__ = ['A'] + other")]
[DataRow(@"
other = ['B']
__all__ = ['A']
__all__ += other")]
[DataRow(@"
other = ['B']
__all__ = ['A']
__all__.extend(other)")]
[DataRow(@"
__all__ = ['A']
__all__.append('B')")]
[DataTestMethod, Priority(0)]
public async Task AllComplex(string allCode) {
var module1Code = @"
class A:
def foo(self):
pass
pass
class B:
def bar(self):
pass
pass
class C:
def baz(self):
pass
pass
" + allCode;

var appCode = @"
from module1 import *
A().
B().
C().
";

var module1Uri = TestData.GetTestSpecificUri("module1.py");
var appUri = TestData.GetTestSpecificUri("app.py");

var root = Path.GetDirectoryName(appUri.AbsolutePath);
await CreateServicesAsync(root, PythonVersions.LatestAvailable3X);
var rdt = Services.GetService<IRunningDocumentTable>();
var analyzer = Services.GetService<IPythonAnalyzer>();

rdt.OpenDocument(module1Uri, module1Code);

var app = rdt.OpenDocument(appUri, appCode);
await analyzer.WaitForCompleteAnalysisAsync();
var analysis = await app.GetAnalysisAsync(-1);

var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion);
var comps = cs.GetCompletions(analysis, new SourceLocation(4, 5));
comps.Should().HaveLabels("foo");

comps = cs.GetCompletions(analysis, new SourceLocation(5, 5));
comps.Should().HaveLabels("bar");

comps = cs.GetCompletions(analysis, new SourceLocation(6, 5));
comps.Should().NotContainLabels("baz");
}

[DataRow(@"
__all__ = ['A']
__all__.something(A)")]
[DataRow(@"
__all__ = ['A']
__all__ *= ['B']")]
[DataRow(@"
__all__ = ['A']
__all__ += 1234")]
[DataRow(@"
__all__ = ['A']
__all__.extend(123)")]
[DataRow(@"
__all__ = ['A']
__all__.extend(nothing)")]
[DataTestMethod, Priority(0)]
public async Task AllUnsupported(string allCode) {
var module1Code = @"
class A:
def foo(self):
pass
pass
class B:
def bar(self):
pass
pass
" + allCode;

var appCode = @"
from module1 import *
A().
B().
";

var module1Uri = TestData.GetTestSpecificUri("module1.py");
var appUri = TestData.GetTestSpecificUri("app.py");

var root = Path.GetDirectoryName(appUri.AbsolutePath);
await CreateServicesAsync(root, PythonVersions.LatestAvailable3X);
var rdt = Services.GetService<IRunningDocumentTable>();
var analyzer = Services.GetService<IPythonAnalyzer>();

rdt.OpenDocument(module1Uri, module1Code);

var app = rdt.OpenDocument(appUri, appCode);
await analyzer.WaitForCompleteAnalysisAsync();
var analysis = await app.GetAnalysisAsync(-1);

var cs = new CompletionSource(new PlainTextDocumentationSource(), ServerSettings.completion);
var comps = cs.GetCompletions(analysis, new SourceLocation(4, 5));
comps.Should().HaveLabels("foo");

comps = cs.GetCompletions(analysis, new SourceLocation(5, 5));
comps.Should().HaveLabels("bar");
}
}
}

0 comments on commit 237a18c

Please sign in to comment.