Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
331 changes: 331 additions & 0 deletions MattSourceGenHelpers.Generators/GeneratesMethodExecutionRuntime.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,331 @@
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.Emit;
using System.Collections;
using System.Reflection;
using System.Runtime.Loader;
using System.Text;

namespace MattSourceGenHelpers.Generators;

internal sealed record SwitchBodyData(
IReadOnlyList<(object key, string value)> CasePairs,
bool HasDefaultCase);

internal static class GeneratesMethodExecutionRuntime
{
internal static (string? value, string? error) ExecuteSimpleGeneratorMethod(
IMethodSymbol generatorMethod,
IMethodSymbol partialMethod,
Compilation compilation)
{
IReadOnlyList<IMethodSymbol> allPartials = GetAllUnimplementedPartialMethods(compilation);
return ExecuteGeneratorMethodWithArgs(generatorMethod, allPartials, compilation, null);
}

internal static (SwitchBodyData? record, string? error) ExecuteFluentGeneratorMethod(
IMethodSymbol generatorMethod,
IMethodSymbol partialMethod,
Compilation compilation)
{
IReadOnlyList<IMethodSymbol> allPartials = GetAllUnimplementedPartialMethods(compilation);
CSharpCompilation executableCompilation = BuildExecutionCompilation(allPartials, compilation);

using MemoryStream stream = new();
EmitResult emitResult = executableCompilation.Emit(stream);
if (!emitResult.Success)
{
string errors = string.Join("; ", emitResult.Diagnostics
.Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)
.Select(diagnostic => diagnostic.GetMessage()));
return (null, $"Compilation failed: {errors}");
}

stream.Position = 0;
AssemblyLoadContext? loadContext = null;
try
{
loadContext = new AssemblyLoadContext("__GeneratorExec", isCollectible: true);
loadContext.Resolving += (context, assemblyName) =>
{
PortableExecutableReference? match = compilation.References
.OfType<PortableExecutableReference>()
.FirstOrDefault(reference => reference.FilePath is not null && string.Equals(
Path.GetFileNameWithoutExtension(reference.FilePath),
assemblyName.Name,
StringComparison.OrdinalIgnoreCase));
return match?.FilePath != null ? context.LoadFromAssemblyPath(match.FilePath) : null;
Comment on lines +52 to +58
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PortableExecutableReference.FilePath is nullable, but the predicate calls Path.GetFileNameWithoutExtension(reference.FilePath) without a null check. This can throw (and also triggers nullable warnings) when the compilation contains in-memory references with FilePath == null. Consider filtering to reference.FilePath is not null before calling Path.* (and/or using a safe fallback).

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 095d0c1 by guarding reference.FilePath is not null before calling Path.GetFileNameWithoutExtension in the resolver lookup.

};

Assembly assembly = loadContext.LoadFromStream(stream);

PortableExecutableReference? abstractionsReference = compilation.References
.OfType<PortableExecutableReference>()
.FirstOrDefault(reference => reference.FilePath is not null && string.Equals(
Path.GetFileNameWithoutExtension(reference.FilePath),
"MattSourceGenHelpers.Abstractions",
StringComparison.OrdinalIgnoreCase));
Comment on lines +63 to +68
Copy link

Copilot AI Feb 28, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The lookup for the Abstractions reference assumes reference.FilePath is non-null (Path.GetFileNameWithoutExtension(reference.FilePath)). FilePath can be null for some metadata references, which would crash the generator. Add a null check in the predicate (or otherwise handle null paths) before calling Path.*.

Copilot uses AI. Check for mistakes.
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in 095d0c1 by adding a null-safe predicate for reference.FilePath in the Abstractions-reference lookup.


if (abstractionsReference?.FilePath == null)
{
return (null, "Could not find MattSourceGenHelpers.Abstractions reference in compilation");
}

string abstractionsAssemblyPath = ResolveImplementationAssemblyPath(abstractionsReference.FilePath);
Assembly abstractionsAssembly = loadContext.LoadFromAssemblyPath(abstractionsAssemblyPath);

Type? generatorStaticType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.Generator");
Type? recordingFactoryType = abstractionsAssembly.GetType("MattSourceGenHelpers.Abstractions.RecordingGeneratorsFactory");
if (generatorStaticType == null || recordingFactoryType == null)
{
return (null, "Could not find Generator or RecordingGeneratorsFactory types in Abstractions assembly");
}

object? recordingFactory = Activator.CreateInstance(recordingFactoryType);
PropertyInfo? currentGeneratorProperty = generatorStaticType.GetProperty("CurrentGenerator", BindingFlags.Public | BindingFlags.Static);
currentGeneratorProperty?.SetValue(null, recordingFactory);

string typeName = generatorMethod.ContainingType.ToDisplayString();
Type? loadedType = assembly.GetType(typeName);
if (loadedType == null)
{
return (null, $"Could not find type '{typeName}' in compiled assembly");
}

MethodInfo? generatorMethodInfo = loadedType.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public);
if (generatorMethodInfo == null)
{
return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'");
}

generatorMethodInfo.Invoke(null, null);

PropertyInfo? lastRecordProperty = recordingFactoryType.GetProperty("LastRecord");
object? lastRecord = lastRecordProperty?.GetValue(recordingFactory);
if (lastRecord == null)
{
return (null, "RecordingGeneratorsFactory did not produce a record");
}

return (ExtractSwitchBodyData(lastRecord, partialMethod.ReturnType), null);
}
catch (Exception ex)
{
return (null, $"Error executing generator method '{generatorMethod.Name}': {ex.GetBaseException()}");
}
finally
{
loadContext?.Unload();
}
}

