Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Rewrite/Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,7 @@
<VersionPrefix>0.8.7</VersionPrefix>
<RewriteRemoteVersion>0.8.2-rc.*</RewriteRemoteVersion>
<RewriteRemoteDir Condition="'$(RewriteRemoteDir)'==''">..\..\..\..\..\moderneinc\rewrite-remote\Rewrite.Remote</RewriteRemoteDir>
<DefineConstants>TRACE;DEBUG;REMOTE_PRINTER</DefineConstants>
</PropertyGroup>
<Import Project="Directory.Build.props.user" Condition="Exists('Directory.Build.props.user')"/>
</Project>
107 changes: 54 additions & 53 deletions Rewrite/src/Rewrite.CSharp/CSharpPrinter.cs

Large diffs are not rendered by default.

45 changes: 45 additions & 0 deletions Rewrite/src/Rewrite.CSharp/Extensions.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
using System.Collections;
using Rewrite.Core;
using Rewrite.RewriteJava.Tree;

namespace Rewrite.RewriteCSharp;

public static class Extensions
{
public static IEnumerable<Core.Tree> Descendents<TNode>(this TNode source) where TNode : Core.Tree
{
var searchVisitor = new SearchVisitor();
searchVisitor.Visit(source, null);
return searchVisitor.Visited;
}

public static TRoot ReplaceNode<TRoot>(this TRoot root, J oldNode, J newNode)
where TRoot : Core.Tree
{
var newRoot = new ReplaceVisitor(oldNode, newNode).Visit(root, null);
return (TRoot)newRoot!;
}

private class ReplaceVisitor(J oldNode, J newNode) : CSharpVisitor<object?>
{
public override J? PreVisit(Core.Tree? tree, object? p)
{
if (oldNode.Equals(tree))
{
return newNode;
}
return base.PreVisit(tree, p);
}
}

private class SearchVisitor : CSharpVisitor<object?>
{
public List<Core.Tree> Visited { get; } = new ();
public override J? PostVisit(Core.Tree tree, object? p)
{

Visited.Add(tree);
return base.PostVisit(tree, p);
}
}
}
10 changes: 10 additions & 0 deletions Rewrite/src/Rewrite.CSharp/Marker/MemberBinding.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
namespace Rewrite.RewriteCSharp.Marker;

public record MemberBinding(Guid Id) : Core.Marker.Marker
{
public virtual bool Equals(Core.Marker.Marker? other)
{
return other is MemberBinding && other.Id == Id;
}

}
206 changes: 170 additions & 36 deletions Rewrite/src/Rewrite.CSharp/Parser/CSharpParserVisitor.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
Expand Down Expand Up @@ -399,12 +400,18 @@ public override J VisitIdentifierName(IdentifierNameSyntax node)
var select = Convert<J>(node.Expression);
if (select is J.FieldAccess fa)
{
var mae = (MemberAccessExpressionSyntax)node.Expression;
var operatorToken = node.Expression switch
{
MemberAccessExpressionSyntax mae => mae.OperatorToken,
MemberBindingExpressionSyntax mbe => mbe.OperatorToken,
_ => throw new InvalidOperationException($"Unexpected node of type {node.Expression.GetType()} encountered.")
};

return new J.MethodInvocation(
Core.Tree.RandomId(),
prefix,
Markers.EMPTY,
new JRightPadded<Expression>(fa.Target, fa.Padding.Name.Before, Markers.EMPTY),
fa.Markers,
new JRightPadded<Expression>(fa.Target, Format(operatorToken.LeadingTrivia), Markers.EMPTY),
null,
fa.Name,
MapArgumentList(node.ArgumentList),
Expand Down Expand Up @@ -444,7 +451,7 @@ pt.Clazz is J.FieldAccess lfa
pt.TypeParameters.Select(JRightPadded<Expression>.Build).ToList(),
Markers.EMPTY
)
: null, // TODO type parameters
: null, // TODO: type parameters
pt.Clazz is J.Identifier i
? i
: (pt.Clazz as J.FieldAccess)?.Name ??
Expand All @@ -453,6 +460,27 @@ pt.Clazz is J.Identifier i
MapType(node) as JavaType.Method
);
}
else if (select is J.MethodInvocation mi) // chained method invocation (method returns a delegate). ex. Something()()
{

return new J.MethodInvocation(
Core.Tree.RandomId(),
prefix,
Markers.EMPTY,
JRightPadded<Expression>.Build(mi),
null,
new J.Identifier(
Core.Tree.RandomId(),
Space.EMPTY,
Markers.EMPTY,
new List<J.Annotation>(),
"",
null,
null),
MapArgumentList(node.ArgumentList),
MapType(node) as JavaType.Method
);
}

