/
SymbolEditorExtensions.cs
102 lines (90 loc) · 3.99 KB
/
SymbolEditorExtensions.cs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
#nullable disable
using System;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
namespace Microsoft.CodeAnalysis.Editing
{
public static class SymbolEditorExtensions
{
/// <summary>
/// Gets the reference to the declaration of the base or interface type as part of the symbol's declaration.
/// </summary>
public static async Task<SyntaxNode> GetBaseOrInterfaceDeclarationReferenceAsync(
this SymbolEditor editor,
ISymbol symbol,
ITypeSymbol baseOrInterfaceType,
CancellationToken cancellationToken = default)
{
if (baseOrInterfaceType == null)
{
throw new ArgumentNullException(nameof(baseOrInterfaceType));
}
if (baseOrInterfaceType.TypeKind != TypeKind.Error)
{
baseOrInterfaceType = (ITypeSymbol)(await editor.GetCurrentSymbolAsync(baseOrInterfaceType, cancellationToken).ConfigureAwait(false));
}
// look for the base or interface declaration in all declarations of the symbol
var currentDecls = await editor.GetCurrentDeclarationsAsync(symbol, cancellationToken).ConfigureAwait(false);
foreach (var decl in currentDecls)
{
var doc = editor.OriginalSolution.GetDocument(decl.SyntaxTree);
var model = await doc.GetSemanticModelAsync(cancellationToken).ConfigureAwait(false);
var gen = SyntaxGenerator.GetGenerator(doc);
var typeRef = gen.GetBaseAndInterfaceTypes(decl).FirstOrDefault(r => model.GetTypeInfo(r, cancellationToken).Type.Equals(baseOrInterfaceType));
if (typeRef != null)
{
return typeRef;
}
}
return null;
}
/// <summary>
/// Changes the base type of the symbol.
/// </summary>
public static async Task<ISymbol> SetBaseTypeAsync(
this SymbolEditor editor,
INamedTypeSymbol symbol,
Func<SyntaxGenerator, SyntaxNode> getNewBaseType,
CancellationToken cancellationToken = default)
{
var baseType = symbol.BaseType;
if (baseType != null)
{
// find existing declaration of the base type
var typeRef = await editor.GetBaseOrInterfaceDeclarationReferenceAsync(symbol, baseType, cancellationToken).ConfigureAwait(false);
if (typeRef != null)
{
return await editor.EditOneDeclarationAsync(
symbol,
typeRef.GetLocation(),
(e, d) => e.ReplaceNode(typeRef, getNewBaseType(e.Generator)),
cancellationToken).ConfigureAwait(false);
}
}
// couldn't find the existing reference to change, so add it to one of the declarations
return await editor.EditOneDeclarationAsync(symbol, (e, decl) =>
{
var newBaseType = getNewBaseType(e.Generator);
if (newBaseType != null)
{
e.ReplaceNode(decl, (d, g) => g.AddBaseType(d, newBaseType));
}
}, cancellationToken).ConfigureAwait(false);
}
/// <summary>
/// Changes the base type of the symbol.
/// </summary>
public static Task<ISymbol> SetBaseTypeAsync(
this SymbolEditor editor,
INamedTypeSymbol symbol,
ITypeSymbol newBaseType,
CancellationToken cancellationToken = default)
{
return editor.SetBaseTypeAsync(symbol, g => newBaseType != null ? g.TypeExpression(newBaseType) : null, cancellationToken);
}
}
}