diff --git a/src/Analysis/Ast/Impl/Extensions/ArgumentSetExtensions.cs b/src/Analysis/Ast/Impl/Extensions/ArgumentSetExtensions.cs index 12acaba44..2595aaf2d 100644 --- a/src/Analysis/Ast/Impl/Extensions/ArgumentSetExtensions.cs +++ b/src/Analysis/Ast/Impl/Extensions/ArgumentSetExtensions.cs @@ -32,8 +32,11 @@ public static IReadOnlyList> Arguments(this IArgument public static T Argument(this IArgumentSet args, int index) where T : class => args.Arguments[index].Value as T; + public static IArgument Argument(this IArgumentSet args, string name, bool excludeDefault = true) + => args.Arguments.FirstOrDefault(a => name.Equals(a.Name) && !(excludeDefault &&a.ValueIsDefault)); + public static T GetArgumentValue(this IArgumentSet args, string name, bool excludeDefault = true) where T : class { - var value = args.Arguments.FirstOrDefault(a => name.Equals(a.Name) && !(excludeDefault && a.ValueIsDefault))?.Value; + var value = Argument(args, name, excludeDefault)?.Value; if (value == null && args.DictionaryArgument?.Arguments != null && args.DictionaryArgument.Arguments.TryGetValue(name, out var m)) { return m as T; } diff --git a/src/Analysis/Ast/Impl/Specializations/Typing/Types/GenericTypeParameter.cs b/src/Analysis/Ast/Impl/Specializations/Typing/Types/GenericTypeParameter.cs index d59dd27ac..bbe016351 100644 --- a/src/Analysis/Ast/Impl/Specializations/Typing/Types/GenericTypeParameter.cs +++ b/src/Analysis/Ast/Impl/Specializations/Typing/Types/GenericTypeParameter.cs @@ -20,6 +20,7 @@ using Microsoft.Python.Analysis.Types; using Microsoft.Python.Analysis.Utilities; using Microsoft.Python.Analysis.Values; +using Microsoft.Python.Core.Diagnostics; using Microsoft.Python.Core.Text; using Microsoft.Python.Parsing; @@ -103,7 +104,13 @@ private static bool TypeVarArgumentsValid(IArgumentSet argSet) { /// private static IPythonType GetBoundType(IArgumentSet argSet) { var eval = argSet.Eval; - var rawBound = argSet.GetArgumentValue("bound"); + var boundArg = argSet.Argument("bound"); + // User did not pass in upper bound, bail + if(boundArg == default) { + return null; + } + + var rawBound = boundArg.Value as IMember; switch (rawBound) { case IPythonType t: return t; @@ -119,6 +126,7 @@ private static IPythonType GetBoundType(IArgumentSet argSet) { } public static IPythonType FromTypeVar(IArgumentSet argSet, IPythonModule declaringModule, IndexSpan indexSpan = default) { + Check.ArgumentNotNull(nameof(argSet.Eval), argSet.Eval); if (!TypeVarArgumentsValid(argSet)) { return declaringModule.Interpreter.UnknownType; } diff --git a/src/Analysis/Ast/Impl/Specializations/Typing/TypingModule.cs b/src/Analysis/Ast/Impl/Specializations/Typing/TypingModule.cs index 36657bee4..54258b312 100644 --- a/src/Analysis/Ast/Impl/Specializations/Typing/TypingModule.cs +++ b/src/Analysis/Ast/Impl/Specializations/Typing/TypingModule.cs @@ -55,7 +55,7 @@ private void SpecializeMembers() { o.SetParameters(new List { new ParameterInfo("name", Interpreter.GetBuiltinType(BuiltinTypeId.Str), ParameterKind.Normal, null), new ParameterInfo("constraints", Interpreter.GetBuiltinType(BuiltinTypeId.Str), ParameterKind.List, null), - new ParameterInfo("bound", Interpreter.GetBuiltinType(BuiltinTypeId.Str), ParameterKind.KeywordOnly, new PythonConstant(null, Interpreter.GetBuiltinType(BuiltinTypeId.NoneType))), + new ParameterInfo("bound", Interpreter.GetBuiltinType(BuiltinTypeId.Type), ParameterKind.KeywordOnly, new PythonConstant(null, Interpreter.GetBuiltinType(BuiltinTypeId.NoneType))), new ParameterInfo("covariant", Interpreter.GetBuiltinType(BuiltinTypeId.Bool), ParameterKind.KeywordOnly, new PythonConstant(false, Interpreter.GetBuiltinType(BuiltinTypeId.Bool))), new ParameterInfo("contravariant", Interpreter.GetBuiltinType(BuiltinTypeId.Bool), ParameterKind.KeywordOnly, new PythonConstant(false, Interpreter.GetBuiltinType(BuiltinTypeId.Bool))) }); diff --git a/src/Analysis/Ast/Impl/Types/PythonFunctionOverload.cs b/src/Analysis/Ast/Impl/Types/PythonFunctionOverload.cs index 785ef804b..d2cabacbc 100644 --- a/src/Analysis/Ast/Impl/Types/PythonFunctionOverload.cs +++ b/src/Analysis/Ast/Impl/Types/PythonFunctionOverload.cs @@ -203,8 +203,8 @@ private IMember CreateSpecificReturnFromTypeVar(IPythonClassType selfClassType, } // Try getting the type from the type parameter bound - if (returnType.Bound != null) { - return returnType.Bound.CreateInstance(args); + if (!returnType.Bound.IsUnknown()) { + return new PythonInstance(returnType.Bound); } // Try returning the constraint diff --git a/src/Analysis/Ast/Test/GenericsTests.cs b/src/Analysis/Ast/Test/GenericsTests.cs index 376c6d47e..385718329 100644 --- a/src/Analysis/Ast/Test/GenericsTests.cs +++ b/src/Analysis/Ast/Test/GenericsTests.cs @@ -1360,6 +1360,44 @@ def get(self) -> T: ... analysis.Should().HaveVariable("x").OfType("A"); } + [TestMethod, Priority(0)] + public async Task GenericDefaultBound() { + const string code = @" +from typing import TypeVar, Generic +from logging import Logger, getLogger + +T = TypeVar('T') + +class A: ... + +class Test(Generic[T]): + def get(self) -> T: ... + +x = Test().get() +"; + var analysis = await GetAnalysisAsync(code); + analysis.Should().HaveVariable("x").OfType("T"); + } + + [TestMethod, Priority(0)] + public async Task GenericUnknownBound() { + const string code = @" +from typing import TypeVar, Generic +from logging import Logger, getLogger + +T = TypeVar('T', bound='unknown_thing') + +class A: ... + +class Test(Generic[T]): + def get(self) -> T: ... + +x = Test().get() +"; + var analysis = await GetAnalysisAsync(code); + analysis.Should().HaveVariable("x").OfType("T"); + } + [TestMethod, Priority(0)] public async Task GenericPath() { const string code = @"