Skip to content

Commit

Permalink
Added support for new assert_type call, which is being added to Pyt…
Browse files Browse the repository at this point in the history
…hon 3.11 and typing_extensions.
  • Loading branch information
msfterictraut committed Mar 21, 2022
1 parent d78f737 commit 4edb1b4
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 0 deletions.
35 changes: 35 additions & 0 deletions packages/pyright-internal/src/analyzer/typeEvaluator.ts
Expand Up @@ -6738,6 +6738,9 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
} else if (isFunction(baseTypeResult.type) && baseTypeResult.type.details.builtInName === 'reveal_type') {
// Handle the "typing.reveal_type" call.
returnResult = getTypeFromRevealType(node, expectedType);
} else if (isFunction(baseTypeResult.type) && baseTypeResult.type.details.builtInName === 'assert_type') {
// Handle the "typing.assert_type" call.
returnResult = getTypeFromAssertType(node, expectedType);
} else if (
isAnyOrUnknown(baseTypeResult.type) &&
node.leftExpression.nodeType === ParseNodeType.Name &&
Expand Down Expand Up @@ -6806,6 +6809,38 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
return returnResult;
}

function getTypeFromAssertType(node: CallNode, expectedType: Type | undefined): TypeResult {
if (
node.arguments.length !== 2 ||
node.arguments[0].argumentCategory !== ArgumentCategory.Simple ||
node.arguments[0].name !== undefined ||
node.arguments[0].argumentCategory !== ArgumentCategory.Simple ||
node.arguments[1].name !== undefined
) {
addError(Localizer.Diagnostic.assertTypeArgs(), node);
return { node, type: UnknownType.create() };
}

const arg0TypeResult = getTypeOfExpression(node.arguments[0].valueExpression, expectedType);
if (arg0TypeResult.isIncomplete) {
return { node, type: UnknownType.create(), isIncomplete: true };
}

const assertedType = convertToInstance(getTypeForArgumentExpectingType(node.arguments[1]).type);

if (!isTypeSame(assertedType, arg0TypeResult.type)) {
addError(
Localizer.Diagnostic.assertTypeTypeMismatch().format({
expected: printType(assertedType),
received: printType(arg0TypeResult.type),
}),
node.arguments[0].valueExpression
);
}

return { node, type: arg0TypeResult.type };
}

function getTypeFromRevealType(node: CallNode, expectedType: Type | undefined): TypeResult {
let arg0Value: ExpressionNode | undefined;
let expectedRevealTypeNode: ExpressionNode | undefined;
Expand Down
5 changes: 5 additions & 0 deletions packages/pyright-internal/src/localization/localize.ts
Expand Up @@ -208,6 +208,11 @@ export namespace Localizer {
export const argTypePartiallyUnknown = () => getRawString('Diagnostic.argTypePartiallyUnknown');
export const argTypeUnknown = () => getRawString('Diagnostic.argTypeUnknown');
export const assertAlwaysTrue = () => getRawString('Diagnostic.assertAlwaysTrue');
export const assertTypeArgs = () => getRawString('Diagnostic.assertTypeArgs');
export const assertTypeTypeMismatch = () =>
new ParameterizedString<{ expected: string; received: string }>(
getRawString('Diagnostic.assertTypeTypeMismatch')
);
export const assignmentExprContext = () => getRawString('Diagnostic.assignmentExprContext');
export const assignmentExprComprehension = () =>
new ParameterizedString<{ name: string }>(getRawString('Diagnostic.assignmentExprComprehension'));
Expand Down
Expand Up @@ -20,6 +20,8 @@
"argTypePartiallyUnknown": "Argument type is partially unknown",
"argTypeUnknown": "Argument type is unknown",
"assertAlwaysTrue": "Assert expression always evaluates to true",
"assertTypeArgs": "\"assert_type\" expects two positional arguments",
"assertTypeTypeMismatch": "\"assert_type\" mismatch: expected \"{expected}\" but received \"{received}\"",
"assignmentExprContext": "Assignment expression must be within module, function or lambda",
"assignmentExprComprehension": "Assignment expression target \"{name}\" cannot use same name as comprehension for target",
"assignmentInProtocol": "Instance or class variables within a Protocol class must be explicitly declared within the class body",
Expand Down
46 changes: 46 additions & 0 deletions packages/pyright-internal/src/tests/samples/assertType1.py
@@ -0,0 +1,46 @@
# This sample tests the assert_type call.

from typing import Any, Literal
from typing_extensions import assert_type

def func1():
# This should generate an error.
assert_type()

# This should generate an error.
assert_type(1)

# This should generate an error.
assert_type(1, 2, 3)

# This should generate an error.
assert_type(*[])


def func2(x: int, y: int | str):
assert_type(x, int)

# This should generate an error.
assert_type(x, str)

# This should generate an error.
assert_type(x, Any)

x = 3
assert_type(x, Literal[3])

# This should generate an error.
assert_type(x, int)

assert_type(y, int | str)
assert_type(y, str | int)

# This should generate an error.
assert_type(y, str)

# This should generate an error.
assert_type(y, None)

# This should generate two errors.
assert_type(y, 3)

6 changes: 6 additions & 0 deletions packages/pyright-internal/src/tests/typeEvaluator2.test.ts
Expand Up @@ -332,6 +332,12 @@ test('RevealedType1', () => {
TestUtils.validateResults(analysisResults, 2, 0, 7);
});

test('AssertType1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['assertType1.py']);

TestUtils.validateResults(analysisResults, 11);
});

test('NameBindings1', () => {
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['nameBindings1.py']);

Expand Down
Expand Up @@ -166,5 +166,7 @@ def dataclass_transform(

# Types not yet implemented in typing_extensions library

def assert_type(val: _T, typ: Any, /) -> _T: ...

# Proposed extension to PEP 647
StrictTypeGuard: _SpecialForm = ...

0 comments on commit 4edb1b4

Please sign in to comment.