Skip to content

Commit

Permalink
Added support for an unpacked TypedDict as a type annotation for a `*…
Browse files Browse the repository at this point in the history
…kwargs` parameter.
  • Loading branch information
msfterictraut committed Feb 15, 2022
1 parent 9827e56 commit 5bee749
Show file tree
Hide file tree
Showing 8 changed files with 188 additions and 12 deletions.
31 changes: 26 additions & 5 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -1245,6 +1245,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
associateTypeVarsWithScope: true,
allowTypeVarTuple: paramCategory === ParameterCategory.VarArgList,
disallowRecursiveTypeAlias: true,
allowUnpackedTypedDict: paramCategory === ParameterCategory.VarArgDictionary,
allowUnpackedTuple: paramCategory === ParameterCategory.VarArgList,
});
}
Expand Down Expand Up @@ -1299,6 +1300,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
evaluatorFlags |= EvaluatorFlags.DisallowRecursiveTypeAliasPlaceholder;
}

if (options?.allowUnpackedTypedDict) {
evaluatorFlags |= EvaluatorFlags.AllowUnpackedTypedDict;
}

if (options?.allowUnpackedTuple) {
evaluatorFlags |= EvaluatorFlags.AllowUnpackedTupleOrTypeVarTuple;
}
Expand Down Expand Up @@ -13115,6 +13120,20 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
return UnknownType.create();
}

if ((flags & EvaluatorFlags.AllowUnpackedTypedDict) !== 0) {
if (isInstantiableClass(typeArgType) && ClassType.isTypedDictClass(typeArgType)) {
return ClassType.cloneForUnpacked(typeArgType);
}

addDiagnostic(
fileInfo.diagnosticRuleSet.reportGeneralTypeIssues,
DiagnosticRule.reportGeneralTypeIssues,
Localizer.Diagnostic.unpackExpectedTypedDict(),
errorNode
);
return UnknownType.create();
}

addDiagnostic(
fileInfo.diagnosticRuleSet.reportGeneralTypeIssues,
DiagnosticRule.reportGeneralTypeIssues,
Expand Down Expand Up @@ -15254,6 +15273,11 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
return type;
}

// Is this an unpacked TypedDict? If so, return it unmodified.
if (isClassInstance(type) && ClassType.isTypedDictClass(type) && type.isUnpacked) {
return type;
}

