Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Expectation to RsTypeInferenceWalker.inferMacroAsExpr #10748

Merged
merged 1 commit into from Jul 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -1284,7 +1284,7 @@ class RsTypeInferenceWalker(
val origin = definition?.containingCrate?.origin
if (origin != null && origin != PackageOrigin.STDLIB) {
inferChildExprsRecursively(macroCall)
return inferMacroAsExpr(macroCall)
return inferMacroAsExpr(macroCall, expected)
}

val name = macroCall.macroName
Expand Down Expand Up @@ -1325,7 +1325,7 @@ class RsTypeInferenceWalker(
inferChildExprsRecursively(macroCall)
return when {
macroCall.assertMacroArgument != null -> TyUnit.INSTANCE
macroCall.formatMacroArgument != null -> inferFormatMacro(macroCall)
macroCall.formatMacroArgument != null -> inferFormatMacro(macroCall, expected)
macroCall.includeMacroArgument != null -> inferIncludeMacro(macroCall)
name == "env" -> TyReference(TyStr.INSTANCE, IMMUTABLE)
name == "option_env" -> items.findOptionForElementTy(TyReference(TyStr.INSTANCE, IMMUTABLE))
Expand All @@ -1335,7 +1335,7 @@ class RsTypeInferenceWalker(
name == "stringify" -> TyReference(TyStr.INSTANCE, IMMUTABLE)
name == "module_path" -> TyReference(TyStr.INSTANCE, IMMUTABLE)
name == "cfg" -> TyBool.INSTANCE
else -> inferMacroAsExpr(macroCall)
else -> inferMacroAsExpr(macroCall, expected)
}
}

Expand All @@ -1347,11 +1347,12 @@ class RsTypeInferenceWalker(
}
}

private fun inferMacroAsExpr(macroCall: RsMacroCall): Ty
= (macroCall.expansion as? MacroExpansion.Expr)?.expr?.inferType() ?: TyUnknown
private fun inferMacroAsExpr(macroCall: RsMacroCall, expected: Expectation = NoExpectation): Ty {
return (macroCall.expansion as? MacroExpansion.Expr)?.expr?.inferType(expected) ?: TyUnknown
}

private fun inferFormatMacro(macroCall: RsMacroCall): Ty {
val inferredTy = inferMacroAsExpr(macroCall)
private fun inferFormatMacro(macroCall: RsMacroCall, expected: Expectation): Ty {
val inferredTy = inferMacroAsExpr(macroCall, expected)
val name = macroCall.macroName
return when {
"print" in name -> TyUnit.INSTANCE
Expand Down
Expand Up @@ -6,6 +6,8 @@
package org.rust.lang.core.type

import org.rust.CheckTestmarkHit
import org.rust.ProjectDescriptor
import org.rust.WithStdlibRustProjectDescriptor
import org.rust.lang.core.macros.MacroExpansionManager

class RsMacroTypeInferenceTest : RsTypificationTestBase() {
Expand Down Expand Up @@ -149,4 +151,20 @@ class RsMacroTypeInferenceTest : RsTypificationTestBase() {
a;
} //^ Data
""")

@ProjectDescriptor(WithStdlibRustProjectDescriptor::class)
fun `test closure from macro`() = testExpr("""
macro_rules! closure {
[$ p:pat => $ tup:expr] => {
|$ p| $ tup
};
}

fn foo() {
let a = vec![(1i32, 2i32)].into_iter().map(
closure![(a, b) => (a, b)]
).next().unwrap();
a;
} //^ (i32, i32)
""")
}