From ae88967c31de2829a18c53e0d2cecfb6826e43e1 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:00:38 +0000 Subject: [PATCH 1/6] Initial plan From 90d0eff48678022333223305a9d770e810f5a5fe Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:05:19 +0000 Subject: [PATCH 2/6] refactor: extract GeneratesMethod diagnostics and method collection Co-authored-by: dex3r <3155725+dex3r@users.noreply.github.com> --- ...eneratesMethodGenerationTargetCollector.cs | 90 ++++++ .../GeneratesMethodGenerator.Diagnostics.cs | 40 +++ .../GeneratesMethodGenerator.cs | 306 +++++++----------- 3 files changed, 240 insertions(+), 196 deletions(-) create mode 100644 MattSourceGenHelpers.Generators/GeneratesMethodGenerationTargetCollector.cs create mode 100644 MattSourceGenHelpers.Generators/GeneratesMethodGenerator.Diagnostics.cs diff --git a/MattSourceGenHelpers.Generators/GeneratesMethodGenerationTargetCollector.cs b/MattSourceGenHelpers.Generators/GeneratesMethodGenerationTargetCollector.cs new file mode 100644 index 0000000..1807c2a --- /dev/null +++ b/MattSourceGenHelpers.Generators/GeneratesMethodGenerationTargetCollector.cs @@ -0,0 +1,90 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using System.Collections.Immutable; + +namespace MattSourceGenHelpers.Generators; + +internal sealed record GeneratesMethodGenerationTarget( + MethodDeclarationSyntax Syntax, + IMethodSymbol Symbol, + string TargetMethodName, + IMethodSymbol PartialMethod, + INamedTypeSymbol ContainingType); + +internal static class GeneratesMethodGenerationTargetCollector +{ + private const string GeneratesMethodAttributeTypeName = "MattSourceGenHelpers.Abstractions.GeneratesMethod"; + + internal static List Collect( + SourceProductionContext context, + ImmutableArray generatorMethods, + Compilation compilation) + { + List validMethods = new(); + + foreach (MethodDeclarationSyntax? generatorMethod in generatorMethods) + { + if (generatorMethod is null) + { + continue; + } + + SemanticModel semanticModel = compilation.GetSemanticModel(generatorMethod.SyntaxTree); + IMethodSymbol? methodSymbol = semanticModel.GetDeclaredSymbol(generatorMethod) as IMethodSymbol; + + if (methodSymbol is null) + { + continue; + } + + if (!methodSymbol.IsStatic) + { + context.ReportDiagnostic(Diagnostic.Create( + GeneratesMethodGeneratorDiagnostics.GeneratorMethodMustBeStaticError, + generatorMethod.GetLocation(), + generatorMethod.Identifier.Text)); + continue; + } + + AttributeData? attribute = methodSymbol + .GetAttributes() + .FirstOrDefault(attributeData => attributeData.AttributeClass?.ToDisplayString() == GeneratesMethodAttributeTypeName); + + if (attribute is null || attribute.ConstructorArguments.Length == 0) + { + continue; + } + + string? targetMethodName = attribute.ConstructorArguments[0].Value?.ToString(); + if (string.IsNullOrWhiteSpace(targetMethodName)) + { + continue; + } + + INamedTypeSymbol containingType = methodSymbol.ContainingType; + IMethodSymbol? partialMethodSymbol = containingType + .GetMembers(targetMethodName) + .OfType() + .FirstOrDefault(method => method.IsPartialDefinition); + + if (partialMethodSymbol is null) + { + context.ReportDiagnostic(Diagnostic.Create( + GeneratesMethodGeneratorDiagnostics.MissingPartialMethodError, + generatorMethod.GetLocation(), + targetMethodName, + containingType.Name)); + continue; + } + + validMethods.Add(new GeneratesMethodGenerationTarget( + generatorMethod, + methodSymbol, + targetMethodName, + partialMethodSymbol, + containingType)); + } + + return validMethods; + } +} diff --git a/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.Diagnostics.cs b/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.Diagnostics.cs new file mode 100644 index 0000000..1d9f5b2 --- /dev/null +++ b/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.Diagnostics.cs @@ -0,0 +1,40 @@ +using Microsoft.CodeAnalysis; + +namespace MattSourceGenHelpers.Generators; + +internal static class GeneratesMethodGeneratorDiagnostics +{ + private const string Category = "GeneratesMethodGenerator"; + + internal static readonly DiagnosticDescriptor MissingPartialMethodError = new( + id: "MSGH001", + title: "Missing partial method", + messageFormat: "Could not find partial method '{0}' in class '{1}'", + category: Category, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true); + + internal static readonly DiagnosticDescriptor GeneratorMethodMustBeStaticError = new( + id: "MSGH002", + title: "Generator method must be static", + messageFormat: "Method '{0}' marked with [GeneratesMethod] must be static", + category: Category, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true); + + internal static readonly DiagnosticDescriptor GeneratingMethodInfo = new( + id: "MSGH003", + title: "Generating partial method implementation", + messageFormat: "Generating implementation for partial method '{0}' in class '{1}' using generator '{2}'", + category: Category, + defaultSeverity: DiagnosticSeverity.Info, + isEnabledByDefault: false); + + internal static readonly DiagnosticDescriptor GeneratorMethodExecutionError = new( + id: "MSGH004", + title: "Generator method execution failed", + messageFormat: "Failed to execute generator method '{0}': {1}", + category: Category, + defaultSeverity: DiagnosticSeverity.Error, + isEnabledByDefault: true); +} diff --git a/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs b/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs index b195310..736cb8b 100644 --- a/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs +++ b/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs @@ -1,6 +1,7 @@ -using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Emit; using System.Collections; using System.Collections.Immutable; using System.Reflection; @@ -14,53 +15,14 @@ namespace MattSourceGenHelpers.Generators; [Generator] public class GeneratesMethodGenerator : IIncrementalGenerator { - private static readonly DiagnosticDescriptor MissingPartialMethodError = new( - id: "MSGH001", - title: "Missing partial method", - messageFormat: "Could not find partial method '{0}' in class '{1}'", - category: "GeneratesMethodGenerator", - defaultSeverity: DiagnosticSeverity.Error, - isEnabledByDefault: true); - - private static readonly DiagnosticDescriptor GeneratorMethodMustBeStaticError = new( - id: "MSGH002", - title: "Generator method must be static", - messageFormat: "Method '{0}' marked with [GeneratesMethod] must be static", - category: "GeneratesMethodGenerator", - defaultSeverity: DiagnosticSeverity.Error, - isEnabledByDefault: true); - - private static readonly DiagnosticDescriptor GeneratingMethodInfo = new( - id: "MSGH003", - title: "Generating partial method implementation", - messageFormat: "Generating implementation for partial method '{0}' in class '{1}' using generator '{2}'", - category: "GeneratesMethodGenerator", - defaultSeverity: DiagnosticSeverity.Info, - isEnabledByDefault: false); - - private static readonly DiagnosticDescriptor GeneratorMethodExecutionError = new( - id: "MSGH004", - title: "Generator method execution failed", - messageFormat: "Failed to execute generator method '{0}': {1}", - category: "GeneratesMethodGenerator", - defaultSeverity: DiagnosticSeverity.Error, - isEnabledByDefault: true); - - private record GeneratorMethodInfo( - MethodDeclarationSyntax Syntax, - IMethodSymbol Symbol, - string TargetMethodName, - IMethodSymbol PartialMethod, - INamedTypeSymbol ContainingType); - public void Initialize(IncrementalGeneratorInitializationContext context) { - var methodsWithAttribute = context.SyntaxProvider + IncrementalValueProvider> methodsWithAttribute = context.SyntaxProvider .CreateSyntaxProvider( predicate: IsMethodWithGeneratesMethodAttribute, transform: GetMethodDeclaration) - .Where(m => m != null) - .Collect(); + .Where(m => m != null) + .Collect(); context.RegisterSourceOutput( methodsWithAttribute.Combine(context.CompilationProvider), @@ -87,70 +49,22 @@ private void Execute( ImmutableArray generatorMethods, Compilation compilation) { - var validMethods = new List(); - - foreach (var generatorMethod in generatorMethods) - { - if (generatorMethod == null) - continue; - - var semanticModel = compilation.GetSemanticModel(generatorMethod.SyntaxTree); - var methodSymbol = semanticModel.GetDeclaredSymbol(generatorMethod) as IMethodSymbol; - - if (methodSymbol == null) - continue; - - if (!methodSymbol.IsStatic) - { - context.ReportDiagnostic(Diagnostic.Create( - GeneratorMethodMustBeStaticError, - generatorMethod.GetLocation(), - generatorMethod.Identifier.Text)); - continue; - } - - var attribute = methodSymbol.GetAttributes() - .FirstOrDefault(a => a.AttributeClass?.ToDisplayString() == "MattSourceGenHelpers.Abstractions.GeneratesMethod"); - - if (attribute == null || attribute.ConstructorArguments.Length == 0) - continue; - - var targetMethodName = attribute.ConstructorArguments[0].Value?.ToString(); - if (string.IsNullOrEmpty(targetMethodName)) - continue; - - var containingType = methodSymbol.ContainingType; - var partialMethodSymbol = containingType.GetMembers(targetMethodName) - .OfType() - .FirstOrDefault(m => m.IsPartialDefinition); - - if (partialMethodSymbol == null) - { - context.ReportDiagnostic(Diagnostic.Create( - MissingPartialMethodError, - generatorMethod.GetLocation(), - targetMethodName, - containingType.Name)); - continue; - } - - validMethods.Add(new GeneratorMethodInfo(generatorMethod, methodSymbol, targetMethodName, partialMethodSymbol, containingType)); - } + List validMethods = GeneratesMethodGenerationTargetCollector.Collect(context, generatorMethods, compilation); // Group by (containing type display string, target method name) - var groups = validMethods + IEnumerable> groups = validMethods .GroupBy(m => (TypeKey: m.ContainingType.ToDisplayString(), m.TargetMethodName)); // Collect all unimplemented partial methods once for the whole compilation - var allPartials = GetAllUnimplementedPartialMethods(compilation); + IReadOnlyList allPartials = GetAllUnimplementedPartialMethods(compilation); - foreach (var group in groups) + foreach (IGrouping<(string TypeKey, string TargetMethodName), GeneratesMethodGenerationTarget> group in groups) { - var methods = group.ToList(); - var first = methods[0]; + List methods = group.ToList(); + GeneratesMethodGenerationTarget first = methods[0]; context.ReportDiagnostic(Diagnostic.Create( - GeneratingMethodInfo, + GeneratesMethodGeneratorDiagnostics.GeneratingMethodInfo, first.Syntax.GetLocation(), first.TargetMethodName, first.ContainingType.Name, @@ -178,11 +92,11 @@ private void Execute( else { // Simple pattern: execute first method and use returned value - var (returnValue, error) = ExecuteSimpleGeneratorMethod(first.Symbol, first.PartialMethod, compilation); + (string? returnValue, string? error) = ExecuteSimpleGeneratorMethod(first.Symbol, first.PartialMethod, compilation); if (error != null) { context.ReportDiagnostic(Diagnostic.Create( - GeneratorMethodExecutionError, + GeneratesMethodGeneratorDiagnostics.GeneratorMethodExecutionError, first.Syntax.GetLocation(), first.Symbol.Name, error)); @@ -202,36 +116,36 @@ private void Execute( private static string GenerateFromSwitchAttributes( SourceProductionContext context, - List methods, + List methods, IMethodSymbol partialMethod, INamedTypeSymbol containingType, IReadOnlyList allPartials, Compilation compilation) { - var switchCaseMethods = methods.Where(m => m.Symbol.GetAttributes() + List switchCaseMethods = methods.Where(m => m.Symbol.GetAttributes() .Any(a => a.AttributeClass?.ToDisplayString() == "MattSourceGenHelpers.Abstractions.SwitchCase")).ToList(); - var switchDefaultMethod = methods.FirstOrDefault(m => m.Symbol.GetAttributes() + GeneratesMethodGenerationTarget? switchDefaultMethod = methods.FirstOrDefault(m => m.Symbol.GetAttributes() .Any(a => a.AttributeClass?.ToDisplayString() == "MattSourceGenHelpers.Abstractions.SwitchDefault")); - var cases = new List<(object key, string value)>(); + List<(object key, string value)> cases = new(); // For each [SwitchCase] method, execute it for each case value - foreach (var switchMethod in switchCaseMethods) + foreach (GeneratesMethodGenerationTarget switchMethod in switchCaseMethods) { - var switchCaseAttrs = switchMethod.Symbol.GetAttributes() + IEnumerable switchCaseAttrs = switchMethod.Symbol.GetAttributes() .Where(a => a.AttributeClass?.ToDisplayString() == "MattSourceGenHelpers.Abstractions.SwitchCase"); - foreach (var attr in switchCaseAttrs) + foreach (AttributeData attr in switchCaseAttrs) { if (attr.ConstructorArguments.Length == 0) continue; - var caseArg = attr.ConstructorArguments[0].Value; + object? caseArg = attr.ConstructorArguments[0].Value; if (caseArg == null) continue; - var (result, error) = ExecuteGeneratorMethodWithArgs(switchMethod.Symbol, allPartials, compilation, new[] { caseArg }); + (string? result, string? error) = ExecuteGeneratorMethodWithArgs(switchMethod.Symbol, allPartials, compilation, new[] { caseArg }); if (error != null) { context.ReportDiagnostic(Diagnostic.Create( - GeneratorMethodExecutionError, + GeneratesMethodGeneratorDiagnostics.GeneratorMethodExecutionError, switchMethod.Syntax.GetLocation(), switchMethod.Symbol.Name, error)); @@ -257,11 +171,11 @@ private static string GenerateFromSwitchAttributes( private static string? ExtractDefaultExpressionFromSwitchDefaultMethod(MethodDeclarationSyntax method) { // Expression body: => - var bodyExpr = method.ExpressionBody?.Expression; + ExpressionSyntax? bodyExpr = method.ExpressionBody?.Expression; if (bodyExpr == null && method.Body != null) { // Block body: { return ; } - var returnStmt = method.Body.Statements.OfType().FirstOrDefault(); + ReturnStatementSyntax? returnStmt = method.Body.Statements.OfType().FirstOrDefault(); bodyExpr = returnStmt?.Expression; } return ExtractInnermostLambdaBody(bodyExpr); @@ -273,23 +187,23 @@ private static string GenerateFromSwitchAttributes( private static string GenerateFromFluent( SourceProductionContext context, - GeneratorMethodInfo methodInfo, + GeneratesMethodGenerationTarget methodInfo, IMethodSymbol partialMethod, INamedTypeSymbol containingType, Compilation compilation) { - var (record, error) = ExecuteFluentGeneratorMethod(methodInfo.Symbol, partialMethod, compilation); + (SwitchBodyData? record, string? error) = ExecuteFluentGeneratorMethod(methodInfo.Symbol, partialMethod, compilation); if (error != null) { context.ReportDiagnostic(Diagnostic.Create( - GeneratorMethodExecutionError, + GeneratesMethodGeneratorDiagnostics.GeneratorMethodExecutionError, methodInfo.Syntax.GetLocation(), methodInfo.Symbol.Name, error)); return string.Empty; } - var cases = record!; + SwitchBodyData cases = record!; // Extract default expression from the RuntimeBody or CompileTimeBody call in the method syntax string? defaultExpression = null; @@ -307,14 +221,14 @@ private static string GenerateFromFluent( { // Walk all InvocationExpressionSyntax nodes; find the one named RuntimeBody or CompileTimeBody // that follows a ForDefaultCase() call. - var invocations = method.DescendantNodes().OfType(); - foreach (var inv in invocations) + IEnumerable invocations = method.DescendantNodes().OfType(); + foreach (InvocationExpressionSyntax inv in invocations) { if (inv.Expression is not MemberAccessExpressionSyntax ma) continue; - var name = ma.Name.Identifier.Text; + string name = ma.Name.Identifier.Text; if (name is not ("RuntimeBody" or "CompileTimeBody")) continue; - var arg = inv.ArgumentList.Arguments.FirstOrDefault()?.Expression; + ExpressionSyntax? arg = inv.ArgumentList.Arguments.FirstOrDefault()?.Expression; return ExtractInnermostLambdaBody(arg); } return null; @@ -353,16 +267,16 @@ private static string GenerateSwitchMethodSource( IReadOnlyList<(object key, string value)> cases, string? defaultExpression) { - var sb = new StringBuilder(); + StringBuilder sb = new(); AppendNamespaceAndTypeHeader(sb, containingType, partialMethod); - var paramName = partialMethod.Parameters.Length > 0 ? partialMethod.Parameters[0].Name : "arg"; + string paramName = partialMethod.Parameters.Length > 0 ? partialMethod.Parameters[0].Name : "arg"; sb.AppendLine($" switch ({paramName})"); sb.AppendLine(" {"); - foreach (var (key, value) in cases) + foreach ((object key, string value) in cases) sb.AppendLine($" case {key}: return {value};"); if (defaultExpression != null) @@ -377,7 +291,7 @@ private static string GenerateSwitchMethodSource( private static void AppendNamespaceAndTypeHeader(StringBuilder sb, INamedTypeSymbol containingType, IMethodSymbol partialMethod) { - var namespaceName = containingType.ContainingNamespace?.IsGlobalNamespace == false + string? namespaceName = containingType.ContainingNamespace?.IsGlobalNamespace == false ? containingType.ContainingNamespace.ToDisplayString() : null; @@ -387,18 +301,18 @@ private static void AppendNamespaceAndTypeHeader(StringBuilder sb, INamedTypeSym sb.AppendLine(); } - var typeKeyword = containingType.TypeKind switch + string typeKeyword = containingType.TypeKind switch { TypeKind.Struct => "struct", TypeKind.Interface => "interface", _ => "class" }; - var typeModifiers = containingType.IsStatic ? "static partial" : "partial"; + string typeModifiers = containingType.IsStatic ? "static partial" : "partial"; sb.AppendLine($"{typeModifiers} {typeKeyword} {containingType.Name}"); sb.AppendLine("{"); - var accessibility = partialMethod.DeclaredAccessibility switch + string accessibility = partialMethod.DeclaredAccessibility switch { Accessibility.Public => "public", Accessibility.Protected => "protected", @@ -408,11 +322,11 @@ private static void AppendNamespaceAndTypeHeader(StringBuilder sb, INamedTypeSym _ => "private" }; - var returnTypeName = partialMethod.ReturnType.ToDisplayString(); - var methodName = partialMethod.Name; - var parameters = string.Join(", ", partialMethod.Parameters.Select(p => $"{p.Type.ToDisplayString()} {p.Name}")); + string returnTypeName = partialMethod.ReturnType.ToDisplayString(); + string methodName = partialMethod.Name; + string parameters = string.Join(", ", partialMethod.Parameters.Select(p => $"{p.Type.ToDisplayString()} {p.Name}")); - var methodModifiers = partialMethod.IsStatic ? "static partial" : "partial"; + string methodModifiers = partialMethod.IsStatic ? "static partial" : "partial"; sb.AppendLine($" {accessibility} {methodModifiers} {returnTypeName} {methodName}({parameters})"); sb.AppendLine(" {"); } @@ -438,13 +352,13 @@ private static string GenerateSimplePartialMethod( IMethodSymbol partialMethod, string? returnValue) { - var sb = new StringBuilder(); + StringBuilder sb = new(); AppendNamespaceAndTypeHeader(sb, containingType, partialMethod); if (!partialMethod.ReturnsVoid) { - var literal = FormatCaseValue(returnValue, partialMethod.ReturnType); + string literal = FormatCaseValue(returnValue, partialMethod.ReturnType); sb.AppendLine($" return {literal};"); } @@ -463,8 +377,8 @@ private static (string? value, string? error) ExecuteSimpleGeneratorMethod( IMethodSymbol partialMethod, Compilation compilation) { - var allPartials = GetAllUnimplementedPartialMethods(compilation); - var (result, error) = ExecuteGeneratorMethodWithArgs(generatorMethod, allPartials, compilation, null); + IReadOnlyList allPartials = GetAllUnimplementedPartialMethods(compilation); + (string? result, string? error) = ExecuteGeneratorMethodWithArgs(generatorMethod, allPartials, compilation, null); return (result, error); } @@ -480,14 +394,14 @@ private static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMet IMethodSymbol partialMethod, Compilation compilation) { - var allPartials = GetAllUnimplementedPartialMethods(compilation); - var dllCompilation = BuildExecutionCompilation(allPartials, compilation); + IReadOnlyList allPartials = GetAllUnimplementedPartialMethods(compilation); + CSharpCompilation dllCompilation = BuildExecutionCompilation(allPartials, compilation); - using var ms = new MemoryStream(); - var emitResult = dllCompilation.Emit(ms); + using MemoryStream ms = new(); + EmitResult emitResult = dllCompilation.Emit(ms); if (!emitResult.Success) { - var errors = string.Join("; ", emitResult.Diagnostics + string errors = string.Join("; ", emitResult.Diagnostics .Where(d => d.Severity == DiagnosticSeverity.Error) .Select(d => d.GetMessage())); return (null, $"Compilation failed: {errors}"); @@ -500,7 +414,7 @@ private static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMet loadContext = new AssemblyLoadContext("__GeneratorExec", isCollectible: true); loadContext.Resolving += (ctx, assemblyName) => { - var match = compilation.References + PortableExecutableReference? match = compilation.References .OfType() .FirstOrDefault(r => string.Equals( Path.GetFileNameWithoutExtension(r.FilePath), @@ -509,13 +423,13 @@ private static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMet return match?.FilePath != null ? ctx.LoadFromAssemblyPath(match.FilePath) : null; }; - var assembly = loadContext.LoadFromStream(ms); + Assembly assembly = loadContext.LoadFromStream(ms); // The Generator and RecordingGeneratorsFactory types are in the Abstractions assembly // (a referenced assembly), not in the compiled user code assembly. // The compilation reference might point to a reference assembly (metadata-only), // so we try to find the actual implementation DLL. - var abstractionsRef = compilation.References + PortableExecutableReference? abstractionsRef = compilation.References .OfType() .FirstOrDefault(r => string.Equals( Path.GetFileNameWithoutExtension(r.FilePath), @@ -526,29 +440,29 @@ private static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMet return (null, "Could not find MattSourceGenHelpers.Abstractions reference in compilation"); // If path is a reference assembly (in a "ref" subdirectory), resolve the implementation DLL - var abstractionsPath = ResolveImplementationAssemblyPath(abstractionsRef.FilePath); + string abstractionsPath = ResolveImplementationAssemblyPath(abstractionsRef.FilePath); - var abstractionsAssembly = loadContext.LoadFromAssemblyPath(abstractionsPath); + Assembly abstractionsAssembly = loadContext.LoadFromAssemblyPath(abstractionsPath); // Set Generator.CurrentGenerator to a fresh RecordingGeneratorsFactory in the loaded assembly - var generatorStaticType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.Generator"); - var recordingFactoryType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.RecordingGeneratorsFactory"); + Type? generatorStaticType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.Generator"); + Type? recordingFactoryType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.RecordingGeneratorsFactory"); if (generatorStaticType == null || recordingFactoryType == null) return (null, "Could not find Generator or RecordingGeneratorsFactory types in Abstractions assembly"); - var recordingFactory = Activator.CreateInstance(recordingFactoryType); - var currentGeneratorProp = generatorStaticType.GetProperty("CurrentGenerator", + object? recordingFactory = Activator.CreateInstance(recordingFactoryType); + PropertyInfo? currentGeneratorProp = generatorStaticType.GetProperty("CurrentGenerator", BindingFlags.Public | BindingFlags.Static); currentGeneratorProp?.SetValue(null, recordingFactory); // Execute the generator method - var typeName = generatorMethod.ContainingType.ToDisplayString(); - var type = assembly.GetType(typeName); + string typeName = generatorMethod.ContainingType.ToDisplayString(); + Type? type = assembly.GetType(typeName); if (type == null) return (null, $"Could not find type '{typeName}' in compiled assembly"); - var method = type.GetMethod(generatorMethod.Name, + MethodInfo? method = type.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public); if (method == null) return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'"); @@ -556,25 +470,25 @@ private static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMet method.Invoke(null, null); // Read the recorded switch body from the factory - var lastRecordProp = recordingFactoryType.GetProperty("LastRecord"); - var lastRecord = lastRecordProp?.GetValue(recordingFactory); + PropertyInfo? lastRecordProp = recordingFactoryType.GetProperty("LastRecord"); + object? lastRecord = lastRecordProp?.GetValue(recordingFactory); if (lastRecord == null) return (null, "RecordingGeneratorsFactory did not produce a record"); - var recordType = lastRecord.GetType(); - var caseKeysProp = recordType.GetProperty("CaseKeys"); - var caseValuesProp = recordType.GetProperty("CaseValues"); - var hasDefaultProp = recordType.GetProperty("HasDefaultCase"); + Type recordType = lastRecord.GetType(); + PropertyInfo? caseKeysProp = recordType.GetProperty("CaseKeys"); + PropertyInfo? caseValuesProp = recordType.GetProperty("CaseValues"); + PropertyInfo? hasDefaultProp = recordType.GetProperty("HasDefaultCase"); - var caseKeys = (caseKeysProp?.GetValue(lastRecord) as IList) ?? new List(); - var caseValues = (caseValuesProp?.GetValue(lastRecord) as IList) ?? new List(); - var hasDefault = (bool)(hasDefaultProp?.GetValue(lastRecord) ?? false); + IList caseKeys = (caseKeysProp?.GetValue(lastRecord) as IList) ?? new List(); + IList caseValues = (caseValuesProp?.GetValue(lastRecord) as IList) ?? new List(); + bool hasDefault = (bool)(hasDefaultProp?.GetValue(lastRecord) ?? false); - var pairs = new List<(object, string)>(); + List<(object, string)> pairs = new(); for (int i = 0; i < caseKeys.Count; i++) { - var k = caseKeys[i]!; - var v = i < caseValues.Count ? caseValues[i]?.ToString() : null; + object k = caseKeys[i]!; + string? v = i < caseValues.Count ? caseValues[i]?.ToString() : null; pairs.Add((k, FormatCaseValue(v, partialMethod.ReturnType))); } @@ -596,14 +510,14 @@ private static (string? value, string? error) ExecuteGeneratorMethodWithArgs( Compilation compilation, object?[]? args) { - var dllCompilation = BuildExecutionCompilation(allPartialMethods, compilation); + CSharpCompilation dllCompilation = BuildExecutionCompilation(allPartialMethods, compilation); - using var ms = new MemoryStream(); - var emitResult = dllCompilation.Emit(ms); + using MemoryStream ms = new(); + EmitResult emitResult = dllCompilation.Emit(ms); if (!emitResult.Success) { - var errors = string.Join("; ", emitResult.Diagnostics + string errors = string.Join("; ", emitResult.Diagnostics .Where(d => d.Severity == DiagnosticSeverity.Error) .Select(d => d.GetMessage())); return (null, $"Compilation failed: {errors}"); @@ -616,7 +530,7 @@ private static (string? value, string? error) ExecuteGeneratorMethodWithArgs( loadContext = new AssemblyLoadContext("__GeneratorExec", isCollectible: true); loadContext.Resolving += (ctx, assemblyName) => { - var match = compilation.References + PortableExecutableReference? match = compilation.References .OfType() .FirstOrDefault(r => string.Equals( Path.GetFileNameWithoutExtension(r.FilePath), @@ -625,14 +539,14 @@ private static (string? value, string? error) ExecuteGeneratorMethodWithArgs( return match?.FilePath != null ? ctx.LoadFromAssemblyPath(match.FilePath) : null; }; - var assembly = loadContext.LoadFromStream(ms); - var typeName = generatorMethod.ContainingType.ToDisplayString(); - var type = assembly.GetType(typeName); + Assembly assembly = loadContext.LoadFromStream(ms); + string typeName = generatorMethod.ContainingType.ToDisplayString(); + Type? type = assembly.GetType(typeName); if (type == null) return (null, $"Could not find type '{typeName}' in compiled assembly"); - var method = type.GetMethod( + MethodInfo? method = type.GetMethod( generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public); @@ -643,11 +557,11 @@ private static (string? value, string? error) ExecuteGeneratorMethodWithArgs( object?[]? convertedArgs = null; if (args != null && method.GetParameters().Length > 0) { - var paramType = method.GetParameters()[0].ParameterType; + Type paramType = method.GetParameters()[0].ParameterType; convertedArgs = new[] { Convert.ChangeType(args[0], paramType) }; } - var result = method.Invoke(null, convertedArgs); + object? result = method.Invoke(null, convertedArgs); return (result?.ToString(), null); } catch (Exception ex) @@ -668,8 +582,8 @@ private static string ResolveImplementationAssemblyPath(string path) { // Reference assemblies are often placed in a "ref" subdirectory // e.g. .../bin/Debug/net10.0/ref/Foo.dll → try .../bin/Debug/net10.0/Foo.dll - var dir = Path.GetDirectoryName(path); - var parentDir = dir != null ? Path.GetDirectoryName(dir) : null; + string? dir = Path.GetDirectoryName(path); + string? parentDir = dir != null ? Path.GetDirectoryName(dir) : null; if (dir != null && parentDir != null && string.Equals(Path.GetFileName(dir), "ref", StringComparison.OrdinalIgnoreCase)) { @@ -683,15 +597,15 @@ private static string ResolveImplementationAssemblyPath(string path) /// private static IReadOnlyList GetAllUnimplementedPartialMethods(Compilation compilation) { - var result = new List(); - foreach (var syntaxTree in compilation.SyntaxTrees) + List result = new(); + foreach (SyntaxTree syntaxTree in compilation.SyntaxTrees) { - var semanticModel = compilation.GetSemanticModel(syntaxTree); - var partialDecls = syntaxTree.GetRoot().DescendantNodes() + SemanticModel semanticModel = compilation.GetSemanticModel(syntaxTree); + IEnumerable partialDecls = syntaxTree.GetRoot().DescendantNodes() .OfType() .Where(m => m.Modifiers.Any(mod => mod.IsKind(SyntaxKind.PartialKeyword))); - foreach (var decl in partialDecls) + foreach (MethodDeclarationSyntax decl in partialDecls) { if (semanticModel.GetDeclaredSymbol(decl) is IMethodSymbol sym && sym.IsPartialDefinition) result.Add(sym); @@ -704,8 +618,8 @@ private static CSharpCompilation BuildExecutionCompilation( IReadOnlyList allPartialMethods, Compilation compilation) { - var dummySource = BuildDummyImplementation(allPartialMethods); - var parseOptions = compilation.SyntaxTrees.FirstOrDefault()?.Options as CSharpParseOptions + string dummySource = BuildDummyImplementation(allPartialMethods); + CSharpParseOptions parseOptions = compilation.SyntaxTrees.FirstOrDefault()?.Options as CSharpParseOptions ?? CSharpParseOptions.Default; return (CSharpCompilation)compilation .WithOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)) @@ -714,9 +628,9 @@ private static CSharpCompilation BuildExecutionCompilation( private static string BuildDummyImplementation(IEnumerable partialMethods) { - var sb = new StringBuilder(); + StringBuilder sb = new(); - var grouped = partialMethods.GroupBy( + IEnumerable> grouped = partialMethods.GroupBy( m => (Namespace: m.ContainingType.ContainingNamespace?.IsGlobalNamespace == false ? m.ContainingType.ContainingNamespace.ToDisplayString() : null, @@ -724,24 +638,24 @@ private static string BuildDummyImplementation(IEnumerable partia IsStatic: m.ContainingType.IsStatic, TypeKind: m.ContainingType.TypeKind)); - foreach (var typeGroup in grouped) + foreach (IGrouping<(string? Namespace, string TypeName, bool IsStatic, TypeKind TypeKind), IMethodSymbol> typeGroup in grouped) { - var namespaceName = typeGroup.Key.Namespace; + string? namespaceName = typeGroup.Key.Namespace; if (namespaceName != null) sb.AppendLine($"namespace {namespaceName} {{"); - var typeKeyword = typeGroup.Key.TypeKind switch + string typeKeyword = typeGroup.Key.TypeKind switch { TypeKind.Struct => "struct", _ => "class" }; - var typeModifiers = typeGroup.Key.IsStatic ? "static partial" : "partial"; + string typeModifiers = typeGroup.Key.IsStatic ? "static partial" : "partial"; sb.AppendLine($"{typeModifiers} {typeKeyword} {typeGroup.Key.TypeName} {{"); - foreach (var partialMethod in typeGroup) + foreach (IMethodSymbol partialMethod in typeGroup) { - var accessibility = partialMethod.DeclaredAccessibility switch + string accessibility = partialMethod.DeclaredAccessibility switch { Accessibility.Public => "public", Accessibility.Protected => "protected", @@ -751,9 +665,9 @@ private static string BuildDummyImplementation(IEnumerable partia _ => "" }; - var staticModifier = partialMethod.IsStatic ? "static " : ""; - var returnType = partialMethod.ReturnType.ToDisplayString(); - var parameters = string.Join(", ", partialMethod.Parameters.Select(p => $"{p.Type.ToDisplayString()} {p.Name}")); + string staticModifier = partialMethod.IsStatic ? "static " : ""; + string returnType = partialMethod.ReturnType.ToDisplayString(); + string parameters = string.Join(", ", partialMethod.Parameters.Select(p => $"{p.Type.ToDisplayString()} {p.Name}")); sb.AppendLine($"{accessibility} {staticModifier}partial {returnType} {partialMethod.Name}({parameters}) {{"); if (!partialMethod.ReturnsVoid) @@ -769,4 +683,4 @@ private static string BuildDummyImplementation(IEnumerable partia return sb.ToString(); } -} \ No newline at end of file +} From 2eb1a2ba85c66328e782ca58006230eb7fb84c81 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:08:16 +0000 Subject: [PATCH 3/6] chore: address review feedback and finalize generator refactor Co-authored-by: dex3r <3155725+dex3r@users.noreply.github.com> --- .../GeneratesMethodGenerationTargetCollector.cs | 2 +- MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/MattSourceGenHelpers.Generators/GeneratesMethodGenerationTargetCollector.cs b/MattSourceGenHelpers.Generators/GeneratesMethodGenerationTargetCollector.cs index 1807c2a..6a37ec2 100644 --- a/MattSourceGenHelpers.Generators/GeneratesMethodGenerationTargetCollector.cs +++ b/MattSourceGenHelpers.Generators/GeneratesMethodGenerationTargetCollector.cs @@ -56,7 +56,7 @@ internal static List Collect( } string? targetMethodName = attribute.ConstructorArguments[0].Value?.ToString(); - if (string.IsNullOrWhiteSpace(targetMethodName)) + if (string.IsNullOrEmpty(targetMethodName)) { continue; } diff --git a/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs b/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs index 736cb8b..d24baf5 100644 --- a/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs +++ b/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs @@ -21,8 +21,8 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .CreateSyntaxProvider( predicate: IsMethodWithGeneratesMethodAttribute, transform: GetMethodDeclaration) - .Where(m => m != null) - .Collect(); + .Where(m => m != null) + .Collect(); context.RegisterSourceOutput( methodsWithAttribute.Combine(context.CompilationProvider), From 681caa01af6767b18db2d0a42090c45134158b12 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:44:45 +0000 Subject: [PATCH 4/6] refactor: shrink GeneratesMethodGenerator and extract runtime helpers Co-authored-by: dex3r <3155725+dex3r@users.noreply.github.com> --- .../GeneratesMethodExecutionRuntime.cs | 331 +++++++++ .../GeneratesMethodGenerationPipeline.cs | 111 +++ .../GeneratesMethodGenerator.cs | 658 +----------------- .../GeneratesMethodPatternSourceBuilder.cs | 260 +++++++ 4 files changed, 709 insertions(+), 651 deletions(-) create mode 100644 MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs create mode 100644 MattSourceGenHelpers.Generators/GeneratesMethodGenerationPipeline.cs create mode 100644 MattSourceGenHelpers.Generators/GeneratesMethodPatternSourceBuilder.cs diff --git a/MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs b/MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs new file mode 100644 index 0000000..9c7c311 --- /dev/null +++ b/MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs @@ -0,0 +1,331 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using Microsoft.CodeAnalysis.Emit; +using System.Collections; +using System.Reflection; +using System.Runtime.Loader; +using System.Text; + +namespace MattSourceGenHelpers.Generators; + +internal sealed record SwitchBodyData( + IReadOnlyList<(object key, string value)> CasePairs, + bool HasDefaultCase); + +internal static class GeneratesMethodExecutionRuntime +{ + internal static (string? value, string? error) ExecuteSimpleGeneratorMethod( + IMethodSymbol generatorMethod, + IMethodSymbol partialMethod, + Compilation compilation) + { + IReadOnlyList allPartials = GetAllUnimplementedPartialMethods(compilation); + return ExecuteGeneratorMethodWithArgs(generatorMethod, allPartials, compilation, null); + } + + internal static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMethod( + IMethodSymbol generatorMethod, + IMethodSymbol partialMethod, + Compilation compilation) + { + IReadOnlyList allPartials = GetAllUnimplementedPartialMethods(compilation); + CSharpCompilation runtimeCompilation = BuildExecutionCompilation(allPartials, compilation); + + using MemoryStream stream = new(); + EmitResult emitResult = runtimeCompilation.Emit(stream); + if (!emitResult.Success) + { + string errors = string.Join("; ", emitResult.Diagnostics + .Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error) + .Select(diagnostic => diagnostic.GetMessage())); + return (null, $"Compilation failed: {errors}"); + } + + stream.Position = 0; + AssemblyLoadContext? loadContext = null; + try + { + loadContext = new AssemblyLoadContext("__GeneratorExec", isCollectible: true); + loadContext.Resolving += (context, assemblyName) => + { + PortableExecutableReference? match = compilation.References + .OfType() + .FirstOrDefault(reference => string.Equals( + Path.GetFileNameWithoutExtension(reference.FilePath), + assemblyName.Name, + StringComparison.OrdinalIgnoreCase)); + return match?.FilePath != null ? context.LoadFromAssemblyPath(match.FilePath) : null; + }; + + Assembly assembly = loadContext.LoadFromStream(stream); + + PortableExecutableReference? abstractionsReference = compilation.References + .OfType() + .FirstOrDefault(reference => string.Equals( + Path.GetFileNameWithoutExtension(reference.FilePath), + "MattSourceGenHelpers.Abstractions", + StringComparison.OrdinalIgnoreCase)); + + if (abstractionsReference?.FilePath == null) + { + return (null, "Could not find MattSourceGenHelpers.Abstractions reference in compilation"); + } + + string abstractionsAssemblyPath = ResolveImplementationAssemblyPath(abstractionsReference.FilePath); + Assembly abstractionsAssembly = loadContext.LoadFromAssemblyPath(abstractionsAssemblyPath); + + Type? generatorStaticType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.Generator"); + Type? recordingFactoryType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.RecordingGeneratorsFactory"); + if (generatorStaticType == null || recordingFactoryType == null) + { + return (null, "Could not find Generator or RecordingGeneratorsFactory types in Abstractions assembly"); + } + + object? recordingFactory = Activator.CreateInstance(recordingFactoryType); + PropertyInfo? currentGeneratorProperty = generatorStaticType.GetProperty("CurrentGenerator", BindingFlags.Public | BindingFlags.Static); + currentGeneratorProperty?.SetValue(null, recordingFactory); + + string typeName = generatorMethod.ContainingType.ToDisplayString(); + Type? generatedType = assembly.GetType(typeName); + if (generatedType == null) + { + return (null, $"Could not find type '{typeName}' in compiled assembly"); + } + + MethodInfo? generatorMethodInfo = generatedType.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public); + if (generatorMethodInfo == null) + { + return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'"); + } + + generatorMethodInfo.Invoke(null, null); + + PropertyInfo? lastRecordProperty = recordingFactoryType.GetProperty("LastRecord"); + object? lastRecord = lastRecordProperty?.GetValue(recordingFactory); + if (lastRecord == null) + { + return (null, "RecordingGeneratorsFactory did not produce a record"); + } + + return (ExtractSwitchBodyData(lastRecord, partialMethod.ReturnType), null); + } + catch (Exception ex) + { + return (null, $"Error executing generator method '{generatorMethod.Name}': {ex.GetBaseException()}"); + } + finally + { + loadContext?.Unload(); + } + } + + internal static (string? value, string? error) ExecuteGeneratorMethodWithArgs( + IMethodSymbol generatorMethod, + IReadOnlyList allPartialMethods, + Compilation compilation, + object?[]? args) + { + CSharpCompilation runtimeCompilation = BuildExecutionCompilation(allPartialMethods, compilation); + + using MemoryStream stream = new(); + EmitResult emitResult = runtimeCompilation.Emit(stream); + if (!emitResult.Success) + { + string errors = string.Join("; ", emitResult.Diagnostics + .Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error) + .Select(diagnostic => diagnostic.GetMessage())); + return (null, $"Compilation failed: {errors}"); + } + + stream.Position = 0; + AssemblyLoadContext? loadContext = null; + try + { + loadContext = new AssemblyLoadContext("__GeneratorExec", isCollectible: true); + loadContext.Resolving += (context, assemblyName) => + { + PortableExecutableReference? match = compilation.References + .OfType() + .FirstOrDefault(reference => string.Equals( + Path.GetFileNameWithoutExtension(reference.FilePath), + assemblyName.Name, + StringComparison.OrdinalIgnoreCase)); + return match?.FilePath != null ? context.LoadFromAssemblyPath(match.FilePath) : null; + }; + + Assembly assembly = loadContext.LoadFromStream(stream); + string typeName = generatorMethod.ContainingType.ToDisplayString(); + Type? generatedType = assembly.GetType(typeName); + if (generatedType == null) + { + return (null, $"Could not find type '{typeName}' in compiled assembly"); + } + + MethodInfo? generatorMethodInfo = generatedType.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public); + if (generatorMethodInfo == null) + { + return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'"); + } + + object?[]? convertedArgs = ConvertArguments(args, generatorMethodInfo); + object? result = generatorMethodInfo.Invoke(null, convertedArgs); + return (result?.ToString(), null); + } + catch (Exception ex) + { + return (null, $"Error executing generator method '{generatorMethod.Name}': {ex.GetBaseException()}"); + } + finally + { + loadContext?.Unload(); + } + } + + internal static IReadOnlyList GetAllUnimplementedPartialMethods(Compilation compilation) + { + List methods = new(); + foreach (SyntaxTree syntaxTree in compilation.SyntaxTrees) + { + SemanticModel semanticModel = compilation.GetSemanticModel(syntaxTree); + IEnumerable partialMethodDeclarations = syntaxTree.GetRoot().DescendantNodes() + .OfType() + .Where(method => method.Modifiers.Any(modifier => modifier.IsKind(SyntaxKind.PartialKeyword))); + + foreach (MethodDeclarationSyntax declaration in partialMethodDeclarations) + { + if (semanticModel.GetDeclaredSymbol(declaration) is IMethodSymbol symbol && symbol.IsPartialDefinition) + { + methods.Add(symbol); + } + } + } + + return methods; + } + + private static object?[]? ConvertArguments(object?[]? args, MethodInfo methodInfo) + { + if (args == null || methodInfo.GetParameters().Length == 0) + { + return null; + } + + Type parameterType = methodInfo.GetParameters()[0].ParameterType; + return new[] { Convert.ChangeType(args[0], parameterType) }; + } + + private static SwitchBodyData ExtractSwitchBodyData(object lastRecord, ITypeSymbol returnType) + { + Type recordType = lastRecord.GetType(); + PropertyInfo? caseKeysProperty = recordType.GetProperty("CaseKeys"); + PropertyInfo? caseValuesProperty = recordType.GetProperty("CaseValues"); + PropertyInfo? hasDefaultProperty = recordType.GetProperty("HasDefaultCase"); + + IList caseKeys = (caseKeysProperty?.GetValue(lastRecord) as IList) ?? new List(); + IList caseValues = (caseValuesProperty?.GetValue(lastRecord) as IList) ?? new List(); + bool hasDefaultCase = (bool)(hasDefaultProperty?.GetValue(lastRecord) ?? false); + + List<(object key, string value)> pairs = new(); + for (int index = 0; index < caseKeys.Count; index++) + { + object key = caseKeys[index]!; + string? value = index < caseValues.Count ? caseValues[index]?.ToString() : null; + pairs.Add((key, GeneratesMethodPatternSourceBuilder.FormatValueAsCSharpLiteral(value, returnType))); + } + + return new SwitchBodyData(pairs, hasDefaultCase); + } + + private static string ResolveImplementationAssemblyPath(string path) + { + string? directory = Path.GetDirectoryName(path); + string? parentDirectory = directory != null ? Path.GetDirectoryName(directory) : null; + if (directory != null && + parentDirectory != null && + string.Equals(Path.GetFileName(directory), "ref", StringComparison.OrdinalIgnoreCase)) + { + return Path.Combine(parentDirectory, Path.GetFileName(path)); + } + + return path; + } + + private static CSharpCompilation BuildExecutionCompilation( + IReadOnlyList allPartialMethods, + Compilation compilation) + { + string dummySource = BuildDummyImplementation(allPartialMethods); + CSharpParseOptions parseOptions = compilation.SyntaxTrees.FirstOrDefault()?.Options as CSharpParseOptions + ?? CSharpParseOptions.Default; + + return (CSharpCompilation)compilation + .WithOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)) + .AddSyntaxTrees(CSharpSyntaxTree.ParseText(dummySource, parseOptions)); + } + + private static string BuildDummyImplementation(IEnumerable partialMethods) + { + StringBuilder builder = new(); + + IEnumerable> groupedMethods = partialMethods.GroupBy( + method => (Namespace: method.ContainingType.ContainingNamespace?.IsGlobalNamespace == false + ? method.ContainingType.ContainingNamespace.ToDisplayString() + : null, + TypeName: method.ContainingType.Name, + IsStatic: method.ContainingType.IsStatic, + TypeKind: method.ContainingType.TypeKind)); + + foreach (IGrouping<(string? Namespace, string TypeName, bool IsStatic, TypeKind TypeKind), IMethodSymbol> typeGroup in groupedMethods) + { + string? namespaceName = typeGroup.Key.Namespace; + if (namespaceName != null) + { + builder.AppendLine($"namespace {namespaceName} {{"); + } + + string typeKeyword = typeGroup.Key.TypeKind switch + { + TypeKind.Struct => "struct", + _ => "class" + }; + + string typeModifiers = typeGroup.Key.IsStatic ? "static partial" : "partial"; + builder.AppendLine($"{typeModifiers} {typeKeyword} {typeGroup.Key.TypeName} {{"); + + foreach (IMethodSymbol partialMethod in typeGroup) + { + string accessibility = partialMethod.DeclaredAccessibility switch + { + Accessibility.Public => "public", + Accessibility.Protected => "protected", + Accessibility.Internal => "internal", + Accessibility.ProtectedOrInternal => "protected internal", + Accessibility.ProtectedAndInternal => "private protected", + _ => string.Empty + }; + + string staticModifier = partialMethod.IsStatic ? "static " : string.Empty; + string returnType = partialMethod.ReturnType.ToDisplayString(); + string parameters = string.Join(", ", partialMethod.Parameters.Select(parameter => $"{parameter.Type.ToDisplayString()} {parameter.Name}")); + + builder.AppendLine($"{accessibility} {staticModifier}partial {returnType} {partialMethod.Name}({parameters}) {{"); + if (!partialMethod.ReturnsVoid) + { + builder.AppendLine("return default!;"); + } + + builder.AppendLine("}"); + } + + builder.AppendLine("}"); + + if (namespaceName != null) + { + builder.AppendLine("}"); + } + } + + return builder.ToString(); + } +} diff --git a/MattSourceGenHelpers.Generators/GeneratesMethodGenerationPipeline.cs b/MattSourceGenHelpers.Generators/GeneratesMethodGenerationPipeline.cs new file mode 100644 index 0000000..bad2de8 --- /dev/null +++ b/MattSourceGenHelpers.Generators/GeneratesMethodGenerationPipeline.cs @@ -0,0 +1,111 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using System.Collections.Immutable; + +namespace MattSourceGenHelpers.Generators; + +internal static class GeneratesMethodGenerationPipeline +{ + private const string SwitchCaseAttributeTypeName = "MattSourceGenHelpers.Abstractions.SwitchCase"; + private const string SwitchDefaultAttributeTypeName = "MattSourceGenHelpers.Abstractions.SwitchDefault"; + private const string FluentGeneratorTypeName = "MattSourceGenHelpers.Abstractions.IMethodImplementationGenerator"; + + internal static void Execute( + SourceProductionContext context, + ImmutableArray generatorMethods, + Compilation compilation) + { + List validMethods = GeneratesMethodGenerationTargetCollector.Collect(context, generatorMethods, compilation); + IReadOnlyList allPartials = GeneratesMethodExecutionRuntime.GetAllUnimplementedPartialMethods(compilation); + + IEnumerable> groups = validMethods + .GroupBy(method => (TypeKey: method.ContainingType.ToDisplayString(), method.TargetMethodName)); + + foreach (IGrouping<(string TypeKey, string TargetMethodName), GeneratesMethodGenerationTarget> group in groups) + { + List methods = group.ToList(); + GeneratesMethodGenerationTarget firstMethod = methods[0]; + + context.ReportDiagnostic(Diagnostic.Create( + GeneratesMethodGeneratorDiagnostics.GeneratingMethodInfo, + firstMethod.Syntax.GetLocation(), + firstMethod.TargetMethodName, + firstMethod.ContainingType.Name, + string.Join(", ", methods.Select(method => method.Symbol.Name)))); + + string source = GenerateSourceForGroup(context, methods, firstMethod, allPartials, compilation); + + if (!string.IsNullOrEmpty(source)) + { + context.AddSource($"{firstMethod.ContainingType.Name}_{firstMethod.TargetMethodName}.g.cs", source); + } + } + } + + private static string GenerateSourceForGroup( + SourceProductionContext context, + List methods, + GeneratesMethodGenerationTarget firstMethod, + IReadOnlyList allPartials, + Compilation compilation) + { + bool hasSwitchCase = methods.Any(method => HasAttribute(method.Symbol, SwitchCaseAttributeTypeName)); + bool hasSwitchDefault = methods.Any(method => HasAttribute(method.Symbol, SwitchDefaultAttributeTypeName)); + bool isFluentPattern = methods.Count == 1 && methods[0].Symbol.ReturnType.ToDisplayString() == FluentGeneratorTypeName; + + if (hasSwitchCase || hasSwitchDefault) + { + return GeneratesMethodPatternSourceBuilder.GenerateFromSwitchAttributes( + context, + methods, + firstMethod.PartialMethod, + firstMethod.ContainingType, + allPartials, + compilation); + } + + if (isFluentPattern) + { + return GeneratesMethodPatternSourceBuilder.GenerateFromFluent( + context, + methods[0], + firstMethod.PartialMethod, + firstMethod.ContainingType, + compilation); + } + + return GenerateFromSimplePattern(context, firstMethod, compilation); + } + + private static string GenerateFromSimplePattern( + SourceProductionContext context, + GeneratesMethodGenerationTarget firstMethod, + Compilation compilation) + { + (string? returnValue, string? error) = GeneratesMethodExecutionRuntime.ExecuteSimpleGeneratorMethod( + firstMethod.Symbol, + firstMethod.PartialMethod, + compilation); + + if (error != null) + { + context.ReportDiagnostic(Diagnostic.Create( + GeneratesMethodGeneratorDiagnostics.GeneratorMethodExecutionError, + firstMethod.Syntax.GetLocation(), + firstMethod.Symbol.Name, + error)); + return string.Empty; + } + + return GeneratesMethodPatternSourceBuilder.GenerateSimplePartialMethod( + firstMethod.ContainingType, + firstMethod.PartialMethod, + returnValue); + } + + private static bool HasAttribute(IMethodSymbol methodSymbol, string fullAttributeTypeName) + { + return methodSymbol.GetAttributes() + .Any(attribute => attribute.AttributeClass?.ToDisplayString() == fullAttributeTypeName); + } +} diff --git a/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs b/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs index d24baf5..1a6b30d 100644 --- a/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs +++ b/MattSourceGenHelpers.Generators/GeneratesMethodGenerator.cs @@ -1,19 +1,13 @@ using Microsoft.CodeAnalysis; -using Microsoft.CodeAnalysis.CSharp; using Microsoft.CodeAnalysis.CSharp.Syntax; -using Microsoft.CodeAnalysis.Emit; -using System.Collections; using System.Collections.Immutable; -using System.Reflection; -using System.Runtime.Loader; -using System.Text; namespace MattSourceGenHelpers.Generators; #pragma warning disable RS1041 // This generator will only work with dotnet 8 to 10 [Generator] -public class GeneratesMethodGenerator : IIncrementalGenerator +public sealed class GeneratesMethodGenerator : IIncrementalGenerator { public void Initialize(IncrementalGeneratorInitializationContext context) { @@ -21,666 +15,28 @@ public void Initialize(IncrementalGeneratorInitializationContext context) .CreateSyntaxProvider( predicate: IsMethodWithGeneratesMethodAttribute, transform: GetMethodDeclaration) - .Where(m => m != null) + .Where(method => method != null) .Collect(); context.RegisterSourceOutput( methodsWithAttribute.Combine(context.CompilationProvider), - (ctx, data) => Execute(ctx, data.Left, data.Right)); + (productionContext, data) => GeneratesMethodGenerationPipeline.Execute(productionContext, data.Left, data.Right)); } private static bool IsMethodWithGeneratesMethodAttribute(SyntaxNode node, CancellationToken _) { if (node is not MethodDeclarationSyntax method) + { return false; + } return method.AttributeLists - .SelectMany(al => al.Attributes) - .Any(a => a.Name.ToString() is "GeneratesMethod" or "GeneratesMethodAttribute"); + .SelectMany(attributeList => attributeList.Attributes) + .Any(attribute => attribute.Name.ToString() is "GeneratesMethod" or "GeneratesMethodAttribute"); } private static MethodDeclarationSyntax? GetMethodDeclaration(GeneratorSyntaxContext context, CancellationToken _) { return context.Node as MethodDeclarationSyntax; } - - private void Execute( - SourceProductionContext context, - ImmutableArray generatorMethods, - Compilation compilation) - { - List validMethods = GeneratesMethodGenerationTargetCollector.Collect(context, generatorMethods, compilation); - - // Group by (containing type display string, target method name) - IEnumerable> groups = validMethods - .GroupBy(m => (TypeKey: m.ContainingType.ToDisplayString(), m.TargetMethodName)); - - // Collect all unimplemented partial methods once for the whole compilation - IReadOnlyList allPartials = GetAllUnimplementedPartialMethods(compilation); - - foreach (IGrouping<(string TypeKey, string TargetMethodName), GeneratesMethodGenerationTarget> group in groups) - { - List methods = group.ToList(); - GeneratesMethodGenerationTarget first = methods[0]; - - context.ReportDiagnostic(Diagnostic.Create( - GeneratesMethodGeneratorDiagnostics.GeneratingMethodInfo, - first.Syntax.GetLocation(), - first.TargetMethodName, - first.ContainingType.Name, - string.Join(", ", methods.Select(m => m.Symbol.Name)))); - - // Check if this group uses the attribute-based switch pattern - bool hasSwitchCase = methods.Any(m => m.Symbol.GetAttributes() - .Any(a => a.AttributeClass?.ToDisplayString() == "MattSourceGenHelpers.Abstractions.SwitchCase")); - bool hasSwitchDefault = methods.Any(m => m.Symbol.GetAttributes() - .Any(a => a.AttributeClass?.ToDisplayString() == "MattSourceGenHelpers.Abstractions.SwitchDefault")); - - // Check if this is a fluent pattern (returns IMethodImplementationGenerator) - bool isFluentPattern = methods.Count == 1 && - methods[0].Symbol.ReturnType.ToDisplayString() == "MattSourceGenHelpers.Abstractions.IMethodImplementationGenerator"; - - string source; - if (hasSwitchCase || hasSwitchDefault) - { - source = GenerateFromSwitchAttributes(context, methods, first.PartialMethod, first.ContainingType, allPartials, compilation); - } - else if (isFluentPattern) - { - source = GenerateFromFluent(context, methods[0], first.PartialMethod, first.ContainingType, compilation); - } - else - { - // Simple pattern: execute first method and use returned value - (string? returnValue, string? error) = ExecuteSimpleGeneratorMethod(first.Symbol, first.PartialMethod, compilation); - if (error != null) - { - context.ReportDiagnostic(Diagnostic.Create( - GeneratesMethodGeneratorDiagnostics.GeneratorMethodExecutionError, - first.Syntax.GetLocation(), - first.Symbol.Name, - error)); - continue; - } - source = GenerateSimplePartialMethod(first.ContainingType, first.PartialMethod, returnValue); - } - - if (!string.IsNullOrEmpty(source)) - context.AddSource($"{first.ContainingType.Name}_{first.TargetMethodName}.g.cs", source); - } - } - - // ────────────────────────────────────────────────────────────────────────── - // Attribute-based switch pattern ([SwitchCase] / [SwitchDefault]) - // ────────────────────────────────────────────────────────────────────────── - - private static string GenerateFromSwitchAttributes( - SourceProductionContext context, - List methods, - IMethodSymbol partialMethod, - INamedTypeSymbol containingType, - IReadOnlyList allPartials, - Compilation compilation) - { - List switchCaseMethods = methods.Where(m => m.Symbol.GetAttributes() - .Any(a => a.AttributeClass?.ToDisplayString() == "MattSourceGenHelpers.Abstractions.SwitchCase")).ToList(); - GeneratesMethodGenerationTarget? switchDefaultMethod = methods.FirstOrDefault(m => m.Symbol.GetAttributes() - .Any(a => a.AttributeClass?.ToDisplayString() == "MattSourceGenHelpers.Abstractions.SwitchDefault")); - - List<(object key, string value)> cases = new(); - - // For each [SwitchCase] method, execute it for each case value - foreach (GeneratesMethodGenerationTarget switchMethod in switchCaseMethods) - { - IEnumerable switchCaseAttrs = switchMethod.Symbol.GetAttributes() - .Where(a => a.AttributeClass?.ToDisplayString() == "MattSourceGenHelpers.Abstractions.SwitchCase"); - - foreach (AttributeData attr in switchCaseAttrs) - { - if (attr.ConstructorArguments.Length == 0) continue; - object? caseArg = attr.ConstructorArguments[0].Value; - if (caseArg == null) continue; - - (string? result, string? error) = ExecuteGeneratorMethodWithArgs(switchMethod.Symbol, allPartials, compilation, new[] { caseArg }); - if (error != null) - { - context.ReportDiagnostic(Diagnostic.Create( - GeneratesMethodGeneratorDiagnostics.GeneratorMethodExecutionError, - switchMethod.Syntax.GetLocation(), - switchMethod.Symbol.Name, - error)); - continue; - } - - cases.Add((caseArg, FormatCaseValue(result, partialMethod.ReturnType))); - } - } - - // Extract default expression from the [SwitchDefault] method's syntax - string? defaultExpression = null; - if (switchDefaultMethod != null) - defaultExpression = ExtractDefaultExpressionFromSwitchDefaultMethod(switchDefaultMethod.Syntax); - - return GenerateSwitchMethodSource(containingType, partialMethod, cases, defaultExpression); - } - - /// - /// Extracts the body expression from a [SwitchDefault] method whose body is a lambda. - /// e.g. "decimalNumber => SlowMath.CalculatePiDecimal(decimalNumber)" → "SlowMath.CalculatePiDecimal(decimalNumber)" - /// - private static string? ExtractDefaultExpressionFromSwitchDefaultMethod(MethodDeclarationSyntax method) - { - // Expression body: => - ExpressionSyntax? bodyExpr = method.ExpressionBody?.Expression; - if (bodyExpr == null && method.Body != null) - { - // Block body: { return ; } - ReturnStatementSyntax? returnStmt = method.Body.Statements.OfType().FirstOrDefault(); - bodyExpr = returnStmt?.Expression; - } - return ExtractInnermostLambdaBody(bodyExpr); - } - - // ────────────────────────────────────────────────────────────────────────── - // Fluent pattern (returns IMethodImplementationGenerator) - // ────────────────────────────────────────────────────────────────────────── - - private static string GenerateFromFluent( - SourceProductionContext context, - GeneratesMethodGenerationTarget methodInfo, - IMethodSymbol partialMethod, - INamedTypeSymbol containingType, - Compilation compilation) - { - (SwitchBodyData? record, string? error) = ExecuteFluentGeneratorMethod(methodInfo.Symbol, partialMethod, compilation); - if (error != null) - { - context.ReportDiagnostic(Diagnostic.Create( - GeneratesMethodGeneratorDiagnostics.GeneratorMethodExecutionError, - methodInfo.Syntax.GetLocation(), - methodInfo.Symbol.Name, - error)); - return string.Empty; - } - - SwitchBodyData cases = record!; - - // Extract default expression from the RuntimeBody or CompileTimeBody call in the method syntax - string? defaultExpression = null; - if (cases.HasDefaultCase) - defaultExpression = ExtractDefaultExpressionFromFluentMethod(methodInfo.Syntax); - - return GenerateSwitchMethodSource(containingType, partialMethod, cases.CasePairs, defaultExpression); - } - - /// - /// Finds RuntimeBody(...) or CompileTimeBody(...) in the ForDefaultCase() chain - /// and extracts the innermost lambda body expression string. - /// - private static string? ExtractDefaultExpressionFromFluentMethod(MethodDeclarationSyntax method) - { - // Walk all InvocationExpressionSyntax nodes; find the one named RuntimeBody or CompileTimeBody - // that follows a ForDefaultCase() call. - IEnumerable invocations = method.DescendantNodes().OfType(); - foreach (InvocationExpressionSyntax inv in invocations) - { - if (inv.Expression is not MemberAccessExpressionSyntax ma) continue; - string name = ma.Name.Identifier.Text; - if (name is not ("RuntimeBody" or "CompileTimeBody")) continue; - - ExpressionSyntax? arg = inv.ArgumentList.Arguments.FirstOrDefault()?.Expression; - return ExtractInnermostLambdaBody(arg); - } - return null; - } - - /// - /// Recursively unwraps nested lambdas and returns the body of the innermost one. - /// e.g. "x => () => Foo(x)" → "Foo(x)" - /// "x => Foo(x)" → "Foo(x)" - /// - private static string? ExtractInnermostLambdaBody(ExpressionSyntax? expr) - { - while (true) - { - switch (expr) - { - case SimpleLambdaExpressionSyntax simple: - expr = simple.Body as ExpressionSyntax; - break; - case ParenthesizedLambdaExpressionSyntax paren: - expr = paren.Body as ExpressionSyntax; - break; - default: - return expr?.ToString(); - } - } - } - - // ────────────────────────────────────────────────────────────────────────── - // Source generation helpers - // ────────────────────────────────────────────────────────────────────────── - - private static string GenerateSwitchMethodSource( - INamedTypeSymbol containingType, - IMethodSymbol partialMethod, - IReadOnlyList<(object key, string value)> cases, - string? defaultExpression) - { - StringBuilder sb = new(); - - AppendNamespaceAndTypeHeader(sb, containingType, partialMethod); - - string paramName = partialMethod.Parameters.Length > 0 ? partialMethod.Parameters[0].Name : "arg"; - - sb.AppendLine($" switch ({paramName})"); - sb.AppendLine(" {"); - - foreach ((object key, string value) in cases) - sb.AppendLine($" case {key}: return {value};"); - - if (defaultExpression != null) - sb.AppendLine($" default: return {defaultExpression};"); - - sb.AppendLine(" }"); - sb.AppendLine(" }"); - sb.AppendLine("}"); - - return sb.ToString(); - } - - private static void AppendNamespaceAndTypeHeader(StringBuilder sb, INamedTypeSymbol containingType, IMethodSymbol partialMethod) - { - string? namespaceName = containingType.ContainingNamespace?.IsGlobalNamespace == false - ? containingType.ContainingNamespace.ToDisplayString() - : null; - - if (namespaceName != null) - { - sb.AppendLine($"namespace {namespaceName};"); - sb.AppendLine(); - } - - string typeKeyword = containingType.TypeKind switch - { - TypeKind.Struct => "struct", - TypeKind.Interface => "interface", - _ => "class" - }; - - string typeModifiers = containingType.IsStatic ? "static partial" : "partial"; - sb.AppendLine($"{typeModifiers} {typeKeyword} {containingType.Name}"); - sb.AppendLine("{"); - - string accessibility = partialMethod.DeclaredAccessibility switch - { - Accessibility.Public => "public", - Accessibility.Protected => "protected", - Accessibility.Internal => "internal", - Accessibility.ProtectedOrInternal => "protected internal", - Accessibility.ProtectedAndInternal => "private protected", - _ => "private" - }; - - string returnTypeName = partialMethod.ReturnType.ToDisplayString(); - string methodName = partialMethod.Name; - string parameters = string.Join(", ", partialMethod.Parameters.Select(p => $"{p.Type.ToDisplayString()} {p.Name}")); - - string methodModifiers = partialMethod.IsStatic ? "static partial" : "partial"; - sb.AppendLine($" {accessibility} {methodModifiers} {returnTypeName} {methodName}({parameters})"); - sb.AppendLine(" {"); - } - - private static string FormatCaseValue(string? value, ITypeSymbol returnType) - { - if (value == null) return "default"; - return returnType.SpecialType switch - { - SpecialType.System_String => SyntaxFactory.Literal(value).Text, - SpecialType.System_Char when value.Length == 1 => SyntaxFactory.Literal(value[0]).Text, - SpecialType.System_Boolean => value.ToLowerInvariant(), - _ => value - }; - } - - // ────────────────────────────────────────────────────────────────────────── - // Simple pattern (existing behaviour) - // ────────────────────────────────────────────────────────────────────────── - - private static string GenerateSimplePartialMethod( - INamedTypeSymbol containingType, - IMethodSymbol partialMethod, - string? returnValue) - { - StringBuilder sb = new(); - - AppendNamespaceAndTypeHeader(sb, containingType, partialMethod); - - if (!partialMethod.ReturnsVoid) - { - string literal = FormatCaseValue(returnValue, partialMethod.ReturnType); - sb.AppendLine($" return {literal};"); - } - - sb.AppendLine(" }"); - sb.AppendLine("}"); - - return sb.ToString(); - } - - // ────────────────────────────────────────────────────────────────────────── - // Compilation / execution helpers - // ────────────────────────────────────────────────────────────────────────── - - private static (string? value, string? error) ExecuteSimpleGeneratorMethod( - IMethodSymbol generatorMethod, - IMethodSymbol partialMethod, - Compilation compilation) - { - IReadOnlyList allPartials = GetAllUnimplementedPartialMethods(compilation); - (string? result, string? error) = ExecuteGeneratorMethodWithArgs(generatorMethod, allPartials, compilation, null); - return (result, error); - } - - /// - /// Holds the recorded switch body data extracted via reflection from the loaded assembly. - /// - private record SwitchBodyData( - IReadOnlyList<(object key, string value)> CasePairs, - bool HasDefaultCase); - - private static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMethod( - IMethodSymbol generatorMethod, - IMethodSymbol partialMethod, - Compilation compilation) - { - IReadOnlyList allPartials = GetAllUnimplementedPartialMethods(compilation); - CSharpCompilation dllCompilation = BuildExecutionCompilation(allPartials, compilation); - - using MemoryStream ms = new(); - EmitResult emitResult = dllCompilation.Emit(ms); - if (!emitResult.Success) - { - string errors = string.Join("; ", emitResult.Diagnostics - .Where(d => d.Severity == DiagnosticSeverity.Error) - .Select(d => d.GetMessage())); - return (null, $"Compilation failed: {errors}"); - } - - ms.Position = 0; - AssemblyLoadContext? loadContext = null; - try - { - loadContext = new AssemblyLoadContext("__GeneratorExec", isCollectible: true); - loadContext.Resolving += (ctx, assemblyName) => - { - PortableExecutableReference? match = compilation.References - .OfType() - .FirstOrDefault(r => string.Equals( - Path.GetFileNameWithoutExtension(r.FilePath), - assemblyName.Name, - StringComparison.OrdinalIgnoreCase)); - return match?.FilePath != null ? ctx.LoadFromAssemblyPath(match.FilePath) : null; - }; - - Assembly assembly = loadContext.LoadFromStream(ms); - - // The Generator and RecordingGeneratorsFactory types are in the Abstractions assembly - // (a referenced assembly), not in the compiled user code assembly. - // The compilation reference might point to a reference assembly (metadata-only), - // so we try to find the actual implementation DLL. - PortableExecutableReference? abstractionsRef = compilation.References - .OfType() - .FirstOrDefault(r => string.Equals( - Path.GetFileNameWithoutExtension(r.FilePath), - "MattSourceGenHelpers.Abstractions", - StringComparison.OrdinalIgnoreCase)); - - if (abstractionsRef?.FilePath == null) - return (null, "Could not find MattSourceGenHelpers.Abstractions reference in compilation"); - - // If path is a reference assembly (in a "ref" subdirectory), resolve the implementation DLL - string abstractionsPath = ResolveImplementationAssemblyPath(abstractionsRef.FilePath); - - Assembly abstractionsAssembly = loadContext.LoadFromAssemblyPath(abstractionsPath); - - // Set Generator.CurrentGenerator to a fresh RecordingGeneratorsFactory in the loaded assembly - Type? generatorStaticType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.Generator"); - Type? recordingFactoryType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.RecordingGeneratorsFactory"); - - if (generatorStaticType == null || recordingFactoryType == null) - return (null, "Could not find Generator or RecordingGeneratorsFactory types in Abstractions assembly"); - - object? recordingFactory = Activator.CreateInstance(recordingFactoryType); - PropertyInfo? currentGeneratorProp = generatorStaticType.GetProperty("CurrentGenerator", - BindingFlags.Public | BindingFlags.Static); - currentGeneratorProp?.SetValue(null, recordingFactory); - - // Execute the generator method - string typeName = generatorMethod.ContainingType.ToDisplayString(); - Type? type = assembly.GetType(typeName); - if (type == null) - return (null, $"Could not find type '{typeName}' in compiled assembly"); - - MethodInfo? method = type.GetMethod(generatorMethod.Name, - BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public); - if (method == null) - return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'"); - - method.Invoke(null, null); - - // Read the recorded switch body from the factory - PropertyInfo? lastRecordProp = recordingFactoryType.GetProperty("LastRecord"); - object? lastRecord = lastRecordProp?.GetValue(recordingFactory); - if (lastRecord == null) - return (null, "RecordingGeneratorsFactory did not produce a record"); - - Type recordType = lastRecord.GetType(); - PropertyInfo? caseKeysProp = recordType.GetProperty("CaseKeys"); - PropertyInfo? caseValuesProp = recordType.GetProperty("CaseValues"); - PropertyInfo? hasDefaultProp = recordType.GetProperty("HasDefaultCase"); - - IList caseKeys = (caseKeysProp?.GetValue(lastRecord) as IList) ?? new List(); - IList caseValues = (caseValuesProp?.GetValue(lastRecord) as IList) ?? new List(); - bool hasDefault = (bool)(hasDefaultProp?.GetValue(lastRecord) ?? false); - - List<(object, string)> pairs = new(); - for (int i = 0; i < caseKeys.Count; i++) - { - object k = caseKeys[i]!; - string? v = i < caseValues.Count ? caseValues[i]?.ToString() : null; - pairs.Add((k, FormatCaseValue(v, partialMethod.ReturnType))); - } - - return (new SwitchBodyData(pairs, hasDefault), null); - } - catch (Exception ex) - { - return (null, $"Error executing generator method '{generatorMethod.Name}': {ex.GetBaseException()}"); - } - finally - { - loadContext?.Unload(); - } - } - - private static (string? value, string? error) ExecuteGeneratorMethodWithArgs( - IMethodSymbol generatorMethod, - IReadOnlyList allPartialMethods, - Compilation compilation, - object?[]? args) - { - CSharpCompilation dllCompilation = BuildExecutionCompilation(allPartialMethods, compilation); - - using MemoryStream ms = new(); - EmitResult emitResult = dllCompilation.Emit(ms); - - if (!emitResult.Success) - { - string errors = string.Join("; ", emitResult.Diagnostics - .Where(d => d.Severity == DiagnosticSeverity.Error) - .Select(d => d.GetMessage())); - return (null, $"Compilation failed: {errors}"); - } - - ms.Position = 0; - AssemblyLoadContext? loadContext = null; - try - { - loadContext = new AssemblyLoadContext("__GeneratorExec", isCollectible: true); - loadContext.Resolving += (ctx, assemblyName) => - { - PortableExecutableReference? match = compilation.References - .OfType() - .FirstOrDefault(r => string.Equals( - Path.GetFileNameWithoutExtension(r.FilePath), - assemblyName.Name, - StringComparison.OrdinalIgnoreCase)); - return match?.FilePath != null ? ctx.LoadFromAssemblyPath(match.FilePath) : null; - }; - - Assembly assembly = loadContext.LoadFromStream(ms); - string typeName = generatorMethod.ContainingType.ToDisplayString(); - Type? type = assembly.GetType(typeName); - - if (type == null) - return (null, $"Could not find type '{typeName}' in compiled assembly"); - - MethodInfo? method = type.GetMethod( - generatorMethod.Name, - BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public); - - if (method == null) - return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'"); - - // Convert args to match the method's parameter types - object?[]? convertedArgs = null; - if (args != null && method.GetParameters().Length > 0) - { - Type paramType = method.GetParameters()[0].ParameterType; - convertedArgs = new[] { Convert.ChangeType(args[0], paramType) }; - } - - object? result = method.Invoke(null, convertedArgs); - return (result?.ToString(), null); - } - catch (Exception ex) - { - return (null, $"Error executing generator method '{generatorMethod.Name}': {ex.GetBaseException()}"); - } - finally - { - loadContext?.Unload(); - } - } - - /// - /// If the given path is a reference assembly (located in a "ref" subdirectory), - /// returns the path to the corresponding implementation assembly. - /// - private static string ResolveImplementationAssemblyPath(string path) - { - // Reference assemblies are often placed in a "ref" subdirectory - // e.g. .../bin/Debug/net10.0/ref/Foo.dll → try .../bin/Debug/net10.0/Foo.dll - string? dir = Path.GetDirectoryName(path); - string? parentDir = dir != null ? Path.GetDirectoryName(dir) : null; - if (dir != null && parentDir != null && - string.Equals(Path.GetFileName(dir), "ref", StringComparison.OrdinalIgnoreCase)) - { - return Path.Combine(parentDir, Path.GetFileName(path)); - } - return path; - } - - /// - /// Collects all partial method definitions (declarations without implementations) from the compilation. - /// - private static IReadOnlyList GetAllUnimplementedPartialMethods(Compilation compilation) - { - List result = new(); - foreach (SyntaxTree syntaxTree in compilation.SyntaxTrees) - { - SemanticModel semanticModel = compilation.GetSemanticModel(syntaxTree); - IEnumerable partialDecls = syntaxTree.GetRoot().DescendantNodes() - .OfType() - .Where(m => m.Modifiers.Any(mod => mod.IsKind(SyntaxKind.PartialKeyword))); - - foreach (MethodDeclarationSyntax decl in partialDecls) - { - if (semanticModel.GetDeclaredSymbol(decl) is IMethodSymbol sym && sym.IsPartialDefinition) - result.Add(sym); - } - } - return result; - } - - private static CSharpCompilation BuildExecutionCompilation( - IReadOnlyList allPartialMethods, - Compilation compilation) - { - string dummySource = BuildDummyImplementation(allPartialMethods); - CSharpParseOptions parseOptions = compilation.SyntaxTrees.FirstOrDefault()?.Options as CSharpParseOptions - ?? CSharpParseOptions.Default; - return (CSharpCompilation)compilation - .WithOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary)) - .AddSyntaxTrees(CSharpSyntaxTree.ParseText(dummySource, parseOptions)); - } - - private static string BuildDummyImplementation(IEnumerable partialMethods) - { - StringBuilder sb = new(); - - IEnumerable> grouped = partialMethods.GroupBy( - m => (Namespace: m.ContainingType.ContainingNamespace?.IsGlobalNamespace == false - ? m.ContainingType.ContainingNamespace.ToDisplayString() - : null, - TypeName: m.ContainingType.Name, - IsStatic: m.ContainingType.IsStatic, - TypeKind: m.ContainingType.TypeKind)); - - foreach (IGrouping<(string? Namespace, string TypeName, bool IsStatic, TypeKind TypeKind), IMethodSymbol> typeGroup in grouped) - { - string? namespaceName = typeGroup.Key.Namespace; - if (namespaceName != null) - sb.AppendLine($"namespace {namespaceName} {{"); - - string typeKeyword = typeGroup.Key.TypeKind switch - { - TypeKind.Struct => "struct", - _ => "class" - }; - - string typeModifiers = typeGroup.Key.IsStatic ? "static partial" : "partial"; - sb.AppendLine($"{typeModifiers} {typeKeyword} {typeGroup.Key.TypeName} {{"); - - foreach (IMethodSymbol partialMethod in typeGroup) - { - string accessibility = partialMethod.DeclaredAccessibility switch - { - Accessibility.Public => "public", - Accessibility.Protected => "protected", - Accessibility.Internal => "internal", - Accessibility.ProtectedOrInternal => "protected internal", - Accessibility.ProtectedAndInternal => "private protected", - _ => "" - }; - - string staticModifier = partialMethod.IsStatic ? "static " : ""; - string returnType = partialMethod.ReturnType.ToDisplayString(); - string parameters = string.Join(", ", partialMethod.Parameters.Select(p => $"{p.Type.ToDisplayString()} {p.Name}")); - - sb.AppendLine($"{accessibility} {staticModifier}partial {returnType} {partialMethod.Name}({parameters}) {{"); - if (!partialMethod.ReturnsVoid) - sb.AppendLine("return default!;"); - sb.AppendLine("}"); - } - - sb.AppendLine("}"); - - if (namespaceName != null) - sb.AppendLine("}"); - } - - return sb.ToString(); - } } diff --git a/MattSourceGenHelpers.Generators/GeneratesMethodPatternSourceBuilder.cs b/MattSourceGenHelpers.Generators/GeneratesMethodPatternSourceBuilder.cs new file mode 100644 index 0000000..3d9614e --- /dev/null +++ b/MattSourceGenHelpers.Generators/GeneratesMethodPatternSourceBuilder.cs @@ -0,0 +1,260 @@ +using Microsoft.CodeAnalysis; +using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.CSharp.Syntax; +using System.Text; + +namespace MattSourceGenHelpers.Generators; + +internal static class GeneratesMethodPatternSourceBuilder +{ + private const string SwitchCaseAttributeTypeName = "MattSourceGenHelpers.Abstractions.SwitchCase"; + private const string SwitchDefaultAttributeTypeName = "MattSourceGenHelpers.Abstractions.SwitchDefault"; + + internal static string GenerateFromSwitchAttributes( + SourceProductionContext context, + List methods, + IMethodSymbol partialMethod, + INamedTypeSymbol containingType, + IReadOnlyList allPartials, + Compilation compilation) + { + List switchCaseMethods = methods + .Where(method => method.Symbol.GetAttributes().Any(attribute => attribute.AttributeClass?.ToDisplayString() == SwitchCaseAttributeTypeName)) + .ToList(); + GeneratesMethodGenerationTarget? switchDefaultMethod = methods + .FirstOrDefault(method => method.Symbol.GetAttributes().Any(attribute => attribute.AttributeClass?.ToDisplayString() == SwitchDefaultAttributeTypeName)); + + List<(object key, string value)> cases = new(); + foreach (GeneratesMethodGenerationTarget switchMethod in switchCaseMethods) + { + IEnumerable switchCaseAttributes = switchMethod.Symbol.GetAttributes() + .Where(attribute => attribute.AttributeClass?.ToDisplayString() == SwitchCaseAttributeTypeName); + + foreach (AttributeData switchCaseAttribute in switchCaseAttributes) + { + if (switchCaseAttribute.ConstructorArguments.Length == 0) + { + continue; + } + + object? caseArgument = switchCaseAttribute.ConstructorArguments[0].Value; + if (caseArgument is null) + { + continue; + } + + (string? result, string? error) = GeneratesMethodExecutionRuntime.ExecuteGeneratorMethodWithArgs( + switchMethod.Symbol, + allPartials, + compilation, + new[] { caseArgument }); + + if (error != null) + { + context.ReportDiagnostic(Diagnostic.Create( + GeneratesMethodGeneratorDiagnostics.GeneratorMethodExecutionError, + switchMethod.Syntax.GetLocation(), + switchMethod.Symbol.Name, + error)); + continue; + } + + cases.Add((caseArgument, FormatValueAsCSharpLiteral(result, partialMethod.ReturnType))); + } + } + + string? defaultExpression = switchDefaultMethod is not null + ? ExtractDefaultExpressionFromSwitchDefaultMethod(switchDefaultMethod.Syntax) + : null; + + return GenerateSwitchMethodSource(containingType, partialMethod, cases, defaultExpression); + } + + internal static string GenerateFromFluent( + SourceProductionContext context, + GeneratesMethodGenerationTarget methodInfo, + IMethodSymbol partialMethod, + INamedTypeSymbol containingType, + Compilation compilation) + { + (SwitchBodyData? record, string? error) = GeneratesMethodExecutionRuntime.ExecuteFluentGeneratorMethod( + methodInfo.Symbol, + partialMethod, + compilation); + + if (error != null) + { + context.ReportDiagnostic(Diagnostic.Create( + GeneratesMethodGeneratorDiagnostics.GeneratorMethodExecutionError, + methodInfo.Syntax.GetLocation(), + methodInfo.Symbol.Name, + error)); + return string.Empty; + } + + SwitchBodyData switchBodyData = record!; + string? defaultExpression = switchBodyData.HasDefaultCase + ? ExtractDefaultExpressionFromFluentMethod(methodInfo.Syntax) + : null; + + return GenerateSwitchMethodSource(containingType, partialMethod, switchBodyData.CasePairs, defaultExpression); + } + + internal static string GenerateSimplePartialMethod( + INamedTypeSymbol containingType, + IMethodSymbol partialMethod, + string? returnValue) + { + StringBuilder builder = new(); + AppendNamespaceAndTypeHeader(builder, containingType, partialMethod); + + if (!partialMethod.ReturnsVoid) + { + string literal = FormatValueAsCSharpLiteral(returnValue, partialMethod.ReturnType); + builder.AppendLine($" return {literal};"); + } + + builder.AppendLine(" }"); + builder.AppendLine("}"); + return builder.ToString(); + } + + private static string? ExtractDefaultExpressionFromSwitchDefaultMethod(MethodDeclarationSyntax method) + { + ExpressionSyntax? bodyExpression = method.ExpressionBody?.Expression; + if (bodyExpression == null && method.Body != null) + { + ReturnStatementSyntax? returnStatement = method.Body.Statements.OfType().FirstOrDefault(); + bodyExpression = returnStatement?.Expression; + } + + return ExtractInnermostLambdaBody(bodyExpression); + } + + private static string? ExtractDefaultExpressionFromFluentMethod(MethodDeclarationSyntax method) + { + IEnumerable invocations = method.DescendantNodes().OfType(); + foreach (InvocationExpressionSyntax invocation in invocations) + { + if (invocation.Expression is not MemberAccessExpressionSyntax memberAccessExpression) + { + continue; + } + + string methodName = memberAccessExpression.Name.Identifier.Text; + if (methodName is not ("RuntimeBody" or "CompileTimeBody")) + { + continue; + } + + ExpressionSyntax? argumentExpression = invocation.ArgumentList.Arguments.FirstOrDefault()?.Expression; + return ExtractInnermostLambdaBody(argumentExpression); + } + + return null; + } + + private static string? ExtractInnermostLambdaBody(ExpressionSyntax? expression) + { + while (true) + { + switch (expression) + { + case SimpleLambdaExpressionSyntax simpleLambdaExpression: + expression = simpleLambdaExpression.Body as ExpressionSyntax; + break; + case ParenthesizedLambdaExpressionSyntax parenthesizedLambdaExpression: + expression = parenthesizedLambdaExpression.Body as ExpressionSyntax; + break; + default: + return expression?.ToString(); + } + } + } + + private static string GenerateSwitchMethodSource( + INamedTypeSymbol containingType, + IMethodSymbol partialMethod, + IReadOnlyList<(object key, string value)> cases, + string? defaultExpression) + { + StringBuilder builder = new(); + AppendNamespaceAndTypeHeader(builder, containingType, partialMethod); + + string switchParameterName = partialMethod.Parameters.Length > 0 ? partialMethod.Parameters[0].Name : "arg"; + builder.AppendLine($" switch ({switchParameterName})"); + builder.AppendLine(" {"); + + foreach ((object key, string value) in cases) + { + builder.AppendLine($" case {key}: return {value};"); + } + + if (defaultExpression != null) + { + builder.AppendLine($" default: return {defaultExpression};"); + } + + builder.AppendLine(" }"); + builder.AppendLine(" }"); + builder.AppendLine("}"); + return builder.ToString(); + } + + private static void AppendNamespaceAndTypeHeader(StringBuilder builder, INamedTypeSymbol containingType, IMethodSymbol partialMethod) + { + string? namespaceName = containingType.ContainingNamespace?.IsGlobalNamespace == false + ? containingType.ContainingNamespace.ToDisplayString() + : null; + if (namespaceName != null) + { + builder.AppendLine($"namespace {namespaceName};"); + builder.AppendLine(); + } + + string typeKeyword = containingType.TypeKind switch + { + TypeKind.Struct => "struct", + TypeKind.Interface => "interface", + _ => "class" + }; + + string typeModifiers = containingType.IsStatic ? "static partial" : "partial"; + builder.AppendLine($"{typeModifiers} {typeKeyword} {containingType.Name}"); + builder.AppendLine("{"); + + string accessibility = partialMethod.DeclaredAccessibility switch + { + Accessibility.Public => "public", + Accessibility.Protected => "protected", + Accessibility.Internal => "internal", + Accessibility.ProtectedOrInternal => "protected internal", + Accessibility.ProtectedAndInternal => "private protected", + _ => "private" + }; + + string returnTypeName = partialMethod.ReturnType.ToDisplayString(); + string methodName = partialMethod.Name; + string parameters = string.Join(", ", partialMethod.Parameters.Select(parameter => $"{parameter.Type.ToDisplayString()} {parameter.Name}")); + string methodModifiers = partialMethod.IsStatic ? "static partial" : "partial"; + + builder.AppendLine($" {accessibility} {methodModifiers} {returnTypeName} {methodName}({parameters})"); + builder.AppendLine(" {"); + } + + internal static string FormatValueAsCSharpLiteral(string? value, ITypeSymbol returnType) + { + if (value == null) + { + return "default"; + } + + return returnType.SpecialType switch + { + SpecialType.System_String => SyntaxFactory.Literal(value).Text, + SpecialType.System_Char when value.Length == 1 => SyntaxFactory.Literal(value[0]).Text, + SpecialType.System_Boolean => value.ToLowerInvariant(), + _ => value + }; + } +} From e89e0c441967cf7d7e8187cf6788b73732d84b45 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Fri, 27 Feb 2026 19:47:43 +0000 Subject: [PATCH 5/6] refactor: split generator orchestration and pattern/runtime helpers Co-authored-by: dex3r <3155725+dex3r@users.noreply.github.com> --- .../GeneratesMethodExecutionRuntime.cs | 24 +++++++++---------- .../GeneratesMethodPatternSourceBuilder.cs | 6 ++--- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs b/MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs index 9c7c311..a69fcf1 100644 --- a/MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs +++ b/MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs @@ -30,10 +30,10 @@ internal static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMe Compilation compilation) { IReadOnlyList allPartials = GetAllUnimplementedPartialMethods(compilation); - CSharpCompilation runtimeCompilation = BuildExecutionCompilation(allPartials, compilation); + CSharpCompilation executableCompilation = BuildExecutionCompilation(allPartials, compilation); using MemoryStream stream = new(); - EmitResult emitResult = runtimeCompilation.Emit(stream); + EmitResult emitResult = executableCompilation.Emit(stream); if (!emitResult.Success) { string errors = string.Join("; ", emitResult.Diagnostics @@ -87,13 +87,13 @@ internal static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMe currentGeneratorProperty?.SetValue(null, recordingFactory); string typeName = generatorMethod.ContainingType.ToDisplayString(); - Type? generatedType = assembly.GetType(typeName); - if (generatedType == null) + Type? loadedType = assembly.GetType(typeName); + if (loadedType == null) { return (null, $"Could not find type '{typeName}' in compiled assembly"); } - MethodInfo? generatorMethodInfo = generatedType.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public); + MethodInfo? generatorMethodInfo = loadedType.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public); if (generatorMethodInfo == null) { return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'"); @@ -126,10 +126,10 @@ internal static (string? value, string? error) ExecuteGeneratorMethodWithArgs( Compilation compilation, object?[]? args) { - CSharpCompilation runtimeCompilation = BuildExecutionCompilation(allPartialMethods, compilation); + CSharpCompilation executableCompilation = BuildExecutionCompilation(allPartialMethods, compilation); using MemoryStream stream = new(); - EmitResult emitResult = runtimeCompilation.Emit(stream); + EmitResult emitResult = executableCompilation.Emit(stream); if (!emitResult.Success) { string errors = string.Join("; ", emitResult.Diagnostics @@ -156,13 +156,13 @@ internal static (string? value, string? error) ExecuteGeneratorMethodWithArgs( Assembly assembly = loadContext.LoadFromStream(stream); string typeName = generatorMethod.ContainingType.ToDisplayString(); - Type? generatedType = assembly.GetType(typeName); - if (generatedType == null) + Type? loadedType = assembly.GetType(typeName); + if (loadedType == null) { return (null, $"Could not find type '{typeName}' in compiled assembly"); } - MethodInfo? generatorMethodInfo = generatedType.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public); + MethodInfo? generatorMethodInfo = loadedType.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public); if (generatorMethodInfo == null) { return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'"); @@ -302,10 +302,10 @@ private static string BuildDummyImplementation(IEnumerable partia Accessibility.Internal => "internal", Accessibility.ProtectedOrInternal => "protected internal", Accessibility.ProtectedAndInternal => "private protected", - _ => string.Empty + _ => "" }; - string staticModifier = partialMethod.IsStatic ? "static " : string.Empty; + string staticModifier = partialMethod.IsStatic ? "static " : ""; string returnType = partialMethod.ReturnType.ToDisplayString(); string parameters = string.Join(", ", partialMethod.Parameters.Select(parameter => $"{parameter.Type.ToDisplayString()} {parameter.Name}")); diff --git a/MattSourceGenHelpers.Generators/GeneratesMethodPatternSourceBuilder.cs b/MattSourceGenHelpers.Generators/GeneratesMethodPatternSourceBuilder.cs index 3d9614e..7a7b299 100644 --- a/MattSourceGenHelpers.Generators/GeneratesMethodPatternSourceBuilder.cs +++ b/MattSourceGenHelpers.Generators/GeneratesMethodPatternSourceBuilder.cs @@ -24,7 +24,7 @@ internal static string GenerateFromSwitchAttributes( GeneratesMethodGenerationTarget? switchDefaultMethod = methods .FirstOrDefault(method => method.Symbol.GetAttributes().Any(attribute => attribute.AttributeClass?.ToDisplayString() == SwitchDefaultAttributeTypeName)); - List<(object key, string value)> cases = new(); + List<(object key, string value)> switchCases = new(); foreach (GeneratesMethodGenerationTarget switchMethod in switchCaseMethods) { IEnumerable switchCaseAttributes = switchMethod.Symbol.GetAttributes() @@ -59,7 +59,7 @@ internal static string GenerateFromSwitchAttributes( continue; } - cases.Add((caseArgument, FormatValueAsCSharpLiteral(result, partialMethod.ReturnType))); + switchCases.Add((caseArgument, FormatValueAsCSharpLiteral(result, partialMethod.ReturnType))); } } @@ -67,7 +67,7 @@ internal static string GenerateFromSwitchAttributes( ? ExtractDefaultExpressionFromSwitchDefaultMethod(switchDefaultMethod.Syntax) : null; - return GenerateSwitchMethodSource(containingType, partialMethod, cases, defaultExpression); + return GenerateSwitchMethodSource(containingType, partialMethod, switchCases, defaultExpression); } internal static string GenerateFromFluent( From 095d0c148bab02a534b5ca9352c94ba142231e48 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 28 Feb 2026 11:26:14 +0000 Subject: [PATCH 6/6] fix: guard null FilePath in metadata reference lookups Co-authored-by: dex3r <3155725+dex3r@users.noreply.github.com> --- .../GeneratesMethodExecutionRuntime.cs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs b/MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs index a69fcf1..553f2d7 100644 --- a/MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs +++ b/MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs @@ -51,7 +51,7 @@ internal static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMe { PortableExecutableReference? match = compilation.References .OfType() - .FirstOrDefault(reference => string.Equals( + .FirstOrDefault(reference => reference.FilePath is not null && string.Equals( Path.GetFileNameWithoutExtension(reference.FilePath), assemblyName.Name, StringComparison.OrdinalIgnoreCase)); @@ -62,7 +62,7 @@ internal static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMe PortableExecutableReference? abstractionsReference = compilation.References .OfType() - .FirstOrDefault(reference => string.Equals( + .FirstOrDefault(reference => reference.FilePath is not null && string.Equals( Path.GetFileNameWithoutExtension(reference.FilePath), "MattSourceGenHelpers.Abstractions", StringComparison.OrdinalIgnoreCase)); @@ -147,7 +147,7 @@ internal static (string? value, string? error) ExecuteGeneratorMethodWithArgs( { PortableExecutableReference? match = compilation.References .OfType() - .FirstOrDefault(reference => string.Equals( + .FirstOrDefault(reference => reference.FilePath is not null && string.Equals( Path.GetFileNameWithoutExtension(reference.FilePath), assemblyName.Name, StringComparison.OrdinalIgnoreCase));