Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using NetDaemon.Client.HomeAssistant.Model;
using Microsoft.CodeAnalysis.CSharp;

namespace NetDaemon.HassModel.CodeGenerator;

Expand All @@ -25,46 +25,38 @@ internal static class ExtensionMethodsGenerator
{
public static IEnumerable<MemberDeclarationSyntax> Generate(IEnumerable<HassServiceDomain> serviceDomains, IReadOnlyCollection<EntityDomainMetadata> entityDomains)
{
var entityDomainNames = entityDomains.Select(d => d.Domain).ToHashSet();

// we only want to generate these classes for entities that
var entityClassNameByDomain = entityDomains.ToLookup(e => e.Domain, e => e.EntityClassName);

return serviceDomains
.Where(sd =>
sd.Services?.Any(s => s.Target?.Entity?.Domain != null && entityDomainNames.Contains(s.Target.Entity.Domain)) == true)
.GroupBy(x => x.Domain, x => x.Services)
.Select(GenerateClass);
.Select(sd => GenerateDomainExtensionClass(sd, entityClassNameByDomain))
.OfType<MemberDeclarationSyntax>(); // filter out nulls
}

private static ClassDeclarationSyntax GenerateClass(IGrouping<string?, IReadOnlyCollection<HassService>?> domainServicesGroup)
private static ClassDeclarationSyntax? GenerateDomainExtensionClass(HassServiceDomain serviceDomain, ILookup<string, string> entityClassNameByDomain)
{
var domain = domainServicesGroup.Key!;

var domainServices = domainServicesGroup
.SelectMany(services => services!)
.Where(s => s.Target?.Entity?.Domain != null)
.Select(group => @group)
var serviceMethodDeclarations = serviceDomain.Services
.OrderBy(x => x.Service)
.ToList();

return GenerateDomainExtensionClass(domain, domainServices);
}

private static ClassDeclarationSyntax GenerateDomainExtensionClass(string domain, IEnumerable<HassService> services)
{
var serviceTypeDeclaration = Class(GetEntityDomainExtensionMethodClassName(domain)).ToPublic().ToStatic();

var serviceMethodDeclarations = services
.SelectMany(service => GenerateExtensionMethod(domain, service))
.SelectMany(service => GenerateExtensionMethods(serviceDomain.Domain, service, entityClassNameByDomain))
.ToArray();

return serviceTypeDeclaration.AddMembers(serviceMethodDeclarations);
if (!serviceMethodDeclarations.Any()) return null;

return SyntaxFactory.ClassDeclaration(GetEntityDomainExtensionMethodClassName(serviceDomain.Domain))
.AddMembers(serviceMethodDeclarations)
.ToPublic()
.ToStatic();
}

private static IEnumerable<MemberDeclarationSyntax> GenerateExtensionMethod(string domain, HassService service)
private static IEnumerable<MemberDeclarationSyntax> GenerateExtensionMethods(string domain, HassService service, ILookup<string, string> entityClassNameByDomain)
{
var serviceName = service.Service!;
var targetEntityDomain = service.Target?.Entity?.Domain;
if (targetEntityDomain == null) yield break;

var entityTypeName = entityClassNameByDomain[targetEntityDomain].FirstOrDefault();
if (entityTypeName == null) yield break;

var serviceName = service.Service;
var serviceArguments = ServiceArguments.Create(domain, service);
var entityTypeName = GetDomainEntityTypeName(service.Target?.Entity?.Domain!);
var enumerableTargetTypeName = $"IEnumerable<{entityTypeName}>";

if (serviceArguments is null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,11 +68,6 @@ public static ClassDeclarationSyntax ClassWithInjected<TInjected>(string classNa
return ParseClass(classCode);
}

public static ClassDeclarationSyntax Class(string name)
{
return ClassDeclaration(name);
}

public static TypeDeclarationSyntax Interface(string name)
{
return InterfaceDeclaration(name);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

internal record HassService
{
public string? Service { get; init; }
public required string Service { get; init; }
public string? Description { get; init; }
public IReadOnlyCollection<HassServiceField>? Fields { get; init; }
public TargetSelector? Target { get; init; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@

internal record HassServiceDomain
{
public string? Domain { get; init; }
public IReadOnlyCollection<HassService>? Services { get; init; }
public required string Domain { get; init; }
public required IReadOnlyCollection<HassService> Services { get; init; }
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
using System.Collections.Generic;
using System.IO;
using System.Text;
using Microsoft.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using NetDaemon.Client.HomeAssistant.Model;
using NetDaemon.HassModel.CodeGenerator;
using NetDaemon.HassModel.CodeGenerator.Model;

namespace NetDaemon.HassModel.Tests.CodeGenerator;

internal class CodeGenTestHelper
{
public static CompilationUnitSyntax GenerateCompilationUnit(
CodeGenerationSettings codeGenerationSettings,
IReadOnlyCollection<HassState> states,
IReadOnlyCollection<HassServiceDomain> services)
{
var metaData = EntityMetaDataGenerator.GetEntityDomainMetaData(states);
metaData = EntityMetaDataMerger.Merge(codeGenerationSettings, new EntitiesMetaData(), metaData);
var generatedTypes = Generator.GenerateTypes(metaData.Domains, services).ToArray();
return Generator.BuildCompilationUnit(codeGenerationSettings.Namespace, generatedTypes);

}

public static void AssertCodeCompiles(string generated, string appCode)
{
var syntaxtrees = new []
{
SyntaxFactory.ParseSyntaxTree(generated, path: "generated.cs"),
SyntaxFactory.ParseSyntaxTree(appCode, path: "appcode.cs")
};

var compilation = CSharpCompilation.Create("tempAssembly",
syntaxtrees,
AppDomain.CurrentDomain.GetAssemblies().Where(a => !a.IsDynamic).Select(a => MetadataReference.CreateFromFile(a.Location)).ToArray(),
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary, nullableContextOptions: NullableContextOptions.Enable)
);

var emitResult = compilation.Emit(Stream.Null);

var warningsOrErrors = emitResult.Diagnostics
.Where(d => d.Severity is DiagnosticSeverity.Error or DiagnosticSeverity.Warning).ToList();

if (!warningsOrErrors.Any()) return;

var msg = new StringBuilder("Compile of generated code failed.\r\n");
foreach (var diagnostic in warningsOrErrors)
{
msg.AppendLine(diagnostic.ToString());
}

msg.AppendLine();
msg.AppendLine("generated.cs");
// output the generated code including line numbers to help debugging
msg.AppendLine(string.Join(Environment.NewLine, generated.Split('\n').Select((l, i) => $"{i+1,4}: {l}")));

Assert.Fail(msg.ToString());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@ public class CodeGeneratorTest
[Fact]
public void RunCodeGenEmpy()
{
var code = GenerateCompilationUnit(_settings, Array.Empty<HassState>(), new HassServiceDomain[0]);
var code = CodeGenTestHelper.GenerateCompilationUnit(_settings, Array.Empty<HassState>(), new HassServiceDomain[0]);

code.DescendantNodes().OfType<FileScopedNamespaceDeclarationSyntax>().First().Name.ToString().Should().Be("RootNameSpace");
AssertCodeCompiles(code.ToString(), string.Empty);

CodeGenTestHelper.AssertCodeCompiles(code.ToString(), string.Empty);
}

[Fact]
Expand All @@ -37,7 +37,7 @@ public void TestIEntityGeneration()
new() { EntityId = "switch.switch2" },
};

var generatedCode = GenerateCompilationUnit(_settings, entityStates, Array.Empty<HassServiceDomain>());
var generatedCode = CodeGenTestHelper.GenerateCompilationUnit(_settings, entityStates, Array.Empty<HassServiceDomain>());
var appCode = """
using NetDaemon.HassModel.Entities;
using NetDaemon.HassModel;
Expand All @@ -55,7 +55,7 @@ public void Run(IHaContext ha)
}
}
""";
AssertCodeCompiles(generatedCode.ToString(), appCode);
CodeGenTestHelper.AssertCodeCompiles(generatedCode.ToString(), appCode);
}

[Fact]
Expand Down Expand Up @@ -84,7 +84,7 @@ public void TestNumericSensorEntityGeneration()
},
};

var generatedCode = GenerateCompilationUnit(_settings, entityStates, Array.Empty<HassServiceDomain>());
var generatedCode = CodeGenTestHelper.GenerateCompilationUnit(_settings, entityStates, Array.Empty<HassServiceDomain>());
var appCode = """
using NetDaemon.HassModel.Entities;
using NetDaemon.HassModel;
Expand All @@ -109,7 +109,7 @@ public void Run(IHaContext ha)
}
}
""";
AssertCodeCompiles(generatedCode.ToString(), appCode);
CodeGenTestHelper.AssertCodeCompiles(generatedCode.ToString(), appCode);
}

[Fact]
Expand Down Expand Up @@ -146,7 +146,7 @@ public void TestNumberExtensionMethodGeneration()
}
};

var generatedCode = GenerateCompilationUnit(_settings, entityStates, hassServiceDomains);
var generatedCode = CodeGenTestHelper.GenerateCompilationUnit(_settings, entityStates, hassServiceDomains);
var appCode = """
using NetDaemon.HassModel.Entities;
using NetDaemon.HassModel;
Expand All @@ -162,7 +162,7 @@ public void Run(IHaContext ha)
}
}
""";
AssertCodeCompiles(generatedCode.ToString(), appCode);
CodeGenTestHelper.AssertCodeCompiles(generatedCode.ToString(), appCode);
}

