Skip to content

Commit

Permalink
LuisDialog supports multiple LuisModel and ILuisService instances, wi…
Browse files Browse the repository at this point in the history
…th test
  • Loading branch information
willportnoy committed Jun 20, 2016
1 parent e10fcbd commit 87177cf
Show file tree
Hide file tree
Showing 4 changed files with 123 additions and 63 deletions.
37 changes: 17 additions & 20 deletions CSharp/Library/Dialogs/LuisDialog.cs
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ private InvalidIntentHandlerException(SerializationInfo info, StreamingContext c
[Serializable]
public class LuisDialog<R> : IDialog<R>
{
private readonly ILuisService service;
private readonly IReadOnlyList<ILuisService> services;

/// <summary> Mapping from intent string to the appropriate handler. </summary>
[NonSerialized]
Expand All @@ -110,22 +110,13 @@ public class LuisDialog<R> : IDialog<R>
/// <summary>
/// Construct the LUIS dialog.
/// </summary>
/// <param name="service">The LUIS service.</param>
public LuisDialog(ILuisService service = null)
/// <param name="services">The LUIS service.</param>
public LuisDialog(params ILuisService[] services)
{
if (service == null)
{
var type = this.GetType();
var luisModel = type.GetCustomAttribute<LuisModelAttribute>(inherit: true);
if (luisModel == null)
{
throw new Exception("Luis model attribute is not set for the class");
}

service = new LuisService(luisModel);
}

SetField.NotNull(out this.service, nameof(service), service);
var type = this.GetType();
var luisModels = type.GetCustomAttributes<LuisModelAttribute>(inherit: true);
services = services.Concat(luisModels.Select(m => new LuisService(m))).ToArray();
SetField.NotNull(out this.services, nameof(services), services);
}

public virtual async Task StartAsync(IDialogContext context)
Expand All @@ -147,19 +138,25 @@ protected virtual async Task MessageReceived(IDialogContext context, IAwaitable<

var message = await item;
var messageText = await GetLuisQueryTextAsync(context, message);
var luisRes = await this.service.QueryAsync(messageText);
var tasks = this.services.Select(s => s.QueryAsync(messageText)).ToArray();
await Task.WhenAll(tasks);

var intentTask = from task in tasks
let result = task.Result
from intent in result.Intents
select new { result, intent };

var intent = BestIntentFrom(luisRes);
var winner = intentTask.MaxBy(it => it.intent.Score ?? 0);

IntentHandler handler = null;
if (intent == null || !this.handlerByIntent.TryGetValue(intent.Intent, out handler))
if (winner == null || !this.handlerByIntent.TryGetValue(winner.intent.Intent, out handler))
{
handler = this.handlerByIntent[string.Empty];
}

if (handler != null)
{
await handler(context, luisRes);
await handler(context, winner?.result);
}
else
{
Expand Down
2 changes: 1 addition & 1 deletion CSharp/Library/Luis/LuisModel.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ namespace Microsoft.Bot.Builder.Luis
/// <summary>
/// The LUIS model information.
/// </summary>
[AttributeUsage(AttributeTargets.Class, AllowMultiple = false)]
[AttributeUsage(AttributeTargets.Class, AllowMultiple = true)]
[Serializable]
public class LuisModelAttribute : Attribute
{
Expand Down
97 changes: 94 additions & 3 deletions CSharp/Tests/Microsoft.Bot.Builder.Tests/LuisTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,24 +34,60 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Linq.Expressions;
using System.Reflection;
using System.Threading.Tasks;

using Microsoft.VisualStudio.TestTools.UnitTesting;

using Autofac;
using Moq;
using Microsoft.Bot.Builder.Luis;
using Microsoft.Bot.Builder.Dialogs;
using Microsoft.Bot.Builder.Luis.Models;

namespace Microsoft.Bot.Builder.Tests
{
public abstract class LuisTestBase : DialogTestBase
{
public static IntentRecommendation[] IntentsFor<D>(Expression<Func<D, Task>> expression, double? score)
{
var body = (MethodCallExpression)expression.Body;
var attribute = body.Method.GetCustomAttribute<LuisIntentAttribute>();
var name = attribute.IntentName;
var intent = new IntentRecommendation(name, score);
return new[] { intent };
}

public static EntityRecommendation EntityFor(string type, string entity)
{
return new EntityRecommendation(type: type) { Entity = entity };
}

public static void SetupLuis<D>(
Mock<ILuisService> luis,
Expression<Func<D, Task>> expression,
double? score,
params EntityRecommendation[] entities
)
{
luis
.Setup(l => l.QueryAsync(It.IsAny<Uri>()))
.ReturnsAsync(new LuisResult()
{
Intents = IntentsFor(expression, score),
Entities = entities
});
}
}

[TestClass]
public sealed class LuisTests
public sealed class LuisTests : LuisTestBase
{
public sealed class DerivedLuisDialog : LuisDialog<object>
{
public DerivedLuisDialog(ILuisService service)
: base(service)
public DerivedLuisDialog(params ILuisService[] services)
: base(services)
{
}

Expand Down Expand Up @@ -109,6 +145,61 @@ public void All_Handlers_Are_Found()
Assert.AreEqual(7, handlers.Length);
}

[Serializable]
public sealed class MultiServiceLuisDialog : LuisDialog<object>
{
public MultiServiceLuisDialog(params ILuisService[] services)
: base(services)
{
}

[LuisIntent("ServiceOne")]
public async Task ServiceOne(IDialogContext context, LuisResult luisResult)
{
await context.PostAsync(luisResult.Entities.Single().Type);
context.Wait(MessageReceived);
}

[LuisIntent("ServiceTwo")]
public async Task ServiceTwo(IDialogContext context, LuisResult luisResult)
{
await context.PostAsync(luisResult.Entities.Single().Type);
context.Wait(MessageReceived);
}
}

[TestMethod]
public async Task All_Services_Are_Called()
{
var service1 = new Mock<ILuisService>();
var service2 = new Mock<ILuisService>();

var dialog = new MultiServiceLuisDialog(service1.Object, service2.Object);

using (new FiberTestBase.ResolveMoqAssembly(service1.Object, service2.Object))
using (var container = Build(Options.ResolveDialogFromContainer, service1.Object, service2.Object))
{
var builder = new ContainerBuilder();
builder
.RegisterInstance(dialog)
.As<IDialog<object>>();
builder.Update(container);

const string EntityOne = "one";
const string EntityTwo = "two";

SetupLuis<MultiServiceLuisDialog>(service1, d => d.ServiceOne(null, null), 1.0, new EntityRecommendation(type: EntityOne));
SetupLuis<MultiServiceLuisDialog>(service2, d => d.ServiceTwo(null, null), 0.0, new EntityRecommendation(type: EntityTwo));

await AssertScriptAsync(container, "hello", EntityOne);

SetupLuis<MultiServiceLuisDialog>(service1, d => d.ServiceOne(null, null), 0.0, new EntityRecommendation(type: EntityOne));
SetupLuis<MultiServiceLuisDialog>(service2, d => d.ServiceTwo(null, null), 1.0, new EntityRecommendation(type: EntityTwo));

await AssertScriptAsync(container, "hello", EntityTwo);
}
}

public sealed class InvalidLuisDialog : LuisDialog<object>
{
public InvalidLuisDialog(ILuisService service)
Expand Down
50 changes: 11 additions & 39 deletions CSharp/Tests/Microsoft.Bot.Sample.Tests/AlarmBotTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -49,53 +49,25 @@
using Autofac;

using Microsoft.VisualStudio.TestTools.UnitTesting;
using Microsoft.Bot.Sample.SimpleAlarmBot;

namespace Microsoft.Bot.Sample.Tests
{
[TestClass]
public sealed class AlarmBotTests : DialogTestBase
public sealed class AlarmBotTests : LuisTestBase
{
public static IntentRecommendation[] IntentsFor(Expression<Func<SimpleAlarmBot.SimpleAlarmDialog, Task>> expression)
{
var body = (MethodCallExpression)expression.Body;
var attribute = body.Method.GetCustomAttribute<LuisIntentAttribute>();
var name = attribute.IntentName;
var intent = new IntentRecommendation(name);
return new[] { intent };
}

public static EntityRecommendation EntityFor(string type, string entity)
{
return new EntityRecommendation(type: type) { Entity = entity };
}

public static void SetupLuis(
Mock<ILuisService> luis,
Expression<Func<SimpleAlarmBot.SimpleAlarmDialog, Task>> expression,
params EntityRecommendation[] entities
)
{
luis
.Setup(l => l.QueryAsync(It.IsAny<Uri>()))
.ReturnsAsync(new LuisResult()
{
Intents = IntentsFor(expression),
Entities = entities
});
}

[TestMethod]
public async Task AlarmDialogFlow()
{
var luis = new Mock<ILuisService>();

// arrange
var now = DateTime.UtcNow;
var entityTitle = EntityFor(SimpleAlarmBot.SimpleAlarmDialog.Entity_Alarm_Title, "title");
var entityDate = EntityFor(SimpleAlarmBot.SimpleAlarmDialog.Entity_Alarm_Start_Date, now.ToString("d", DateTimeFormatInfo.InvariantInfo));
var entityTime = EntityFor(SimpleAlarmBot.SimpleAlarmDialog.Entity_Alarm_Start_Time, now.ToString("t", DateTimeFormatInfo.InvariantInfo));
var entityTitle = EntityFor(SimpleAlarmDialog.Entity_Alarm_Title, "title");
var entityDate = EntityFor(SimpleAlarmDialog.Entity_Alarm_Start_Date, now.ToString("d", DateTimeFormatInfo.InvariantInfo));
var entityTime = EntityFor(SimpleAlarmDialog.Entity_Alarm_Start_Time, now.ToString("t", DateTimeFormatInfo.InvariantInfo));

Func<IDialog<object>> MakeRoot = () => new SimpleAlarmBot.SimpleAlarmDialog(luis.Object);
Func<IDialog<object>> MakeRoot = () => new SimpleAlarmDialog(luis.Object);
var toBot = MakeTestMessage();

using (new FiberTestBase.ResolveMoqAssembly(luis.Object))
Expand All @@ -108,7 +80,7 @@ public async Task AlarmDialogFlow()
var task = scope.Resolve<IPostToBot>();

// arrange
SetupLuis(luis, a => a.SetAlarm(null, null), entityTitle, entityDate, entityTime);
SetupLuis<SimpleAlarmDialog>(luis, a => a.SetAlarm(null, null), 1.0, entityTitle, entityDate, entityTime);

// act
await task.PostAsync(toBot, CancellationToken.None);
Expand All @@ -125,7 +97,7 @@ public async Task AlarmDialogFlow()
var task = scope.Resolve<IPostToBot>();

// arrange
SetupLuis(luis, a => a.FindAlarm(null, null), entityTitle);
SetupLuis<SimpleAlarmDialog>(luis, a => a.FindAlarm(null, null), 1.0, entityTitle);

// act
await task.PostAsync(toBot, CancellationToken.None);
Expand All @@ -142,7 +114,7 @@ public async Task AlarmDialogFlow()
var task = scope.Resolve<IPostToBot>();

// arrange
SetupLuis(luis, a => a.AlarmSnooze(null, null), entityTitle);
SetupLuis<SimpleAlarmDialog>(luis, a => a.AlarmSnooze(null, null), 1.0, entityTitle);

// act
await task.PostAsync(toBot, CancellationToken.None);
Expand All @@ -159,7 +131,7 @@ public async Task AlarmDialogFlow()
var task = scope.Resolve<IPostToBot>();

// arrange
SetupLuis(luis, a => a.TurnOffAlarm(null, null), entityTitle);
SetupLuis<SimpleAlarmDialog>(luis, a => a.TurnOffAlarm(null, null), 1.0, entityTitle);

// act
await task.PostAsync(toBot, CancellationToken.None);
Expand Down Expand Up @@ -210,7 +182,7 @@ public async Task AlarmDialogFlow()
var task = scope.Resolve<IPostToBot>();

// arrange
SetupLuis(luis, a => a.DeleteAlarm(null, null), entityTitle);
SetupLuis<SimpleAlarmDialog>(luis, a => a.DeleteAlarm(null, null), 1.0, entityTitle);

// act
await task.PostAsync(toBot, CancellationToken.None);
Expand Down

0 comments on commit 87177cf

Please sign in to comment.