internal static (string? value, string? error) ExecuteGeneratorMethodWithArgs(
IMethodSymbol generatorMethod,
IReadOnlyList<IMethodSymbol> allPartialMethods,
Compilation compilation,
object?[]? args)
{
CSharpCompilation executableCompilation = BuildExecutionCompilation(allPartialMethods, compilation);

using MemoryStream stream = new();
EmitResult emitResult = executableCompilation.Emit(stream);
if (!emitResult.Success)
{
string errors = string.Join("; ", emitResult.Diagnostics
.Where(diagnostic => diagnostic.Severity == DiagnosticSeverity.Error)
.Select(diagnostic => diagnostic.GetMessage()));
return (null, $"Compilation failed: {errors}");
}

stream.Position = 0;
AssemblyLoadContext? loadContext = null;
try
{
loadContext = new AssemblyLoadContext("__GeneratorExec", isCollectible: true);
loadContext.Resolving += (context, assemblyName) =>
{
PortableExecutableReference? match = compilation.References
.OfType<PortableExecutableReference>()
.FirstOrDefault(reference => reference.FilePath is not null && string.Equals(
Path.GetFileNameWithoutExtension(reference.FilePath),
assemblyName.Name,
StringComparison.OrdinalIgnoreCase));
return match?.FilePath != null ? context.LoadFromAssemblyPath(match.FilePath) : null;
};

Assembly assembly = loadContext.LoadFromStream(stream);
string typeName = generatorMethod.ContainingType.ToDisplayString();
Type? loadedType = assembly.GetType(typeName);
if (loadedType == null)
{
return (null, $"Could not find type '{typeName}' in compiled assembly");
}

MethodInfo? generatorMethodInfo = loadedType.GetMethod(generatorMethod.Name, BindingFlags.Static | BindingFlags.NonPublic | BindingFlags.Public);
if (generatorMethodInfo == null)
{
return (null, $"Could not find method '{generatorMethod.Name}' in type '{typeName}'");
}

object?[]? convertedArgs = ConvertArguments(args, generatorMethodInfo);
object? result = generatorMethodInfo.Invoke(null, convertedArgs);
return (result?.ToString(), null);
}
catch (Exception ex)
{
return (null, $"Error executing generator method '{generatorMethod.Name}': {ex.GetBaseException()}");
}
finally
{
loadContext?.Unload();
}
}

internal static IReadOnlyList<IMethodSymbol> GetAllUnimplementedPartialMethods(Compilation compilation)
{
List<IMethodSymbol> methods = new();
foreach (SyntaxTree syntaxTree in compilation.SyntaxTrees)
{
SemanticModel semanticModel = compilation.GetSemanticModel(syntaxTree);
IEnumerable<MethodDeclarationSyntax> partialMethodDeclarations = syntaxTree.GetRoot().DescendantNodes()
.OfType<MethodDeclarationSyntax>()
.Where(method => method.Modifiers.Any(modifier => modifier.IsKind(SyntaxKind.PartialKeyword)));

foreach (MethodDeclarationSyntax declaration in partialMethodDeclarations)
{
if (semanticModel.GetDeclaredSymbol(declaration) is IMethodSymbol symbol && symbol.IsPartialDefinition)
{
methods.Add(symbol);
}
}
}

return methods;
}

private static object?[]? ConvertArguments(object?[]? args, MethodInfo methodInfo)
{
if (args == null || methodInfo.GetParameters().Length == 0)
{
return null;
}

Type parameterType = methodInfo.GetParameters()[0].ParameterType;
return new[] { Convert.ChangeType(args[0], parameterType) };
}