// Wrap the type in a dict with str keys.
const dictType = getBuiltInType(node, 'dict');
const strType = getBuiltInObject(node, 'str');
Expand Down Expand Up @@ -22154,7 +22178,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
if (srcParamInfo.param.name && srcParamInfo.param.category === ParameterCategory.Simple) {
const destParamInfo = destParamMap.get(srcParamInfo.param.name);
const paramDiag = diag?.createAddendum();
const srcParamType = FunctionType.getEffectiveParameterType(srcType, srcParamInfo.index);
const srcParamType = srcParamInfo.type;

if (!destParamInfo) {
if (destParamDetails.kwargsIndex === undefined && !srcParamInfo.param.hasDefault) {
Expand Down Expand Up @@ -22184,10 +22208,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
}
}
} else {
const destParamType = FunctionType.getEffectiveParameterType(
destType,
destParamInfo.index
);
const destParamType = destParamInfo.type;
const specializedDestParamType = destTypeVarMap
? applySolvedTypeVars(destParamType, destTypeVarMap)
: destParamType;
Expand Down
4 changes: 4 additions & 0 deletions packages/pyright-internal/src/analyzer/typeEvaluatorTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,9 @@ export const enum EvaluatorFlags {
// the interpreter (within a source file, not a stub) still
// parses the expression and generates parse errors.
InterpreterParsesStringLiteral = 1 << 22,

// Allow Unpack annotation for TypedDict.
AllowUnpackedTypedDict = 1 << 23,
}

export interface TypeArgumentResult {
Expand Down Expand Up @@ -257,6 +260,7 @@ export interface AnnotationTypeOptions {
allowTypeVarTuple?: boolean;
allowParamSpec?: boolean;
disallowRecursiveTypeAlias?: boolean;
allowUnpackedTypedDict?: boolean;
allowUnpackedTuple?: boolean;
notParsedByInterpreter?: boolean;
}
Expand Down
6 changes: 6 additions & 0 deletions packages/pyright-internal/src/analyzer/typePrinter.ts
Original file line number Diff line number Diff line change
Expand Up @@ -694,11 +694,17 @@ export function printFunctionParts(
recursionTypes.length < maxTypeRecursionCount
? printType(paramType, printTypeFlags, returnTypeCallback, recursionTypes)
: '';

if (!param.isNameSynthesized) {
paramString += ': ';
} else if (param.category === ParameterCategory.VarArgList && !isUnpacked(paramType)) {
paramString += '*';
}

if (param.category === ParameterCategory.VarArgDictionary && isUnpacked(paramType)) {
paramString += '**';
}

paramString += paramTypeString;

if (isParamSpec(paramType)) {
Expand Down
37 changes: 30 additions & 7 deletions packages/pyright-internal/src/analyzer/typeUtils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -343,14 +343,37 @@ export function getParameterListDetails(type: FunctionType): ParameterListDetail
}
} else if (param.category === ParameterCategory.VarArgDictionary) {
sawKeywordOnlySeparator = true;
if (result.kwargsIndex === undefined) {
result.kwargsIndex = result.params.length;
}
if (result.firstKeywordOnlyIndex === undefined) {
result.firstKeywordOnlyIndex = result.params.length;
}

addVirtualParameter(param, index);
// Is this an unpacked TypedDict? If so, expand the entries.
if (isClassInstance(param.type) && isUnpackedClass(param.type) && param.type.details.typedDictEntries) {
if (result.firstKeywordOnlyIndex === undefined) {
result.firstKeywordOnlyIndex = result.params.length;
}

param.type.details.typedDictEntries.forEach((entry, name) => {
addVirtualParameter(
{
category: ParameterCategory.Simple,
name,
type: entry.valueType,
hasDeclaredType: true,
hasDefault: !entry.isRequired,
},
index,
entry.valueType
);
});
} else {
if (result.kwargsIndex === undefined) {
result.kwargsIndex = result.params.length;
}

if (result.firstKeywordOnlyIndex === undefined) {
result.firstKeywordOnlyIndex = result.params.length;
}

addVirtualParameter(param, index);
}
} else if (param.category === ParameterCategory.Simple) {
if (param.name && !sawKeywordOnlySeparator) {
result.positionParamCount++;
Expand Down
1 change: 1 addition & 0 deletions packages/pyright-internal/src/localization/localize.ts
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,7 @@ export namespace Localizer {
new ParameterizedString<{ name1: string; name2: string }>(
getRawString('Diagnostic.unpackedTypeVarTupleExpected')
);
export const unpackExpectedTypedDict = () => getRawString('Diagnostic.unpackExpectedTypedDict');
export const unpackExpectedTypeVarTuple = () => getRawString('Diagnostic.unpackExpectedTypeVarTuple');
export const unpackIllegalInComprehension = () => getRawString('Diagnostic.unpackIllegalInComprehension');
export const unpackInAnnotation = () => getRawString('Diagnostic.unpackInAnnotation');
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,7 @@
"unpackedSubscriptIllegal": "Unpack operator in subscript requires Python 3.11 or newer",
"unpackedTypedDictArgument": "Unable to match unpacked TypedDict argument to parameters",
"unpackedTypeVarTupleExpected": "Expected unpacked TypeVarTuple; use Unpack[{name1}] or *{name2}",
"unpackExpectedTypedDict": "Expected TypedDict type argument for Unpack",
"unpackExpectedTypeVarTuple": "Expected TypeVarTuple or Tuple as type argument for Unpack",
"unpackIllegalInComprehension": "Unpack operation not allowed in comprehension",
"unpackInAnnotation": "Unpack operator not allowed in type annotation",
Expand Down
114 changes: 114 additions & 0 deletions packages/pyright-internal/src/tests/samples/kwargsUnpack1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# This sample tests the handling of Unpack[TypedDict] when used with
# a **kwargs parameter in a function signature.

from typing import Protocol, TypedDict
from typing_extensions import NotRequired, Required, Unpack


class TD1(TypedDict):
v1: Required[int]
v2: NotRequired[str]


class TD2(TD1):
v3: Required[str]


def func1(**kwargs: Unpack[TD2]) -> None:
v1 = kwargs["v1"]
reveal_type(v1, expected_text="int")

# This should generate an error because v2 might not be present.
kwargs["v2"]

if "v2" in kwargs:
v2 = kwargs["v2"]
reveal_type(v2, expected_text="str")

v3 = kwargs["v3"]
reveal_type(v3, expected_text="str")


reveal_type(func1, expected_text="(**kwargs: **TD2) -> None")


def func2(v1: int, **kwargs: Unpack[TD1]) -> None:
pass


def func3():
# This should generate an error because it is
# missing required keyword arguments.
func1()

func1(v1=1, v2="", v3="5")

td2 = TD2(v1=2, v3="4")
func1(**td2)

# This should generate an error because v4 is not in TD2.
func1(v1=1, v2="", v3="5", v4=5)

# This should generate an error because args are passed by position.
func1(1, "", "5")

my_dict: dict[str, str] = {}
# This should generate an error because it's an untyped dict.
func1(**my_dict)

func1(**{"v1": 2, "v3": "4", "v4": 4})

# This should generate an error because v1 is already specified.
func1(v1=2, **td2)

# This should generate an error because v1 is already specified.
func2(1, **td2)

# This should generate an error because v1 is matched to a
# named parameter and is not available for kwargs.
func2(v1=1, **td2)


class TDProtocol1(Protocol):
def __call__(self, *, v1: int, v3: str) -> None:
...


class TDProtocol2(Protocol):
def __call__(self, *, v1: int, v3: str, v2: str = "") -> None:
...


class TDProtocol3(Protocol):
def __call__(self, *, v1: int, v2: int, v3: str) -> None:
...


class TDProtocol4(Protocol):
def __call__(self, *, v1: int) -> None:
...


class TDProtocol5(Protocol):
def __call__(self, v1: int, v3: str) -> None:
...


class TDProtocol6(Protocol):
def __call__(self, **kwargs: Unpack[TD2]) -> None:
...


v1: TDProtocol1 = func1
v2: TDProtocol2 = func1

# This should generate an error because v2 is the wrong type.
v3: TDProtocol3 = func1

# This should generate an error because v3 is missing.
v4: TDProtocol4 = func1

# This should generate an error because parameters are positional.
v5: TDProtocol5 = func1

v6: TDProtocol6 = func1
6 changes: 6 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator1.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,12 @@ test('Function10', () => {
TestUtils.validateResults(analysisResults, 0);
});

test('KwargsUnpack1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['kwargsUnpack1.py']);

TestUtils.validateResults(analysisResults, 11);
});

test('Unreachable1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['unreachable1.py']);

Expand Down

0 comments on commit 5bee749

Please sign in to comment.