-
Notifications
You must be signed in to change notification settings - Fork 0
Refactor GeneratesMethodGenerator into a tiny orchestrator with focused helper components
#14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
ae88967
90d0eff
2eb1a2b
681caa0
e89e0c4
095d0c1
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,331 @@ | ||
| using Microsoft.CodeAnalysis; | ||
| using Microsoft.CodeAnalysis.CSharp; | ||
| using Microsoft.CodeAnalysis.CSharp.Syntax; | ||
| using Microsoft.CodeAnalysis.Emit; | ||
| using System.Collections; | ||
| using System.Reflection; | ||
| using System.Runtime.Loader; | ||
| using System.Text; | ||
|
|
||
| namespace MattSourceGenHelpers.Generators; | ||
|
|
||
| internal sealed record SwitchBodyData( | ||
| IReadOnlyList<(object key, string value)> CasePairs, | ||
| bool HasDefaultCase); | ||
|
|
||
| internal static class GeneratesMethodExecutionRuntime | ||
| { | ||
| internal static (string? value, string? error) ExecuteSimpleGeneratorMethod( | ||
| IMethodSymbol generatorMethod, | ||
| IMethodSymbol partialMethod, | ||
| Compilation compilation) | ||
| { | ||
| IReadOnlyList<IMethodSymbol> allPartials = GetAllUnimplementedPartialMethods(compilation); | ||
| return ExecuteGeneratorMethodWithArgs(generatorMethod, allPartials, compilation, null); | ||
| } | ||
|
|
||
| internal static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMethod( | ||
| IMethodSymbol generatorMethod, | ||
| IMethodSymbol partialMethod, | ||
| Compilation compilation) | ||
| { | ||
| IReadOnlyList<IMethodSymbol> allPartials = GetAllUnimplementedPartialMethods(compilation); | ||
| CSharpCompilation executableCompilation = BuildExecutionCompilation(allPartials, compilation); | ||
|
|
||
| using MemoryStream stream = new(); | ||
| EmitResult emitResult = executableCompilation.Emit(stream); | ||
| if (!emitResult.Success) | ||
| { | ||
| string errors = string.Join("; ", emitResult.Diagnostics | ||
| .Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error) | ||
| .Select(diagnostic => diagnostic.GetMessage())); | ||
| return (null, $"Compilation failed: {errors}"); | ||
| } | ||
|
|
||
| stream.Position = 0; | ||
| AssemblyLoadContext? loadContext = null; | ||
| try | ||
| { | ||
| loadContext = new AssemblyLoadContext("__GeneratorExec", isCollectible: true); | ||
| loadContext.Resolving += (context, assemblyName) => | ||
| { | ||
| PortableExecutableReference? match = compilation.References | ||
| .OfType<PortableExecutableReference>() | ||
| .FirstOrDefault(reference => reference.FilePath is not null && string.Equals( | ||
| Path.GetFileNameWithoutExtension(reference.FilePath), | ||
| assemblyName.Name, | ||
| StringComparison.OrdinalIgnoreCase)); | ||
| return match?.FilePath != null ? context.LoadFromAssemblyPath(match.FilePath) : null; | ||
| }; | ||
|
|
||
| Assembly assembly = loadContext.LoadFromStream(stream); | ||
|
|
||
| PortableExecutableReference? abstractionsReference = compilation.References | ||
| .OfType<PortableExecutableReference>() | ||
| .FirstOrDefault(reference => reference.FilePath is not null && string.Equals( | ||
| Path.GetFileNameWithoutExtension(reference.FilePath), | ||
| "MattSourceGenHelpers.Abstractions", | ||
| StringComparison.OrdinalIgnoreCase)); | ||
|
Comment on lines
+63
to
+68
|
||
|
|
||
| if (abstractionsReference?.FilePath == null) | ||
| { | ||
| return (null, "Could not find MattSourceGenHelpers.Abstractions reference in compilation"); | ||
| } | ||
|
|
||
| string abstractionsAssemblyPath = ResolveImplementationAssemblyPath(abstractionsReference.FilePath); | ||
| Assembly abstractionsAssembly = loadContext.LoadFromAssemblyPath(abstractionsAssemblyPath); | ||
|
|
||
| Type? generatorStaticType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.Generator"); | ||
| Type? recordingFactoryType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.RecordingGeneratorsFactory"); | ||
| if (generatorStaticType == null || recordingFactoryType == null) | ||
| { | ||
| return (null, "Could not find Generator or RecordingGeneratorsFactory types in Abstractions assembly"); | ||
| } | ||
|
|
||
| object? recordingFactory = Activator.CreateInstance(recordingFactoryType); | ||
| PropertyInfo? currentGeneratorProperty = generatorStaticType.GetProperty("CurrentGenerator", BindingFlags.Public | BindingFlags.Static); | ||
| currentGeneratorProperty?.SetValue(null, recordingFactory); | ||
|
|
||
| string typeName = generatorMethod.ContainingType.ToDisplayString(); | ||
| Type? loadedType = assembly.GetType(typeName); | ||
| if (loadedType == null) | ||
| { | ||
| return (null, $"Could not find type '{typeName}' in compiled assembly"); | ||
| } | ||
|
|
||
| MethodInfo? generatorMethodInfo = loadedType.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public); | ||
| if (generatorMethodInfo == null) | ||
| { | ||
| return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'"); | ||
| } | ||
|
|
||
| generatorMethodInfo.Invoke(null, null); | ||
|
|
||
| PropertyInfo? lastRecordProperty = recordingFactoryType.GetProperty("LastRecord"); | ||
| object? lastRecord = lastRecordProperty?.GetValue(recordingFactory); | ||
| if (lastRecord == null) | ||
| { | ||
| return (null, "RecordingGeneratorsFactory did not produce a record"); | ||
| } | ||
|
|
||
| return (ExtractSwitchBodyData(lastRecord, partialMethod.ReturnType), null); | ||
| } | ||
| catch (Exception ex) | ||
| { | ||
| return (null, $"Error executing generator method '{generatorMethod.Name}': {ex.GetBaseException()}"); | ||
| } | ||
| finally | ||
| { | ||
| loadContext?.Unload(); | ||
| } | ||
| } | ||
|
|
||
| internal static (string? value, string? error) ExecuteGeneratorMethodWithArgs( | ||
| IMethodSymbol generatorMethod, | ||
| IReadOnlyList<IMethodSymbol> allPartialMethods, | ||
| Compilation compilation, | ||
| object?[]? args) | ||
| { | ||
| CSharpCompilation executableCompilation = BuildExecutionCompilation(allPartialMethods, compilation); | ||
|
|
||
| using MemoryStream stream = new(); | ||
| EmitResult emitResult = executableCompilation.Emit(stream); | ||
| if (!emitResult.Success) | ||
| { | ||
| string errors = string.Join("; ", emitResult.Diagnostics | ||
| .Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error) | ||
| .Select(diagnostic => diagnostic.GetMessage())); | ||
| return (null, $"Compilation failed: {errors}"); | ||
| } | ||
|
|
||
| stream.Position = 0; | ||
| AssemblyLoadContext? loadContext = null; | ||
| try | ||
| { | ||
| loadContext = new AssemblyLoadContext("__GeneratorExec", isCollectible: true); | ||
| loadContext.Resolving += (context, assemblyName) => | ||
| { | ||
| PortableExecutableReference? match = compilation.References | ||
| .OfType<PortableExecutableReference>() | ||
| .FirstOrDefault(reference => reference.FilePath is not null && string.Equals( | ||
| Path.GetFileNameWithoutExtension(reference.FilePath), | ||
| assemblyName.Name, | ||
| StringComparison.OrdinalIgnoreCase)); | ||
| return match?.FilePath != null ? context.LoadFromAssemblyPath(match.FilePath) : null; | ||
| }; | ||
|
|
||
| Assembly assembly = loadContext.LoadFromStream(stream); | ||
| string typeName = generatorMethod.ContainingType.ToDisplayString(); | ||
| Type? loadedType = assembly.GetType(typeName); | ||
| if (loadedType == null) | ||
| { | ||
| return (null, $"Could not find type '{typeName}' in compiled assembly"); | ||
| } | ||
|
|
||
| MethodInfo? generatorMethodInfo = loadedType.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public); | ||
| if (generatorMethodInfo == null) | ||
| { | ||
| return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'"); | ||
| } | ||
|
|
||
| object?[]? convertedArgs = ConvertArguments(args, generatorMethodInfo); | ||
| object? result = generatorMethodInfo.Invoke(null, convertedArgs); | ||
| return (result?.ToString(), null); | ||
| } | ||
| catch (Exception ex) | ||
| { | ||
| return (null, $"Error executing generator method '{generatorMethod.Name}': {ex.GetBaseException()}"); | ||
| } | ||
| finally | ||
| { | ||
| loadContext?.Unload(); | ||
| } | ||
| } | ||
|
|
||
| internal static IReadOnlyList<IMethodSymbol> GetAllUnimplementedPartialMethods(Compilation compilation) | ||
| { | ||
| List<IMethodSymbol> methods = new(); | ||
| foreach (SyntaxTree syntaxTree in compilation.SyntaxTrees) | ||
| { | ||
| SemanticModel semanticModel = compilation.GetSemanticModel(syntaxTree); | ||
| IEnumerable<MethodDeclarationSyntax> partialMethodDeclarations = syntaxTree.GetRoot().DescendantNodes() | ||
| .OfType<MethodDeclarationSyntax>() | ||
| .Where(method => method.Modifiers.Any(modifier => modifier.IsKind(SyntaxKind.PartialKeyword))); | ||
|
|
||
| foreach (MethodDeclarationSyntax declaration in partialMethodDeclarations) | ||
| { | ||
| if (semanticModel.GetDeclaredSymbol(declaration) is IMethodSymbol symbol && symbol.IsPartialDefinition) | ||
| { | ||
| methods.Add(symbol); | ||
| } | ||
| } | ||
| } | ||
|
|
||
| return methods; | ||
| } | ||
|
|
||
| private static object?[]? ConvertArguments(object?[]? args, MethodInfo methodInfo) | ||
| { | ||
| if (args == null || methodInfo.GetParameters().Length == 0) | ||
| { | ||
| return null; | ||
| } | ||
|
|
||
| Type parameterType = methodInfo.GetParameters()[0].ParameterType; | ||
| return new[] { Convert.ChangeType(args[0], parameterType) }; | ||
| } | ||
|
|
||
| private static SwitchBodyData ExtractSwitchBodyData(object lastRecord, ITypeSymbol returnType) | ||
| { | ||
| Type recordType = lastRecord.GetType(); | ||
| PropertyInfo? caseKeysProperty = recordType.GetProperty("CaseKeys"); | ||
| PropertyInfo? caseValuesProperty = recordType.GetProperty("CaseValues"); | ||
| PropertyInfo? hasDefaultProperty = recordType.GetProperty("HasDefaultCase"); | ||
|
|
||
| IList caseKeys = (caseKeysProperty?.GetValue(lastRecord) as IList) ?? new List<object>(); | ||
| IList caseValues = (caseValuesProperty?.GetValue(lastRecord) as IList) ?? new List<object?>(); | ||
| bool hasDefaultCase = (bool)(hasDefaultProperty?.GetValue(lastRecord) ?? false); | ||
|
|
||
| List<(object key, string value)> pairs = new(); | ||
| for (int index = 0; index < caseKeys.Count; index++) | ||
| { | ||
| object key = caseKeys[index]!; | ||
| string? value = index < caseValues.Count ? caseValues[index]?.ToString() : null; | ||
| pairs.Add((key, GeneratesMethodPatternSourceBuilder.FormatValueAsCSharpLiteral(value, returnType))); | ||
| } | ||
|
|
||
| return new SwitchBodyData(pairs, hasDefaultCase); | ||
| } | ||
|
|
||
| private static string ResolveImplementationAssemblyPath(string path) | ||
| { | ||
| string? directory = Path.GetDirectoryName(path); | ||
| string? parentDirectory = directory != null ? Path.GetDirectoryName(directory) : null; | ||
| if (directory != null && | ||
| parentDirectory != null && | ||
| string.Equals(Path.GetFileName(directory), "ref", StringComparison.OrdinalIgnoreCase)) | ||
| { | ||
| return Path.Combine(parentDirectory, Path.GetFileName(path)); | ||
| } | ||
|
|
||
| return path; | ||
| } | ||
|
|
||
| private static CSharpCompilation BuildExecutionCompilation( | ||
| IReadOnlyList<IMethodSymbol> allPartialMethods, | ||
| Compilation compilation) | ||
| { | ||
| string dummySource = BuildDummyImplementation(allPartialMethods); | ||
| CSharpParseOptions parseOptions = compilation.SyntaxTrees.FirstOrDefault()?.Options as CSharpParseOptions | ||
| ?? CSharpParseOptions.Default; | ||
|
|
||
| return (CSharpCompilation)compilation | ||
| .WithOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)) | ||
| .AddSyntaxTrees(CSharpSyntaxTree.ParseText(dummySource, parseOptions)); | ||
| } | ||
|
|
||
| private static string BuildDummyImplementation(IEnumerable<IMethodSymbol> partialMethods) | ||
| { | ||
| StringBuilder builder = new(); | ||
|
|
||
| IEnumerable<IGrouping<(string? Namespace, string TypeName, bool IsStatic, TypeKind TypeKind), IMethodSymbol>> groupedMethods = partialMethods.GroupBy( | ||
| method => (Namespace: method.ContainingType.ContainingNamespace?.IsGlobalNamespace == false | ||
| ? method.ContainingType.ContainingNamespace.ToDisplayString() | ||
| : null, | ||
| TypeName: method.ContainingType.Name, | ||
| IsStatic: method.ContainingType.IsStatic, | ||
| TypeKind: method.ContainingType.TypeKind)); | ||
|
|
||
| foreach (IGrouping<(string? Namespace, string TypeName, bool IsStatic, TypeKind TypeKind), IMethodSymbol> typeGroup in groupedMethods) | ||
| { | ||
| string? namespaceName = typeGroup.Key.Namespace; | ||
| if (namespaceName != null) | ||
| { | ||
| builder.AppendLine($"namespace {namespaceName} {{"); | ||
| } | ||
|
|
||
| string typeKeyword = typeGroup.Key.TypeKind switch | ||
| { | ||
| TypeKind.Struct => "struct", | ||
| _ => "class" | ||
| }; | ||
|
|
||
| string typeModifiers = typeGroup.Key.IsStatic ? "static partial" : "partial"; | ||
| builder.AppendLine($"{typeModifiers} {typeKeyword} {typeGroup.Key.TypeName} {{"); | ||
|
|
||
| foreach (IMethodSymbol partialMethod in typeGroup) | ||
| { | ||
| string accessibility = partialMethod.DeclaredAccessibility switch | ||
| { | ||
| Accessibility.Public => "public", | ||
| Accessibility.Protected => "protected", | ||
| Accessibility.Internal => "internal", | ||
| Accessibility.ProtectedOrInternal => "protected internal", | ||
| Accessibility.ProtectedAndInternal => "private protected", | ||
| _ => "" | ||
| }; | ||
|
|
||
| string staticModifier = partialMethod.IsStatic ? "static " : ""; | ||
| string returnType = partialMethod.ReturnType.ToDisplayString(); | ||
| string parameters = string.Join(", ", partialMethod.Parameters.Select(parameter => $"{parameter.Type.ToDisplayString()} {parameter.Name}")); | ||
|
|
||
| builder.AppendLine($"{accessibility} {staticModifier}partial {returnType} {partialMethod.Name}({parameters}) {{"); | ||
| if (!partialMethod.ReturnsVoid) | ||
| { | ||
| builder.AppendLine("return default!;"); | ||
| } | ||
|
|
||
| builder.AppendLine("}"); | ||
| } | ||
|
|
||
| builder.AppendLine("}"); | ||
|
|
||
| if (namespaceName != null) | ||
| { | ||
| builder.AppendLine("}"); | ||
| } | ||
| } | ||
|
|
||
| return builder.ToString(); | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PortableExecutableReference.FilePathis nullable, but the predicate callsPath.GetFileNameWithoutExtension(reference.FilePath)without a null check. This can throw (and also triggers nullable warnings) when the compilation contains in-memory references withFilePath == null. Consider filtering toreference.FilePath is not nullbefore callingPath.*(and/or using a safe fallback).There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in 095d0c1 by guarding
reference.FilePath is not nullbefore callingPath.GetFileNameWithoutExtensionin the resolver lookup.