diff --git a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt index 526a073f619ea..b01053faf738a 100644 --- a/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt +++ b/clang-tools-extra/clangd/refactor/tweaks/CMakeLists.txt @@ -21,6 +21,7 @@ add_clang_library(clangDaemonTweaks OBJECT ExpandMacro.cpp ExtractFunction.cpp ExtractVariable.cpp + InlineConceptRequirement.cpp MemberwiseConstructor.cpp ObjCLocalizeStringLiteral.cpp ObjCMemberwiseInitializer.cpp diff --git a/clang-tools-extra/clangd/refactor/tweaks/InlineConceptRequirement.cpp b/clang-tools-extra/clangd/refactor/tweaks/InlineConceptRequirement.cpp new file mode 100644 index 0000000000000..201ad8ee15dc5 --- /dev/null +++ b/clang-tools-extra/clangd/refactor/tweaks/InlineConceptRequirement.cpp @@ -0,0 +1,265 @@ +//===--- InlineConceptRequirement.cpp ----------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +#include "ParsedAST.h" +#include "SourceCode.h" +#include "refactor/Tweak.h" +#include "support/Logger.h" +#include "clang/AST/ASTContext.h" +#include "clang/AST/ExprConcepts.h" +#include "clang/Tooling/Core/Replacement.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Error.h" + +namespace clang { +namespace clangd { +namespace { +/// Inlines a concept requirement. +/// +/// Before: +/// template void f(T) requires foo {} +/// ^^^^^^ +/// After: +/// template void f(T) {} +class InlineConceptRequirement : public Tweak { +public: + const char *id() const final; + + auto prepare(const Selection &Inputs) -> bool override; + auto apply(const Selection &Inputs) -> Expected override; + auto title() const -> std::string override { + return "Inline concept requirement"; + } + auto kind() const -> llvm::StringLiteral override { + return CodeAction::REFACTOR_KIND; + } + +private: + const ConceptSpecializationExpr *ConceptSpecializationExpression; + const TemplateTypeParmDecl *TemplateTypeParameterDeclaration; + const syntax::Token *RequiresToken; + + static auto getTemplateParameterIndexOfTemplateArgument( + const TemplateArgument &TemplateArgument) -> std::optional; + auto generateRequiresReplacement(ASTContext &) + -> llvm::Expected; + auto generateRequiresTokenReplacement(const syntax::TokenBuffer &) + -> tooling::Replacement; + auto generateTemplateParameterReplacement(ASTContext &Context) + -> llvm::Expected; + + static auto findToken(const ParsedAST *, const SourceRange &, + const tok::TokenKind) -> const syntax::Token *; + + template + static auto findNode(const SelectionTree::Node &Root) + -> std::tuple; + + template + static auto findExpression(const SelectionTree::Node &Root) + -> std::tuple { + return findNode(Root); + } + + template + static auto findDeclaration(const SelectionTree::Node &Root) + -> std::tuple { + return findNode(Root); + } +}; + +REGISTER_TWEAK(InlineConceptRequirement) + +auto InlineConceptRequirement::prepare(const Selection &Inputs) -> bool { + // Check if C++ version is 20 or higher + if (!Inputs.AST->getLangOpts().CPlusPlus20) + return false; + + const auto *Root = Inputs.ASTSelection.commonAncestor(); + if (!Root) + return false; + + const SelectionTree::Node *ConceptSpecializationExpressionTreeNode; + std::tie(ConceptSpecializationExpression, + ConceptSpecializationExpressionTreeNode) = + findExpression(*Root); + if (!ConceptSpecializationExpression) + return false; + + // Only allow concepts that are direct children of function template + // declarations or function declarations. This excludes conjunctions of + // concepts which are not handled. + const auto *ParentDeclaration = + ConceptSpecializationExpressionTreeNode->Parent->ASTNode.get(); + if (!isa_and_nonnull(ParentDeclaration) && + !isa_and_nonnull(ParentDeclaration)) + return false; + + const FunctionTemplateDecl *FunctionTemplateDeclaration = + std::get<0>(findDeclaration(*Root)); + if (!FunctionTemplateDeclaration) + return false; + + auto TemplateArguments = + ConceptSpecializationExpression->getTemplateArguments(); + if (TemplateArguments.size() != 1) + return false; + + auto TemplateParameterIndex = + getTemplateParameterIndexOfTemplateArgument(TemplateArguments[0]); + if (!TemplateParameterIndex) + return false; + + TemplateTypeParameterDeclaration = dyn_cast_or_null( + FunctionTemplateDeclaration->getTemplateParameters()->getParam( + *TemplateParameterIndex)); + if (!TemplateTypeParameterDeclaration->wasDeclaredWithTypename()) + return false; + + RequiresToken = + findToken(Inputs.AST, FunctionTemplateDeclaration->getSourceRange(), + tok::kw_requires); + if (!RequiresToken) + return false; + + return true; +} + +auto InlineConceptRequirement::apply(const Selection &Inputs) + -> Expected { + auto &Context = Inputs.AST->getASTContext(); + auto &TokenBuffer = Inputs.AST->getTokens(); + + tooling::Replacements Replacements{}; + + auto TemplateParameterReplacement = + generateTemplateParameterReplacement(Context); + + if (auto Err = TemplateParameterReplacement.takeError()) + return Err; + + if (auto Err = Replacements.add(*TemplateParameterReplacement)) + return Err; + + auto RequiresReplacement = generateRequiresReplacement(Context); + + if (auto Err = RequiresReplacement.takeError()) + return Err; + + if (auto Err = Replacements.add(*RequiresReplacement)) + return Err; + + if (auto Err = + Replacements.add(generateRequiresTokenReplacement(TokenBuffer))) + return Err; + + return Effect::mainFileEdit(Context.getSourceManager(), Replacements); +} + +auto InlineConceptRequirement::getTemplateParameterIndexOfTemplateArgument( + const TemplateArgument &TemplateArgument) -> std::optional { + if (TemplateArgument.getKind() != TemplateArgument.Type) + return {}; + + auto TemplateArgumentType = TemplateArgument.getAsType(); + if (!TemplateArgumentType->isTemplateTypeParmType()) + return {}; + + const auto *TemplateTypeParameterType = + TemplateArgumentType->getAs(); + if (!TemplateTypeParameterType) + return {}; + + return TemplateTypeParameterType->getIndex(); +} + +auto InlineConceptRequirement::generateRequiresReplacement(ASTContext &Context) + -> llvm::Expected { + auto &SourceManager = Context.getSourceManager(); + + auto RequiresRange = + toHalfOpenFileRange(SourceManager, Context.getLangOpts(), + ConceptSpecializationExpression->getSourceRange()); + if (!RequiresRange) + return error("Could not obtain range of the 'requires' branch. Macros?"); + + return tooling::Replacement( + SourceManager, CharSourceRange::getCharRange(*RequiresRange), ""); +} + +auto InlineConceptRequirement::generateRequiresTokenReplacement( + const syntax::TokenBuffer &TokenBuffer) -> tooling::Replacement { + auto &SourceManager = TokenBuffer.sourceManager(); + + auto Spelling = + TokenBuffer.spelledForExpanded(llvm::ArrayRef(*RequiresToken)); + + auto DeletionRange = + syntax::Token::range(SourceManager, Spelling->front(), Spelling->back()) + .toCharRange(SourceManager); + + return tooling::Replacement(SourceManager, DeletionRange, ""); +} + +auto InlineConceptRequirement::generateTemplateParameterReplacement( + ASTContext &Context) -> llvm::Expected { + auto &SourceManager = Context.getSourceManager(); + + auto ConceptName = ConceptSpecializationExpression->getNamedConcept() + ->getQualifiedNameAsString(); + + auto TemplateParameterName = + TemplateTypeParameterDeclaration->getQualifiedNameAsString(); + + auto TemplateParameterReplacement = ConceptName + ' ' + TemplateParameterName; + + auto TemplateParameterRange = + toHalfOpenFileRange(SourceManager, Context.getLangOpts(), + TemplateTypeParameterDeclaration->getSourceRange()); + + if (!TemplateParameterRange) + return error("Could not obtain range of the template parameter. Macros?"); + + return tooling::Replacement( + SourceManager, CharSourceRange::getCharRange(*TemplateParameterRange), + TemplateParameterReplacement); +} + +auto clang::clangd::InlineConceptRequirement::findToken( + const ParsedAST *AST, const SourceRange &SourceRange, + const tok::TokenKind TokenKind) -> const syntax::Token * { + auto &TokenBuffer = AST->getTokens(); + const auto &Tokens = TokenBuffer.expandedTokens(SourceRange); + + const auto Predicate = [TokenKind](const auto &Token) { + return Token.kind() == TokenKind; + }; + + auto It = std::find_if(Tokens.begin(), Tokens.end(), Predicate); + + if (It == Tokens.end()) + return nullptr; + + return It; +} + +template +auto InlineConceptRequirement::findNode(const SelectionTree::Node &Root) + -> std::tuple { + + for (const auto *Node = &Root; Node; Node = Node->Parent) { + if (const T *Result = dyn_cast_or_null(Node->ASTNode.get())) + return {Result, Node}; + } + + return {}; +} + +} // namespace +} // namespace clangd +} // namespace clang diff --git a/clang-tools-extra/clangd/unittests/CMakeLists.txt b/clang-tools-extra/clangd/unittests/CMakeLists.txt index 8d02b91fdd716..58773a445545c 100644 --- a/clang-tools-extra/clangd/unittests/CMakeLists.txt +++ b/clang-tools-extra/clangd/unittests/CMakeLists.txt @@ -125,6 +125,7 @@ add_unittest(ClangdUnitTests ClangdTests tweaks/ExpandMacroTests.cpp tweaks/ExtractFunctionTests.cpp tweaks/ExtractVariableTests.cpp + tweaks/InlineConceptRequirementTests.cpp tweaks/MemberwiseConstructorTests.cpp tweaks/ObjCLocalizeStringLiteralTests.cpp tweaks/ObjCMemberwiseInitializerTests.cpp diff --git a/clang-tools-extra/clangd/unittests/tweaks/InlineConceptRequirementTests.cpp b/clang-tools-extra/clangd/unittests/tweaks/InlineConceptRequirementTests.cpp new file mode 100644 index 0000000000000..54688318c95a2 --- /dev/null +++ b/clang-tools-extra/clangd/unittests/tweaks/InlineConceptRequirementTests.cpp @@ -0,0 +1,94 @@ +//===-- InlineConceptRequirementTests.cpp -----------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "TweakTesting.h" +#include "gtest/gtest.h" + +namespace clang { +namespace clangd { +namespace { + +TWEAK_TEST(InlineConceptRequirement); + +TEST_F(InlineConceptRequirementTest, Test) { + Header = R"cpp( + template + concept foo = true; + + template + concept bar = true; + + template + concept baz = true; + )cpp"; + + ExtraArgs = {"-std=c++20"}; + + // + // Extra spaces are expected and will be stripped by the formatter. + // + + EXPECT_EQ( + apply("template void f(T) requires f^oo {}"), + "template void f(T) {}"); + + EXPECT_EQ( + apply("template requires foo<^T> void f(T) {}"), + "template void f(T) {}"); + + EXPECT_EQ(apply("template