for (var index = 0; index < node.ArgumentList.Arguments.Count; index++)
{
Expand Down Expand Up @@ -1446,40 +1474,116 @@ private J.Identifier MapIdentifier(SyntaxToken identifier, JavaType? type)

public override J? VisitConditionalAccessExpression(ConditionalAccessExpressionSyntax node)
{
// the AST hierarchy in Roslyn is very different from that with `J.FieldAccess`
// see `VisitMemberBindingExpression()` for more details
// return Convert<Expression>(node.WhenNotNull)!;
return base.VisitConditionalAccessExpression(node);
// conditional expressions appear in their "natural order"
// meaning for an expression like this "a?.b", node a will be at the top of hierarchy
// vs in normal expression such as "a.b", b will be at the top of hierarchy

// LST doesn't reverse this order, so we need to traverse down any chain of nullable expression tree, and then
// process them in reverse
var conditionalExpressions = new List<(ExpressionSyntax, Space)>();
ExpressionSyntax currentNode = node;
while(currentNode is ConditionalAccessExpressionSyntax conditionalNode)
{
conditionalExpressions.Add((conditionalNode.Expression, Format(Leading(conditionalNode.OperatorToken))));
currentNode = conditionalNode.WhenNotNull;
}
conditionalExpressions.Add((currentNode, Format(Leading(currentNode))));
// at this point conditionalExpressions for something like this: a?.b?.c
// would look like this ['a','.b','.c']

Expression currentExpression = null!;// = Convert<Expression>(conditionalExpressions[0].Item1)!;
// each item in list will be individual expressions that form null access path, last one being the "normal"
// expression that is at
var i = 0;
foreach (var (expressionPortion, afterSpace) in conditionalExpressions)
{
var isLastSegment = i == conditionalExpressions.Count - 1;
var lstNode = Convert<Expression>(expressionPortion)!;
// somewhere in this node, a MemberBindingExpression got converted to either FieldAccess or MethodInvocation
// the expression is "fake" and needs to be adjusted. luckly we got a marker to locate this special node that needs to be
// fixed up. The expression for it will become lhs from previous loop iteration (stored in currentExpression)
// ps: god help you if you need to fix this logic :)
var bindingNode = lstNode.Descendents()
.FirstOrDefault(x => x.Markers.Contains<MemberBinding>());
if (bindingNode != null)
{
if (bindingNode is J.MethodInvocation methodNode)
{
var newMethod = methodNode.WithSelect(currentExpression);
lstNode = methodNode.Equals(lstNode) ? newMethod : lstNode.ReplaceNode(methodNode, newMethod);
}
else if (bindingNode is J.FieldAccess fieldAccess)
{
var newFieldAccess = fieldAccess.WithTarget(currentExpression);
lstNode = fieldAccess.Equals(lstNode) ? newFieldAccess : lstNode.ReplaceNode(fieldAccess, newFieldAccess);
}
}

// right hand side is the root and doesn't get wrapped
if (!isLastSegment)
{
lstNode = new Cs.NullSafeExpression(
Core.Tree.RandomId(),
Format(Leading(expressionPortion)),
Markers.EMPTY,
new JRightPadded<Expression>(
lstNode!,
afterSpace,
Markers.EMPTY
)
);
}

currentExpression = lstNode;

i++;
}

// var result = Convert<Expression>(node.WhenNotNull)!;
return currentExpression;

// return base.VisitConditionalAccessExpression(node);
}

/// <summary>
/// Very similar to MemberAccessExpression, but doesn't have an expression portion - just identifier
/// Used in ConditionalAccessExpression since they are constructed left to right, then right to left like normal field access
/// </summary>
public override J? VisitMemberBindingExpression(MemberBindingExpressionSyntax node)
{


// due to the fact that the `ConditionalAccessExpressionSyntax` is at the root of an expression like `foo?.Bar.Baz`
// we need to find that root here, as the containment hierarchy using `J.FieldAccess` and `Cs.NullSafeExpression`
// ends up being very different
ExpressionSyntax? parent = node;
while (parent is not ConditionalAccessExpressionSyntax)
if ((parent = parent.Parent as ExpressionSyntax) == null)
throw new InvalidOperationException(
"Cannot find a `ConditionalAccessExpressionSyntax` in the containment hierarchy.");

var conditionalAccess = (ConditionalAccessExpressionSyntax)parent;
var lhs = new Cs.NullSafeExpression(
Core.Tree.RandomId(),
Format(Leading(node)),
Markers.EMPTY,
new JRightPadded<Expression>(
Convert<Expression>(conditionalAccess.Expression)!,
Format(Leading(conditionalAccess.OperatorToken)),
Markers.EMPTY
)
);
// ExpressionSyntax? parent = node;
// while (parent is not ConditionalAccessExpressionSyntax)
// if ((parent = parent.Parent as ExpressionSyntax) == null)
// throw new InvalidOperationException(
// "Cannot find a `ConditionalAccessExpressionSyntax` in the containment hierarchy.");
//
// var conditionalAccess = (ConditionalAccessExpressionSyntax)parent;
// var lhs = new Cs.NullSafeExpression(
// Core.Tree.RandomId(),
// Format(Leading(node)),
// Markers.EMPTY,
// new JRightPadded<Expression>(
// Convert<Expression>(conditionalAccess.Expression)!,
// Format(Leading(conditionalAccess.OperatorToken)),
// Markers.EMPTY
// )
// );

return new J.FieldAccess(
Core.Tree.RandomId(),
Format(Leading(node)),
Markers.EMPTY,
lhs,
new Markers(
Core.Tree.RandomId(),
new List<Core.Marker.Marker>
{
new MemberBinding(Core.Tree.RandomId())
}),
Convert<Expression>(node.Name)!,
new JLeftPadded<J.Identifier>(
Format(Leading(node.OperatorToken)),
Convert<J.Identifier>(node.Name)!,
Expand Down Expand Up @@ -3484,7 +3588,9 @@ private JRightPadded<Expression> MapSwitchCaseLabel(SwitchLabelSyntax sls)
{
return base.VisitNullableDirectiveTrivia(node);
}

#if DEBUG_VISITOR
[DebuggerStepThrough]
#endif
T? Convert<T>(SyntaxNode? node) where T : class, J
{
if (node == null) return default;
Expand Down Expand Up @@ -3596,21 +3702,30 @@ private JRightPadded<T> MapParameter<T>(ParameterSyntax tps) where T : J
)).ToList();
}

#if DEBUG_VISITOR
[DebuggerStepThrough]
#endif
private List<Cs.AttributeList>? MapAttributes(SyntaxList<AttributeListSyntax> m)
{
return m.Count == 0 ? null : m.Select(x => Convert<Cs.AttributeList>(x)!).ToList();
}

#if DEBUG_VISITOR
[DebuggerStepThrough]
#endif
private JavaType MapType(ExpressionSyntax ins)
{
return _typeMapping.Type(semanticModel.GetTypeInfo(ins).Type);
}

#if DEBUG_VISITOR
[DebuggerStepThrough]
#endif
private JavaType MapType(SyntaxNode ins)
{
return _typeMapping.Type(semanticModel.GetDeclaredSymbol(ins) ?? semanticModel.GetTypeInfo(ins).Type);
}

// #if DEBUG_VISITOR
// [DebuggerStepThrough]
// #endif
private JRightPadded<Statement> MapStatement(StatementSyntax statementSyntax)
{
var statement = (Visit(statementSyntax) as Statement)!;
Expand All @@ -3628,18 +3743,26 @@ statementSyntax is LocalFunctionStatementSyntax
: Markers.EMPTY
);
}

#if DEBUG_VISITOR
[DebuggerStepThrough]
#endif
private SyntaxTriviaList Leading<T>(SyntaxList<T> list) where T : SyntaxNode
{
return list.Count == 0 ? SyntaxTriviaList.Empty : Leading(list.First());
}

#if DEBUG_VISITOR
[DebuggerStepThrough]
#endif
private SyntaxTriviaList Leading(SyntaxNode node)
{
var firstToken = node.GetFirstToken();
return Leading(firstToken);
}

#if DEBUG_VISITOR
[DebuggerStepThrough]
#endif
private SyntaxTriviaList Leading(SyntaxToken token)
{
var previousToken = token.GetPreviousToken();
Expand All @@ -3652,6 +3775,9 @@ private SyntaxTriviaList Leading(SyntaxToken token)
return OnlyUnseenTrivia(trailing, leading);
}

#if DEBUG_VISITOR
[DebuggerStepThrough]
#endif
private SyntaxTriviaList OnlyUnseenTrivia(SyntaxTriviaList trivia)
{
var span = trivia.Span;
Expand All @@ -3665,7 +3791,9 @@ private SyntaxTriviaList OnlyUnseenTrivia(SyntaxTriviaList trivia)
_seenTriviaSpans.Sort();
return trivia;
}

#if DEBUG_VISITOR
[DebuggerStepThrough]
#endif
private SyntaxTriviaList OnlyUnseenTrivia(SyntaxTriviaList trivia1, SyntaxTriviaList trivia2)
{
var span = new TextSpan(trivia1.Span.Start, trivia2.Span.End - trivia1.Span.Start);
Expand All @@ -3679,17 +3807,23 @@ private SyntaxTriviaList OnlyUnseenTrivia(SyntaxTriviaList trivia1, SyntaxTrivia
_seenTriviaSpans.Sort();
return trivia1.AddRange(trivia2);
}

#if DEBUG_VISITOR
[DebuggerStepThrough]
#endif
private SyntaxTriviaList Trailing(SyntaxNode node)
{
return Trailing(node.GetLastToken());
}

#if DEBUG_VISITOR
[DebuggerStepThrough]
#endif
private SyntaxTriviaList Trailing(SyntaxToken token)
{
return Leading(token.GetNextToken());
}

#if DEBUG_VISITOR
[DebuggerStepThrough]
#endif
private static Space Format(SyntaxTriviaList trivia)
{
// FIXME optimize
Expand Down
Loading