Skip to content

Commit

Permalink
Merge pull request #107 from jonisavo/fix/no-mandatory-logger-in-di
Browse files Browse the repository at this point in the history
Drop requirement of Logger field from dependency injection codegen
  • Loading branch information
jonisavo committed Sep 21, 2023
2 parents f941d19 + b562ca7 commit b6e219c
Show file tree
Hide file tree
Showing 13 changed files with 155 additions and 42 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -76,4 +76,5 @@ crashlytics-build.properties
*.coverage
/dist
coveragereport/
TestResults/
TestResults/
.DS_Store
59 changes: 59 additions & 0 deletions Assets/UIComponents.Tests/Roslyn/ProvideErrorTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
using System;
using System.Collections.Generic;
using NUnit.Framework;
using UIComponents.DependencyInjection;
using UnityEngine;
using UnityEngine.TestTools;

namespace UIComponents.Tests.Roslyn
{
public interface IMissingDependency {}

public partial class ProvideErrorComponent : UIComponent
{
[Provide] public ILogger MyLogger;
[Provide] private IMissingDependency _dependency;
}

[Dependency(typeof(ILogger), provide: typeof(DebugLogger))]
public partial class ProvideErrorClass : IDependencyConsumer
{
[Provide] public ILogger MyLogger;
[Provide] private IMissingDependency _dependency;
private readonly DependencyInjector _dependencyInjector;

public ProvideErrorClass()
{
DiContext.Current.RegisterConsumer(this);
_dependencyInjector = DiContext.Current.GetInjector(GetType());
UIC_PopulateProvideFields();
}

private T Provide<T>() where T : class
{
return _dependencyInjector.Provide<T>();
}
}

[TestFixture]
public class ProvideErrorTests
{
[Test]
public void Error_Is_Printed_With_Logger()
{
var component = new ProvideErrorComponent();

Assert.That(component.MyLogger, Is.InstanceOf<DebugLogger>());
LogAssert.Expect(LogType.Error, "[ProvideErrorComponent] Could not provide IMissingDependency to _dependency");
}

[Test]
public void Error_Is_Printed_With_DebugLog()
{
var instance = new ProvideErrorClass();

Assert.That(instance.MyLogger, Is.InstanceOf<DebugLogger>());
LogAssert.Expect(LogType.Error, "Could not provide IMissingDependency to _dependency");
}
}
}
3 changes: 3 additions & 0 deletions Assets/UIComponents.Tests/Roslyn/ProvideErrorTests.cs.meta

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Binary file modified Assets/UIComponents/Roslyn/UIComponents.Roslyn.Common.dll
Binary file not shown.
Binary file modified Assets/UIComponents/Roslyn/UIComponents.Roslyn.Common.pdb
Binary file not shown.
Binary file modified Assets/UIComponents/Roslyn/UIComponents.Roslyn.Generation.dll
Binary file not shown.
Binary file modified Assets/UIComponents/Roslyn/UIComponents.Roslyn.Generation.pdb
Binary file not shown.
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public static bool HasBaseType(INamedTypeSymbol type, INamedTypeSymbol desiredBa
{
var current = type;

while (current != null)
while (current != null && current.SpecialType != SpecialType.System_Object)
{
if (current.Equals(desiredBaseType, SymbolEqualityComparer.Default))
return true;
Expand Down Expand Up @@ -177,7 +177,7 @@ public static IEnumerable<ISymbol> GetAllMembersOfType(ITypeSymbol typeSymbol)
{
var current = typeSymbol;

while (current != null)
while (current != null && current.SpecialType != SpecialType.System_Object)
{
foreach (var member in current.GetMembers())
yield return member;
Expand All @@ -190,7 +190,7 @@ public static IEnumerable<AttributeData> GetAllAttributesOfType(ITypeSymbol type
{
var current = typeSymbol;

while (current != null)
while (current != null && current.SpecialType != SpecialType.System_Object)
{
foreach (var attribute in current.GetAttributes())
yield return attribute;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ public interface IService {}
namespace MyLibrary.GUI
{
public class GuiComponent
public class GuiComponent : UIComponent
{
[Provide]
public IService service;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ public class DebugLogger : ILogger {}
[Dependency(typeof(ILogger), typeof(DebugLogger))]
public abstract class UIComponent : VisualElement, IDependencyConsumer
{
protected readonly ILogger? Logger;

public class UxmlFactory<T> where T : UIComponent {}

protected virtual void UIC_PopulateProvideFields() {}
}
}
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
//HintName: MyClass.Provide.g.cs
//HintName: MyClass.Provide.g.cs
// <auto-generated>
// This file has been generated automatically by UIComponents.Roslyn.
// Do not attempt to modify it. Any changes will be overridden during compilation.
// </auto-generated>

using UIComponents;
using System;
using UnityEngine;
using System.CodeDom.Compiler;
using UnityEngine.UIElements;

Expand All @@ -20,16 +21,16 @@ public partial class MyClass
}
catch (MissingProviderException)
{
Logger.LogError($"Could not provide {typeof(TField).Name} to {fieldName}", this);
Debug.LogError($"Could not provide {typeof(TField).Name} to {fieldName}");
}
catch (InvalidCastException)
{
Logger.LogError($"Could not cast {typeof(TCastFrom).Name} to {typeof(TField).Name}", this);
Debug.LogError($"Could not cast {typeof(TCastFrom).Name} to {typeof(TField).Name}");
}
}

[GeneratedCode("UIComponents.Roslyn.Generation", "1.0.0-beta.4")]
protected override void UIC_PopulateProvideFields()
protected void UIC_PopulateProvideFields()
{
UIC_SetProvideField<IDependency, IDependency>(ref Dependency, "Dependency");
UIC_SetProvideField<Dependency, IDependency>(ref ConcreteDependency, "ConcreteDependency");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,14 @@ private void ExecuteForType(TypeDeclarationSyntax node, GeneratorExecutionContex
foreach (var usingName in usingsList)
_currentContext.Usings.Add(usingName);

AddAdditionalUsings(_currentContext.Usings);

_currentContext.CurrentTypeSymbol =
_currentContext.ClassSemanticModel.GetDeclaredSymbol(node) as INamedTypeSymbol;

if (!ShouldGenerateSource(_currentContext))
return;

AddAdditionalUsings(_currentContext.Usings);

_stringBuilder.AppendLine(@"// <auto-generated>
// This file has been generated automatically by UIComponents.Roslyn.
// Do not attempt to modify it. Any changes will be overridden during compilation.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,28 @@ namespace UIComponents.Roslyn.Generation.Generators.DependencyInjection
public sealed class ProvideAugmentGenerator : AugmentGenerator<ClassSyntaxReceiver>
{
private INamedTypeSymbol _provideAttributeSymbol;
private bool _hasLogger = false;
private bool _hasPopulateFieldsMethod = false;
private readonly List<ProvideDescription> _provideDescriptions =
new List<ProvideDescription>();

private const string MissingProviderExceptionMessage =
"$\"Could not provide {typeof(TField).Name} to {fieldName}\"";
private const string InvalidCastExceptionMessage =
"$\"Could not cast {typeof(TCastFrom).Name} to {typeof(TField).Name}\"";
private const string PopulateMethodName = "UIC_PopulateProvideFields";

protected override void OnBeforeExecute(GeneratorExecutionContext context)
{
_provideAttributeSymbol =
context.Compilation.GetTypeByMetadataName("UIComponents.ProvideAttribute");
}

private void GetProvideDescriptions(AugmentGenerationContext context, IList<ProvideDescription> output)
private void GetProvideDescriptions(IEnumerable<IFieldSymbol> fields, IList<ProvideDescription> output)
{
var members = RoslynUtilities.GetAllMembersOfType(context.CurrentTypeSymbol);

foreach (var member in members)
foreach (var field in fields)
{
if (!(member is IFieldSymbol fieldSymbol))
continue;

var memberType = fieldSymbol.Type;

var memberTypeIsClassOrInterface =
memberType.TypeKind == TypeKind.Class || memberType.TypeKind == TypeKind.Interface;

if (!memberTypeIsClassOrInterface)
continue;

var provideAttributes = fieldSymbol
var provideAttributes = field
.GetAttributes()
.Where((attribute) => attribute.AttributeClass.Equals(_provideAttributeSymbol, SymbolEqualityComparer.Default))
.ToList();
Expand All @@ -62,19 +57,56 @@ private void GetProvideDescriptions(AugmentGenerationContext context, IList<Prov
if (arguments.TryGetValue("CastFrom", out var castFromArg))
castFromType = castFromArg.Value as INamedTypeSymbol;

output.Add(new ProvideDescription(fieldSymbol, castFromType));
output.Add(new ProvideDescription(field, castFromType));
}
}
}

private bool DoesFieldsContainLoggerField(IEnumerable<IFieldSymbol> fields)
{
foreach (var field in fields)
{
if (field.Name == "Logger")
return true;
}

return false;
}

protected override bool ShouldGenerateSource(AugmentGenerationContext context)
{
_provideDescriptions.Clear();
_hasLogger = false;
_hasPopulateFieldsMethod = false;

if (_provideAttributeSymbol == null)
return false;

GetProvideDescriptions(context, _provideDescriptions);
var members = RoslynUtilities.GetAllMembersOfType(context.CurrentTypeSymbol).ToList();

_hasPopulateFieldsMethod = members.Any((member) =>
{
if (!(member is IMethodSymbol methodSymbol))
return false;
return member.Name == PopulateMethodName;
});

var fields = members.Where((member) =>
{
if (!(member is IFieldSymbol fieldSymbol))
return false;
var memberType = fieldSymbol.Type;
var memberTypeIsClassOrInterface =
memberType.TypeKind == TypeKind.Class || memberType.TypeKind == TypeKind.Interface;
return memberTypeIsClassOrInterface;
}).Cast<IFieldSymbol>().ToList();

GetProvideDescriptions(fields, _provideDescriptions);
_hasLogger = DoesFieldsContainLoggerField(fields);

return _provideDescriptions.Count > 0;
}
Expand All @@ -83,36 +115,49 @@ protected override void AddAdditionalUsings(HashSet<string> usings)
{
usings.Add("System");
usings.Add("UIComponents");

if (!_hasLogger)
usings.Add("UnityEngine");

base.AddAdditionalUsings(usings);
}

private string CreateExceptionMessage(string message)
{
if (_hasLogger)
return $"Logger.LogError({message}, this);";
else
return $"Debug.LogError({message});";
}

protected override void GenerateSource(AugmentGenerationContext context, StringBuilder stringBuilder)
{
stringBuilder
.AppendPadding()
.AppendCodeGeneratedAttribute()
.AppendLineWithPadding(@"private void UIC_SetProvideField<TField, TCastFrom>(ref TField value, string fieldName) where TField : class where TCastFrom : class
{
.AppendLineWithPadding($@"private void UIC_SetProvideField<TField, TCastFrom>(ref TField value, string fieldName) where TField : class where TCastFrom : class
{{
try
{
{{
value = (TField) (object) Provide<TCastFrom>();
}
}}
catch (MissingProviderException)
{
Logger.LogError($""Could not provide {typeof(TField).Name} to {fieldName}"", this);
}
{{
{CreateExceptionMessage(MissingProviderExceptionMessage)}
}}
catch (InvalidCastException)
{
Logger.LogError($""Could not cast {typeof(TCastFrom).Name} to {typeof(TField).Name}"", this);
}
}
{{
{CreateExceptionMessage(InvalidCastExceptionMessage)}
}}
}}
");
var keyword = _hasPopulateFieldsMethod ? " override " : " ";

stringBuilder
.AppendPadding()
.AppendCodeGeneratedAttribute()
.AppendLineWithPadding(@"protected override void UIC_PopulateProvideFields()
{");
.AppendLineWithPadding($@"protected{keyword}void {PopulateMethodName}()
{{");

foreach (var provideDescription in _provideDescriptions)
{
Expand Down

0 comments on commit b6e219c

Please sign in to comment.