[Fact]
Expand All @@ -188,7 +188,7 @@ public void TestAttributeClassGeneration_UseAttributeBaseClassesFalse()
},
};

var generatedCode = GenerateCompilationUnit(_settings with { UseAttributeBaseClasses = false }, entityStates, Array.Empty<HassServiceDomain>()).ToString();
var generatedCode = CodeGenTestHelper.GenerateCompilationUnit(_settings with { UseAttributeBaseClasses = false }, entityStates, Array.Empty<HassServiceDomain>()).ToString();

var appCode = """
using NetDaemon.HassModel.Entities;
Expand All @@ -210,7 +210,7 @@ public void Run(IHaContext ha)
}
}
""";
AssertCodeCompiles(generatedCode, appCode);
CodeGenTestHelper.AssertCodeCompiles(generatedCode, appCode);
}


Expand All @@ -235,7 +235,7 @@ public void TestAttributeClassGenerationSkipBaseProperties()
},
};

var generatedCode = GenerateCompilationUnit(_settings with { UseAttributeBaseClasses = true }, entityStates, Array.Empty<HassServiceDomain>()).ToString();
var generatedCode = CodeGenTestHelper.GenerateCompilationUnit(_settings with { UseAttributeBaseClasses = true }, entityStates, Array.Empty<HassServiceDomain>()).ToString();
generatedCode.Should().NotContain("Brightness", because: "It is in the base class");

