diff --git a/dotnet/typeagent/examples/examplesLib/KnowProWriter.cs b/dotnet/typeagent/examples/examplesLib/KnowProWriter.cs index 3ef6234f1..51df09c89 100644 --- a/dotnet/typeagent/examples/examplesLib/KnowProWriter.cs +++ b/dotnet/typeagent/examples/examplesLib/KnowProWriter.cs @@ -1,6 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using System.Threading.Tasks; using TypeAgent.ExamplesLib.CommandLine; namespace TypeAgent.ExamplesLib; @@ -35,16 +36,16 @@ public static async Task WriteSemanticRefsAsync(IConversation conversation) } } -public static async Task WriteMessagesAsync(IConversation conversation) -{ - await foreach (var message in conversation.Messages) + public static async Task WriteMessagesAsync(IConversation conversation) { - WriteMessage(message); - WriteLine(); + await foreach (var message in conversation.Messages) + { + WriteMessage(message); + WriteLine(); + } } -} -public static void WriteMessage(IMessage message) + public static void WriteMessage(IMessage message) { PushColor(ConsoleColor.Cyan); WriteNameValue("Timestamp", message.Timestamp); @@ -71,7 +72,7 @@ public static void WriteMetadata(IMessage message) } } -public static void WriteEntity(ConcreteEntity? entity) + public static void WriteEntity(ConcreteEntity? entity) { if (entity is not null) { @@ -85,10 +86,36 @@ public static void WriteEntity(ConcreteEntity? entity) } } + public static void WriteAction(TypeAgent.KnowPro.Action? action) + { + if (action is not null) + { + WriteLine(action.ToString()); + } + } + + public static void WriteTopic(Topic? topic) + { + if (topic is not null) + { + WriteLine(topic.Text); + } + } + + public static void WriteTag(Tag tag) + { + if (tag is not null) + { + WriteLine(tag.Text); + } + } + public static async Task WriteConversationSearchResultsAsync( - IConversation conversation, - ConversationSearchResult? searchResult, - bool verbose = false + IConversation conversation, + ConversationSearchResult? searchResult, + bool showKnowledge, + bool showMessages, + bool verbose = false ) { if (searchResult is null) @@ -118,7 +145,7 @@ public static async Task WriteConversationSearchResultsAsync( if (!searchResult.KnowledgeMatches.IsNullOrEmpty()) { WriteLineHeading("Knowledge"); - WriteKnowledgeSearchResults(conversation, searchResult.KnowledgeMatches); + await WriteKnowledgeSearchResultsAsync(conversation, searchResult.KnowledgeMatches); } } @@ -131,9 +158,11 @@ IList messageOrdinals WriteJson(messageOrdinals); } - public static void WriteKnowledgeSearchResults( + public static async Task WriteKnowledgeSearchResultsAsync( IConversation conversation, - IDictionary? results + IDictionary? results, + int? maxToDisplay = null, + bool isAsc = false ) { if (results.IsNullOrEmpty()) @@ -144,15 +173,17 @@ public static void WriteKnowledgeSearchResults( foreach (var kv in results!) { - WriteKnowledgeSearchResult(conversation, kv.Key, kv.Value); + await WriteKnowledgeSearchResultAsync(conversation, kv.Key, kv.Value, maxToDisplay, isAsc); WriteLine(); } } - public static void WriteKnowledgeSearchResult( + public static async Task WriteKnowledgeSearchResultAsync( IConversation conversation, KnowledgeType kType, - SemanticRefSearchResult result + SemanticRefSearchResult result, + int? maxToDisplay = null, + bool isAsc = true ) { WriteLineUnderline(kType.ToString().ToUpper()); @@ -162,7 +193,102 @@ SemanticRefSearchResult result ListType.Ol) ); WriteLine($"{result.SemanticRefMatches.Count} matches"); - WriteJson(result.SemanticRefMatches); + + await WriteScoredSemanticRefsAsync( + result.SemanticRefMatches, + conversation.SemanticRefs, + kType, + maxToDisplay is not null ? maxToDisplay.Value : result.SemanticRefMatches.Count, + isAsc + ); + } + + public static void WriteSemanticRef(SemanticRef sr) + { + switch (sr.KnowledgeType) + { + default: + break; + + case KnowledgeType.EntityTypeName: + case KnowledgeType.STagTypeName: + WriteEntity(sr.AsEntity()); + break; + + case KnowledgeType.ActionTypeName: + WriteAction(sr.AsAction()); + break; + + case KnowledgeType.TopicTypeName: + WriteTopic(sr.AsTopic()); + break; + + case KnowledgeType.TagTypeName: + WriteTag(sr.AsTag()); + break; + } + } + + public static async Task WriteScoredSemanticRefsAsync( + IList semanticRefMatches, + ISemanticRefCollection semanticRefCollection, + KnowledgeType kType, + int maxToDisplay, + bool isAsc = true + ) + { + if (isAsc) + { + WriteLine("Sorted in ascending order(lowest first)"); + } + + var matchesToDisplay = semanticRefMatches.Slice(0, maxToDisplay); + WriteLine($"Displaying {matchesToDisplay.Count} matches of total {semanticRefMatches.Count}"); + + if (kType == KnowledgeType.Entity) + { + IList> entities = await semanticRefCollection.GetDistinctEntitiesAsync(matchesToDisplay); + for (int i = 0; i < entities.Count; ++i) + { + var pos = isAsc ? matchesToDisplay.Count - (i + 1) : i; + WriteLine( + ConsoleColor.Green, + $"{pos + 1} / {matchesToDisplay.Count}: [{entities[i].Score}]" + ); + WriteEntity(entities[i]); + WriteLine(); + } + } + else + { + IList semanticRefs = await semanticRefCollection.GetAsync(matchesToDisplay); + for (int i = 0; i < matchesToDisplay.Count; ++i) + { + var pos = isAsc ? matchesToDisplay.Count - (i + 1) : i; + WriteScoredRef( + pos, + matchesToDisplay.Count, + matchesToDisplay[pos], + semanticRefs[pos] + ); + } + } + + } + + public static void WriteScoredRef( + int matchNumber, + int totalMatches, + ScoredSemanticRefOrdinal scoredRef, + SemanticRef semanticRef + ) + { + WriteLine( + ConsoleColor.Green, + $"#{matchNumber + 1} / {totalMatches}: <{scoredRef.SemanticRefOrdinal}::{semanticRef.Range.Start.MessageOrdinal}> {semanticRef.KnowledgeType} [{scoredRef.Score}]" + ); + WriteSemanticRef(semanticRef); + WriteLine(); } public static void WriteDataFileStats(ConversationData data) diff --git a/dotnet/typeagent/examples/knowProConsole/MemoryCommands.cs b/dotnet/typeagent/examples/knowProConsole/MemoryCommands.cs index 072cb0709..66ebaba17 100644 --- a/dotnet/typeagent/examples/knowProConsole/MemoryCommands.cs +++ b/dotnet/typeagent/examples/knowProConsole/MemoryCommands.cs @@ -154,7 +154,12 @@ private async Task SearchRagAsync(ParseResult args, CancellationToken cancellati namedArgs.Get("budget"), cancellationToken ); - await KnowProWriter.WriteConversationSearchResultsAsync(conversation, matches, true); + await KnowProWriter.WriteConversationSearchResultsAsync( + conversation, + matches, + true, + true + ); } private IConversation EnsureConversation() diff --git a/dotnet/typeagent/examples/knowProConsole/TestCommands.cs b/dotnet/typeagent/examples/knowProConsole/TestCommands.cs index 8f422bc1d..b9979c9a7 100644 --- a/dotnet/typeagent/examples/knowProConsole/TestCommands.cs +++ b/dotnet/typeagent/examples/knowProConsole/TestCommands.cs @@ -124,7 +124,12 @@ private async Task SearchMessagesAsync(ParseResult result, CancellationToken can null, cancellationToken ); - await KnowProWriter.WriteConversationSearchResultsAsync(conversation, searchResults); + await KnowProWriter.WriteConversationSearchResultsAsync( + conversation, + searchResults, + true, + true + ); DateRange? conversationDateRange = await conversation.GetDateRangeAsync(); if (conversationDateRange is not null) @@ -148,7 +153,7 @@ private async Task SearchMessagesAsync(ParseResult result, CancellationToken can }, cancellationToken ); - await KnowProWriter.WriteConversationSearchResultsAsync(conversation, searchResults, true); + await KnowProWriter.WriteConversationSearchResultsAsync(conversation, searchResults, true, true); } private Command TestEmbeddingsDef() @@ -280,7 +285,7 @@ async Task TestSearchKnowledgeAsync(IConversation conversation, SearchTermGroup cancellationToken ).ConfigureAwait(false); - KnowProWriter.WriteKnowledgeSearchResults(_kpContext.Conversation!, results); + await KnowProWriter.WriteKnowledgeSearchResultsAsync(_kpContext.Conversation!, results); } private Command SearchQueryTermsDef() @@ -334,9 +339,14 @@ private async Task SearchLangAsync(ParseResult args, CancellationToken cancellat if (conversation is IMemory memory) { IList results = await memory.SearchAsync(query, null, null, null, cancellationToken); - foreach(var result in results) + foreach (var result in results) { - await KnowProWriter.WriteConversationSearchResultsAsync(conversation, result); + await KnowProWriter.WriteConversationSearchResultsAsync( + conversation, + result, + true, + false + ); } } } diff --git a/dotnet/typeagent/src/common/EnumerationExtensions.cs b/dotnet/typeagent/src/common/EnumerationExtensions.cs index 92fee2d87..401db42c5 100644 --- a/dotnet/typeagent/src/common/EnumerationExtensions.cs +++ b/dotnet/typeagent/src/common/EnumerationExtensions.cs @@ -78,4 +78,13 @@ public static IEnumerable> Batch(this IEnumerable items, int batch yield return batch; } } + + public static List> GetTopK(this IEnumerable> items, int topk) + { + ArgumentVerify.ThrowIfNull(items, nameof(items)); + + var topNList = new TopNCollection(topk); + topNList.Add(items); + return topNList.ByRankAndClear(); + } } diff --git a/dotnet/typeagent/src/common/ListExtensions.cs b/dotnet/typeagent/src/common/ListExtensions.cs index d5ab76e70..b55177188 100644 --- a/dotnet/typeagent/src/common/ListExtensions.cs +++ b/dotnet/typeagent/src/common/ListExtensions.cs @@ -140,13 +140,6 @@ public static int BinarySearchFirst( return lo; } - public static List> GetTopK(this IEnumerable> list, int topK) - { - var topNList = new TopNCollection(topK); - topNList.Add(list); - return topNList.ByRankAndClear(); - } - public static void Fill(this IList list, T value, int count) { for (int i = 0; i < count; ++i) diff --git a/dotnet/typeagent/src/common/Multiset.cs b/dotnet/typeagent/src/common/Multiset.cs index dae84d644..98be3c582 100644 --- a/dotnet/typeagent/src/common/Multiset.cs +++ b/dotnet/typeagent/src/common/Multiset.cs @@ -73,6 +73,23 @@ public void Add(IEnumerable> keyValues) } } + public void AddUnique(TKey key, TValue value) + { + List? values = Get(key); + if (values is null) + { + Add(key, value); + } + else + { + int pos = values.IndexOf(value); + if (pos < 0) + { + Add(key, value); + } + } + } + public void Remove(TKey key, TValue value) { if (TryGetValue(key, out var valueList)) diff --git a/dotnet/typeagent/src/common/StringExtensions.cs b/dotnet/typeagent/src/common/StringExtensions.cs index 6aa6b6a27..4582d1910 100644 --- a/dotnet/typeagent/src/common/StringExtensions.cs +++ b/dotnet/typeagent/src/common/StringExtensions.cs @@ -47,4 +47,16 @@ int maxCharsPerChunk yield return chunk; } } + + public static List LowerAndSort(this List list) + { + int count = list.Count; + for (int i = 0; i < count; ++i) + { + list[i] = list[i].ToLower(); + } + list.Sort(); + return list; + } + } diff --git a/dotnet/typeagent/src/knowpro/ISemanticRefCollection.cs b/dotnet/typeagent/src/knowpro/ISemanticRefCollection.cs index d7958f509..c0ed566f3 100644 --- a/dotnet/typeagent/src/knowpro/ISemanticRefCollection.cs +++ b/dotnet/typeagent/src/knowpro/ISemanticRefCollection.cs @@ -14,14 +14,64 @@ public interface ISemanticRefCollection : IAsyncCollection public static class SemanticRefCollectionExtensions { + // + // These methods use IAsyncCollectionReader because then they also work + // with Caches...see ConversationCache.cs + // + public static ValueTask> GetAsync( this IAsyncCollectionReader semanticRefs, IList scoredOrdinals, CancellationToken cancellationToken = default) { + ArgumentVerify.ThrowIfNull(scoredOrdinals, nameof(scoredOrdinals)); + return semanticRefs.GetAsync( [.. scoredOrdinals.ToOrdinals()], cancellationToken ); } + + public static async ValueTask>> GetScoredAsync( + this IAsyncCollectionReader semanticRefs, + IList scoredOrdinals, + CancellationToken cancellationToken = default) + { + ArgumentVerify.ThrowIfNull(scoredOrdinals, nameof(scoredOrdinals)); + + IList refs = await semanticRefs.GetAsync( + scoredOrdinals, + cancellationToken + ).ConfigureAwait(false); + + List> scored = new List>(refs.Count); + int count = scoredOrdinals.Count; + for (int i = 0; i < count; ++i) + { + scored.Add(new Scored(refs[i], scoredOrdinals[i].Score)); + } + + return scored; + } + + public static async ValueTask>> GetDistinctEntitiesAsync( + this IAsyncCollectionReader semanticRefs, + IList semanticRefMatches, + int? topK = null + ) + { + var scoredEntities = await semanticRefs.GetScoredAsync( + semanticRefMatches + ).ConfigureAwait(false); + + Dictionary> mergedEntities = MergedEntity.MergeScoredEntities(scoredEntities, false); + IEnumerable> entitites = mergedEntities.Values.Select((v) => + { + return new Scored(v.Item.ToConcrete(), v.Score); + }); + + return (topK is not null) + ? entitites.GetTopK(topK.Value) + : [.. entitites]; + } } diff --git a/dotnet/typeagent/src/knowpro/KnowledgeExtractor/KnowledgeMerge.cs b/dotnet/typeagent/src/knowpro/KnowledgeExtractor/KnowledgeMerge.cs deleted file mode 100644 index 7d5625143..000000000 --- a/dotnet/typeagent/src/knowpro/KnowledgeExtractor/KnowledgeMerge.cs +++ /dev/null @@ -1,33 +0,0 @@ -// Copyright (c) Microsoft Corporation. -// Licensed under the MIT License. - -namespace TypeAgent.KnowPro.KnowledgeExtractor; - -public static class KnowledgeMergeExtensions -{ -} - -internal class MergedKnowledge -{ - HashSet? SourceMessageOrdinals { get; set; } = null; - -} - -internal class MergedEntity : MergedKnowledge -{ - public string Name { get; set; } - - public string[] Type { get; set; } - - public MergedFacets? Facets { get; set; } = null; -} - -internal class MergedTopic : MergedKnowledge -{ - public Topic Topic { get; set; } -} - -internal class MergedFacets : Multiset -{ - -} diff --git a/dotnet/typeagent/src/knowpro/KnowledgeImpl.cs b/dotnet/typeagent/src/knowpro/KnowledgeImpl.cs index 6beffe3b9..47a49cc43 100644 --- a/dotnet/typeagent/src/knowpro/KnowledgeImpl.cs +++ b/dotnet/typeagent/src/knowpro/KnowledgeImpl.cs @@ -40,6 +40,35 @@ public void MergeEntityFacet(Facet facet) } Facets = Facets.Append(facet); } + + internal MergedEntity ToMerged() + { + List types = [.. Type]; + types.LowerAndSort(); + + return new MergedEntity() + { + Name = Name.ToLower(), + Type = types, + Facets = !Facets.IsNullOrEmpty() ? ToMergedFacets() : null + }; + } + + internal MergedFacets ToMergedFacets() + { + MergedFacets mergedFacets = []; + if (!Facets.IsNullOrEmpty()) + { + foreach (var facet in Facets) + { + string name = facet.Name.ToLower(); + string value = facet.Value.ToString().ToLower(); + mergedFacets.AddUnique(name, value); + } + } + return mergedFacets; + } + } public partial class Action @@ -74,6 +103,37 @@ private static bool IsDefined(string value) { return !string.IsNullOrEmpty(value) && value != NoneEntityName; } + + public override string ToString() + { + StringBuilder text = new StringBuilder(); + + AppendEntityName(text, SubjectEntityName); + + text.Append($" [{VerbString()}]"); + + AppendEntityName(text, ObjectEntityName); + AppendEntityName(text, IndirectObjectEntityName); + + text.Append($" {{{VerbTense}}}"); + + if (SubjectEntityFacet is not null) + { + text.Append($" <{SubjectEntityFacet.ToString()}>"); + } + return text.ToString(); + } + + private void AppendEntityName(StringBuilder text, string? name) + { + if (text.Length > 0) + { + text.Append(' '); + } + text.Append(IsDefined(name) + ? $"<{name}>" + : "<>"); + } } public partial class Topic @@ -165,10 +225,7 @@ internal void MergeActionKnowledge() if (action.SubjectEntityFacet is not null) { ConcreteEntity? entity = Array.Find(Entities, (c) => c.Name == action.SubjectEntityName); - if (entity is not null) - { - entity.MergeEntityFacet(action.SubjectEntityFacet); - } + entity?.MergeEntityFacet(action.SubjectEntityFacet); action.SubjectEntityFacet = null; } } diff --git a/dotnet/typeagent/src/knowpro/KnowledgeMerge.cs b/dotnet/typeagent/src/knowpro/KnowledgeMerge.cs new file mode 100644 index 000000000..31372b118 --- /dev/null +++ b/dotnet/typeagent/src/knowpro/KnowledgeMerge.cs @@ -0,0 +1,165 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace TypeAgent.KnowPro; + +public static class KnowledgeMergeExtensions +{ +} + +internal class MergedKnowledge +{ + public HashSet? SourceMessageOrdinals { get; set; } = null; + + public void MergeMessageOrdinals(SemanticRef sr) + { + SourceMessageOrdinals ??= []; + SourceMessageOrdinals.Add(sr.Range.Start.MessageOrdinal); + } + +} + +internal class MergedEntity : MergedKnowledge +{ + public MergedEntity() + { + + } + + public string Name { get; set; } + + public IList Type { get; set; } + + public MergedFacets? Facets { get; set; } = null; + + public ConcreteEntity ToConcrete() + { + var entity = new ConcreteEntity + { + Name = Name, + Type = [.. Type], + }; + if (!Facets.IsNullOrEmpty()) + { + entity.Facets = [.. Facets.ToFacets()]; + } + return entity; + } + + public static bool Union(MergedEntity to, MergedEntity other) + { + if (to.Name != other.Name) + { + return false; + } + + to.Type = [.. to.Type.Union(other.Type)]; + to.Facets = MergedFacets.Union(to.Facets, other.Facets); + return false; + } + + public static Dictionary> MergeScoredEntities( + IEnumerable> scoredEntities, + bool mergeOrdinals + ) + { + Dictionary> mergedEntities = []; + + foreach (var scoredEntity in scoredEntities) + { + if (scoredEntity.Item.KnowledgeType != KnowledgeType.Entity) + { + continue; + } + + MergedEntity mergedEntity = scoredEntity.Item.AsEntity().ToMerged(); + Scored? target = null; + if (mergedEntities.TryGetValue(mergedEntity.Name, out var existing)) + { + if (Union(existing.Item, mergedEntity)) + { + if (existing.Score < scoredEntity.Score) + { + existing.Score = scoredEntity.Score; + } + target = existing; + } + else + { + target = null; + } + } + else + { + var newMerged = new Scored(mergedEntity, scoredEntity.Score); + mergedEntities.Add(mergedEntity.Name, newMerged); + target = newMerged; + } + if (target is not null && mergeOrdinals) + { + target.Value.Item.MergeMessageOrdinals(scoredEntity); + } + } + + return mergedEntities; + } +} + +internal class MergedTopic : MergedKnowledge +{ + public Topic Topic { get; set; } +} + +internal class MergedFacets : Multiset +{ + public MergedFacets() + : base() + { + } + + public MergedFacets(IEqualityComparer comparer) + : base(comparer) + { + } + + public static MergedFacets Union(MergedFacets? to, MergedFacets? other) + { + if (to is null) + { + return other; + } + if (other is null) + { + return to; + } + + foreach (var facetName in other.Keys) + { + List? facetValues = other.Get(facetName); + if (!facetValues.IsNullOrEmpty()) + { + int count = facetValues.Count; + for (int i = 0; i < count; ++i) + { + to.AddUnique(facetName, facetValues[i]); + } + } + } + return to; + } + + public IEnumerable ToFacets() + { + foreach (KeyValuePair> kv in this) + { + if (!kv.Value.IsNullOrEmpty()) + { + yield return new Facet + { + Name = kv.Key, + Value = new StringFacetValue(string.Join("; ", kv.Value)) + }; + } + } + } +} diff --git a/dotnet/typeagent/src/knowpro/KnowledgeType.cs b/dotnet/typeagent/src/knowpro/KnowledgeType.cs index 470266cec..9e1c63dd3 100644 --- a/dotnet/typeagent/src/knowpro/KnowledgeType.cs +++ b/dotnet/typeagent/src/knowpro/KnowledgeType.cs @@ -28,15 +28,15 @@ public struct KnowledgeType : IEquatable /// public static readonly KnowledgeType STag = new("sTag"); - internal const string EntityTypeName = "entity"; - internal const string ActionTypeName = "action"; - internal const string TopicTypeName = "topic"; - internal const string TagTypeName = "tag"; - internal const string STagTypeName = "sTag"; + public const string EntityTypeName = "entity"; + public const string ActionTypeName = "action"; + public const string TopicTypeName = "topic"; + public const string TagTypeName = "tag"; + public const string STagTypeName = "sTag"; public static bool IsKnowledgeType(string type) { - return + return type == EntityTypeName || type == ActionTypeName || type == TopicTypeName ||