private static SwitchBodyData ExtractSwitchBodyData(object lastRecord, ITypeSymbol returnType)
{
Type recordType = lastRecord.GetType();
PropertyInfo? caseKeysProperty = recordType.GetProperty("CaseKeys");
PropertyInfo? caseValuesProperty = recordType.GetProperty("CaseValues");
PropertyInfo? hasDefaultProperty = recordType.GetProperty("HasDefaultCase");

IList caseKeys = (caseKeysProperty?.GetValue(lastRecord) as IList) ?? new List<object>();
IList caseValues = (caseValuesProperty?.GetValue(lastRecord) as IList) ?? new List<object?>();
bool hasDefaultCase = (bool)(hasDefaultProperty?.GetValue(lastRecord) ?? false);

List<(object key, string value)> pairs = new();
for (int index = 0; index < caseKeys.Count; index++)
{
object key = caseKeys[index]!;
string? value = index < caseValues.Count ? caseValues[index]?.ToString() : null;
pairs.Add((key, GeneratesMethodPatternSourceBuilder.FormatValueAsCSharpLiteral(value, returnType)));
}

return new SwitchBodyData(pairs, hasDefaultCase);
}

private static string ResolveImplementationAssemblyPath(string path)
{
string? directory = Path.GetDirectoryName(path);
string? parentDirectory = directory != null ? Path.GetDirectoryName(directory) : null;
if (directory != null &&
parentDirectory != null &&
string.Equals(Path.GetFileName(directory), "ref", StringComparison.OrdinalIgnoreCase))
{
return Path.Combine(parentDirectory, Path.GetFileName(path));
}

return path;
}

private static CSharpCompilation BuildExecutionCompilation(
IReadOnlyList<IMethodSymbol> allPartialMethods,
Compilation compilation)
{
string dummySource = BuildDummyImplementation(allPartialMethods);
CSharpParseOptions parseOptions = compilation.SyntaxTrees.FirstOrDefault()?.Options as CSharpParseOptions
?? CSharpParseOptions.Default;

return (CSharpCompilation)compilation
.WithOptions(new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary))
.AddSyntaxTrees(CSharpSyntaxTree.ParseText(dummySource, parseOptions));
}

private static string BuildDummyImplementation(IEnumerable<IMethodSymbol> partialMethods)
{
StringBuilder builder = new();

IEnumerable<IGrouping<(string? Namespace, string TypeName, bool IsStatic, TypeKind TypeKind), IMethodSymbol>> groupedMethods = partialMethods.GroupBy(
method => (Namespace: method.ContainingType.ContainingNamespace?.IsGlobalNamespace == false
? method.ContainingType.ContainingNamespace.ToDisplayString()
: null,
TypeName: method.ContainingType.Name,
IsStatic: method.ContainingType.IsStatic,
TypeKind: method.ContainingType.TypeKind));

foreach (IGrouping<(string? Namespace, string TypeName, bool IsStatic, TypeKind TypeKind), IMethodSymbol> typeGroup in groupedMethods)
{
string? namespaceName = typeGroup.Key.Namespace;
if (namespaceName != null)
{
builder.AppendLine($"namespace {namespaceName} {{");
}

string typeKeyword = typeGroup.Key.TypeKind switch
{
TypeKind.Struct => "struct",
_ => "class"
};

string typeModifiers = typeGroup.Key.IsStatic ? "static partial" : "partial";
builder.AppendLine($"{typeModifiers} {typeKeyword} {typeGroup.Key.TypeName} {{");

foreach (IMethodSymbol partialMethod in typeGroup)
{
string accessibility = partialMethod.DeclaredAccessibility switch
{
Accessibility.Public => "public",
Accessibility.Protected => "protected",
Accessibility.Internal => "internal",
Accessibility.ProtectedOrInternal => "protected internal",
Accessibility.ProtectedAndInternal => "private protected",
_ => ""
};

string staticModifier = partialMethod.IsStatic ? "static " : "";
string returnType = partialMethod.ReturnType.ToDisplayString();
string parameters = string.Join(", ", partialMethod.Parameters.Select(parameter => $"{parameter.Type.ToDisplayString()} {parameter.Name}"));

builder.AppendLine($"{accessibility} {staticModifier}partial {returnType} {partialMethod.Name}({parameters}) {{");
if (!partialMethod.ReturnsVoid)
{
builder.AppendLine("return default!;");
}

builder.AppendLine("}");
}

builder.AppendLine("}");

if (namespaceName != null)
{
builder.AppendLine("}");
}
}

return builder.ToString();
}
}
Loading