var appCode = """
Expand All @@ -258,116 +258,7 @@ public void Run(IHaContext ha)
}
}
""";
AssertCodeCompiles(generatedCode, appCode);
}

[Fact]
public void TestServicesGeneration()
{
var readOnlyCollection = new HassState[]
{
new() { EntityId = "light.light1" },
};

var hassServiceDomains = new HassServiceDomain[]
{
new()
{
Domain = "light",
Services = new HassService[] {
new() {
Service = "turn_off",
Target = new TargetSelector { Entity = new() { Domain = "light" } }
},
new() {
Service = "turn_on",
Fields = new HassServiceField[] {
new() { Field = "transition", Selector = new NumberSelector(), },
new() { Field = "brightness", Selector = new NumberSelector { Step = 0.2f }, }
},
Target = new TargetSelector { Entity = new() { Domain = "light" } }
}
}
}
};

// Act:
var code = GenerateCompilationUnit(_settings, readOnlyCollection, hassServiceDomains);

var appCode = """
using NetDaemon.HassModel;
using NetDaemon.HassModel.Entities;
using RootNameSpace;

public class Root
{
public void Run(IHaContext ha)
{
var s = new RootNameSpace.Services(ha);

s.Light.TurnOn(new ServiceTarget() );
s.Light.TurnOn(new ServiceTarget(), transition: 12, brightness: 324.5f);
s.Light.TurnOn(new ServiceTarget(), new (){ Transition = 12L, Brightness = 12.3f });
s.Light.TurnOn(new ServiceTarget(), new (){ Brightness = 12.3f });

s.Light.TurnOff(new ServiceTarget());

var light = new RootNameSpace.LightEntity(ha, "light.testLight");

light.TurnOn();
light.TurnOn(transition: 12, brightness: 324.5f);
light.TurnOn(new (){ Transition = 12L, Brightness = 12.3f });
light.TurnOff();
}
}
""";
AssertCodeCompiles(code.ToString(), appCode);
}

private void AssertCodeCompiles(string generated, string appCode)
{
var syntaxtrees = new []
{
SyntaxFactory.ParseSyntaxTree(generated, path: "generated.cs"),
SyntaxFactory.ParseSyntaxTree(appCode, path: "appcode.cs")
};

var compilation = CSharpCompilation.Create("tempAssembly",
syntaxtrees,
AppDomain.CurrentDomain.GetAssemblies().Where(a => !a.IsDynamic).Select(a => MetadataReference.CreateFromFile(a.Location)).ToArray(),
new CSharpCompilationOptions(OutputKind.DynamicallyLinkedLibrary, nullableContextOptions: NullableContextOptions.Enable)
);

var emitResult = compilation.Emit(Stream.Null);

var warningsAndErrors = emitResult.Diagnostics
.Where(d => d.Severity is DiagnosticSeverity.Error or DiagnosticSeverity.Warning).ToList();

if (warningsAndErrors.Any())
{
var msg = new StringBuilder("Compile of generated code failed.\r\n");
foreach (var diagnostic in warningsAndErrors)
{
msg.AppendLine(diagnostic.ToString());
}

msg.AppendLine();
msg.AppendLine("generated.cs");
// output the generated code including line numbers to help debugging
msg.AppendLine(string.Join(Environment.NewLine, generated.Split('\n').Select((l, i) => $"{i+1,4}: {l}")));

Assert.Fail(msg.ToString());
}
CodeGenTestHelper.AssertCodeCompiles(generatedCode, appCode);
}

private static CompilationUnitSyntax GenerateCompilationUnit(
CodeGenerationSettings codeGenerationSettings,
IReadOnlyCollection<HassState> states,
IReadOnlyCollection<HassServiceDomain> services)
{
var metaData = EntityMetaDataGenerator.GetEntityDomainMetaData(states);
metaData = EntityMetaDataMerger.Merge(codeGenerationSettings, new EntitiesMetaData(), metaData);
var generatedTypes = Generator.GenerateTypes(metaData.Domains, services).ToArray();
return Generator.BuildCompilationUnit(codeGenerationSettings.Namespace, generatedTypes);
}
}
Loading