Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using Microsoft.CodeAnalysis;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.Linq;
Expand All @@ -10,7 +11,10 @@ internal static class WrapperElementGenerator
{
internal static string GenerateWrapperTypeSource(StringBuilder source, INamedTypeSymbol elm)
{
var name = $"{elm.Name.Substring(1)}Wrapper";
// Element interface names start with 'I' (e.g., IElement -> ElementWrapper)
var name = elm.Name.Length > 1 && elm.Name.StartsWith("I", StringComparison.Ordinal)
? $"{elm.Name[1..]}Wrapper"
: $"{elm.Name}Wrapper";
var wrappedTypeName = elm.ToDisplayString(GeneratorConfig.SymbolFormat);

source.AppendLine("#nullable enable");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#nullable enable
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.Text;
using System;
using System.Collections.Generic;
using System.Collections.Immutable;
using System.IO;
using System.Linq;
using System.Text;
Expand All @@ -14,54 +16,82 @@ public class WrapperElementsGenerator : IIncrementalGenerator
public void Initialize(IncrementalGeneratorInitializationContext context)
{
// Finds the AngleSharp assembly referenced by the target project
// This should prevent the source generator from running unless a
// new symbol is returned.
var angleSharpAssemblyReference = context
// and collects element interface type names into cacheable records.
var elementInterfaces = context
.CompilationProvider
.Select((compilation, cancellationToken) =>
{
var meta = compilation.References.FirstOrDefault(x => x.Display?.EndsWith($"{Path.DirectorySeparatorChar}AngleSharp.dll", StringComparison.Ordinal) ?? false);
return compilation.GetAssemblyOrModuleSymbol(meta);
var assembly = compilation.GetAssemblyOrModuleSymbol(meta);

if (assembly is not IAssemblySymbol angleSharpAssembly)
return null;

var elementInterfaceTypes = FindElementInterfaces(angleSharpAssembly);
// Create cacheable records with just the essential info needed for generation
// Store metadata names instead of symbols for cacheability
return new ElementInterfacesData(
elementInterfaceTypes.Select(t => new ElementTypeInfo(
t.Name,
t.ToDisplayString(GeneratorConfig.SymbolFormat),
GetMetadataName(t)
)).ToImmutableArray());
});

// Combine with compilation to retrieve symbols during execution
var elementInterfacesWithCompilation = elementInterfaces.Combine(context.CompilationProvider);

// Output the hardcoded source files
context.RegisterSourceOutput(angleSharpAssemblyReference, GenerateStaticContent);
context.RegisterSourceOutput(elementInterfaces, GenerateStaticContent);

// Output the generated wrapper types
context.RegisterSourceOutput(angleSharpAssemblyReference, GenerateWrapperTypes);
context.RegisterSourceOutput(elementInterfacesWithCompilation, GenerateWrapperTypes);
}

private static void GenerateStaticContent(SourceProductionContext context, ISymbol assembly)
private static void GenerateStaticContent(SourceProductionContext context, ElementInterfacesData? data)
{
if (assembly is not IAssemblySymbol)
if (data is null)
return;

context.AddSource("IElementWrapperFactory.g.cs", ReadEmbeddedResource("Bunit.Web.AngleSharp.IElementWrapperFactory.cs"));
context.AddSource("IElementWrapper.g.cs", ReadEmbeddedResource("Bunit.Web.AngleSharp.IElementWrapper.cs"));
context.AddSource("WrapperBase.g.cs", ReadEmbeddedResource("Bunit.Web.AngleSharp.WrapperBase.cs"));
}

private static void GenerateWrapperTypes(SourceProductionContext context, ISymbol assembly)
private static void GenerateWrapperTypes(SourceProductionContext context, (ElementInterfacesData? data, Compilation compilation) input)
{
var (data, compilation) = input;
if (data is null)
return;

// Find the AngleSharp assembly in the compilation
var meta = compilation.References.FirstOrDefault(x => x.Display?.EndsWith($"{Path.DirectorySeparatorChar}AngleSharp.dll", StringComparison.Ordinal) ?? false);
var assembly = compilation.GetAssemblyOrModuleSymbol(meta);

if (assembly is not IAssemblySymbol angleSharpAssembly)
return;

var elementInterfacetypes = FindElementInterfaces(angleSharpAssembly);
// Retrieve the actual symbols from the assembly for code generation
var elementSymbols = data.ElementTypes
.Select(t => angleSharpAssembly.GetTypeByMetadataName(t.MetadataName))
.Where(s => s is not null)
.Cast<INamedTypeSymbol>()
.ToList();

var source = new StringBuilder();
foreach (var elm in elementInterfacetypes)
foreach (var elm in elementSymbols)
{
source.Clear();
var name = WrapperElementGenerator.GenerateWrapperTypeSource(source, elm);
context.AddSource($"{name}.g.cs", SourceText.From(source.ToString(), Encoding.UTF8));
}

source.Clear();
GenerateWrapperFactory(source, elementInterfacetypes);
GenerateWrapperFactory(source, data.ElementTypes);
context.AddSource($"WrapperExtensions.g.cs", SourceText.From(source.ToString(), Encoding.UTF8));
}

private static void GenerateWrapperFactory(StringBuilder source, IEnumerable<INamedTypeSymbol> elementInterfacetypes)
private static void GenerateWrapperFactory(StringBuilder source, ImmutableArray<ElementTypeInfo> elementTypes)
{
source.AppendLine("""namespace Bunit.Web.AngleSharp;""");
source.AppendLine();
Expand All @@ -78,10 +108,13 @@ private static void GenerateWrapperFactory(StringBuilder source, IEnumerable<INa
source.AppendLine($"\tpublic static global::AngleSharp.Dom.IElement WrapUsing<TElementFactory>(this global::AngleSharp.Dom.IElement element, TElementFactory elementFactory) where TElementFactory : Bunit.Web.AngleSharp.IElementWrapperFactory => element switch");
source.AppendLine("\t{");

foreach (var elm in elementInterfacetypes)
foreach (var elm in elementTypes)
{
var wrapperName = $"{elm.Name.Substring(1)}Wrapper";
source.AppendLine($"\t\t{elm.ToDisplayString(GeneratorConfig.SymbolFormat)} e => new {wrapperName}(e, elementFactory),");
// Element interface names start with 'I' (e.g., IElement -> ElementWrapper)
var wrapperName = elm.Name.Length > 1 && elm.Name.StartsWith("I", StringComparison.Ordinal)
? $"{elm.Name[1..]}Wrapper"
: $"{elm.Name}Wrapper";
source.AppendLine($"\t\t{elm.FullyQualifiedName} e => new {wrapperName}(e, elementFactory),");
}

source.AppendLine($"\t\t_ => new ElementWrapper(element, elementFactory),");
Expand All @@ -90,6 +123,17 @@ private static void GenerateWrapperFactory(StringBuilder source, IEnumerable<INa
source.AppendLine("}");
}

private static string GetMetadataName(INamedTypeSymbol typeSymbol)
{
// Get the full metadata name that can be used with GetTypeByMetadataName
// This is the fully qualified name without the "global::" prefix
var containingNamespace = typeSymbol.ContainingNamespace;
var namespacePrefix = containingNamespace?.IsGlobalNamespace == false
? containingNamespace.ToDisplayString() + "."
: "";
return namespacePrefix + typeSymbol.Name;
}

private static IReadOnlyList<INamedTypeSymbol> FindElementInterfaces(IAssemblySymbol angleSharpAssembly)
{
var htmlDomNamespace = angleSharpAssembly
Expand All @@ -104,6 +148,9 @@ private static IReadOnlyList<INamedTypeSymbol> FindElementInterfaces(IAssemblySy
var elementInterfaceSymbol = angleSharpAssembly
.GetTypeByMetadataName("AngleSharp.Dom.IElement");

if (elementInterfaceSymbol is null)
return Array.Empty<INamedTypeSymbol>();

var result = htmlDomNamespace
.GetTypeMembers()
.Where(typeSymbol => typeSymbol.TypeKind == TypeKind.Interface && typeSymbol.AllInterfaces.Contains(elementInterfaceSymbol))
Expand Down Expand Up @@ -139,3 +186,13 @@ private static string ReadEmbeddedResource(string resourceName)
return reader.ReadToEnd();
}
}

// Cacheable data structure that stores minimal information about element interfaces
// This allows the incremental generator to cache and reuse results across builds
internal sealed record ElementInterfacesData(
ImmutableArray<ElementTypeInfo> ElementTypes);

internal sealed record ElementTypeInfo(
string Name,
string FullyQualifiedName,
string MetadataName);