Skip to content

Commit

Permalink
Implement create/delete methods (#4)
Browse files Browse the repository at this point in the history
  • Loading branch information
dluc committed Dec 17, 2023
1 parent 11b2225 commit 8fc7850
Show file tree
Hide file tree
Showing 8 changed files with 569 additions and 13 deletions.
3 changes: 2 additions & 1 deletion Directory.Packages.props
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
<ManagePackageVersionsCentrally>true</ManagePackageVersionsCentrally>
</PropertyGroup>
<ItemGroup>
<PackageVersion Include="Microsoft.KernelMemory.Abstractions" Version="0.18.231207.2-preview" />
<PackageVersion Include="Microsoft.KernelMemory.Abstractions" Version="0.22.231215.1" />
<PackageVersion Include="Microsoft.Extensions.DependencyInjection" Version="8.0.0" />
<PackageVersion Include="Pgvector" Version="0.2.0" />
</ItemGroup>
<!-- Sources -->
<ItemGroup>
Expand Down
10 changes: 10 additions & 0 deletions PostgresMemoryStorage/PostgresConfig.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,18 @@ namespace Microsoft.KernelMemory.Postgres;
/// </summary>
public class PostgresConfig
{
/// <summary>
/// Name of the default schema
/// </summary>
public const string DefaultSchema = "public";

/// <summary>
/// Connection string required to connect to Postgres
/// </summary>
public string ConnString { get; set; } = string.Empty;

/// <summary>
/// Name of the schema where to read and write records.
/// </summary>
public string Schema { get; set; } = DefaultSchema;
}
383 changes: 383 additions & 0 deletions PostgresMemoryStorage/PostgresDbClient.cs

Large diffs are not rendered by default.

92 changes: 82 additions & 10 deletions PostgresMemoryStorage/PostgresMemory.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,25 @@

using System;
using System.Collections.Generic;
using System.Text.Json;
using System.Threading;
using System.Threading.Tasks;
using Microsoft.Extensions.Logging;
using Microsoft.KernelMemory.AI;
using Microsoft.KernelMemory.Diagnostics;
using Microsoft.KernelMemory.MemoryStorage;
using Pgvector;

namespace Microsoft.KernelMemory.Postgres;

/// <summary>
/// Postgres connector for Kernel Memory.
/// </summary>
public class PostgresMemory : IMemoryDb
public class PostgresMemory : IMemoryDb, IDisposable
{
private readonly ILogger<PostgresMemory> _log;
private readonly ITextEmbeddingGenerator _embeddingGenerator;
private readonly PostgresDbClient _db;

/// <summary>
/// Create a new instance of Postgres KM connector
Expand All @@ -33,44 +36,73 @@ public class PostgresMemory : IMemoryDb
this._log = log ?? DefaultLogger<PostgresMemory>.Instance;

this._embeddingGenerator = embeddingGenerator;

if (this._embeddingGenerator == null)
{
throw new PostgresException("Embedding generator not configured");
}

this._db = new PostgresDbClient(config.ConnString, config.Schema);
}

/// <inheritdoc />
public Task CreateIndexAsync(
public async Task CreateIndexAsync(
string index,
int vectorSize,
CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
index = NormalizeIndexName(index);

if (await this._db.DoesTableExistsAsync(index, cancellationToken).ConfigureAwait(false))
{
return;
}

await this._db.CreateTableAsync(index, vectorSize, cancellationToken).ConfigureAwait(false);
}

/// <inheritdoc />
public Task<IEnumerable<string>> GetIndexesAsync(
public async Task<IEnumerable<string>> GetIndexesAsync(
CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
var result = new List<string>();
var tables = this._db.GetTablesAsync(cancellationToken).ConfigureAwait(false);
await foreach (string name in tables)
{
result.Add(name);
}

return result;
}

/// <inheritdoc />
public Task DeleteIndexAsync(
string index,
CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
index = NormalizeIndexName(index);

return this._db.DeleteTableAsync(index, cancellationToken);
}

/// <inheritdoc />
public Task<string> UpsertAsync(
public async Task<string> UpsertAsync(
string index,
MemoryRecord record,
CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
index = NormalizeIndexName(index);

await this._db.UpsertAsync(
tableName: index,
id: record.Id,
embedding: new Vector(record.Vector.Data),
tags: PostgresSchema.GetTags(record),
content: PostgresSchema.GetContent(record),
payload: JsonSerializer.Serialize(PostgresSchema.GetPayload(record)),
lastUpdate: DateTimeOffset.UtcNow,
cancellationToken).ConfigureAwait(false);

return record.Id;
}

/// <inheritdoc />
Expand All @@ -83,6 +115,8 @@ public class PostgresMemory : IMemoryDb
bool withEmbeddings = false,
CancellationToken cancellationToken = new CancellationToken())
{
index = NormalizeIndexName(index);

if (filters != null)
{
foreach (MemoryFilter filter in filters)
Expand All @@ -107,6 +141,8 @@ public class PostgresMemory : IMemoryDb
bool withEmbeddings = false,
CancellationToken cancellationToken = default)
{
index = NormalizeIndexName(index);

if (filters != null)
{
foreach (MemoryFilter filter in filters)
Expand All @@ -129,6 +165,42 @@ public class PostgresMemory : IMemoryDb
MemoryRecord record,
CancellationToken cancellationToken = default)
{
throw new NotImplementedException();
index = NormalizeIndexName(index);

return this._db.DeleteAsync(tableName: index, id: record.Id, cancellationToken);
}

/// <inheritdoc/>
public void Dispose()
{
this.Dispose(true);
GC.SuppressFinalize(this);
}

/// <summary>
/// Disposes the managed resources.
/// </summary>
protected virtual void Dispose(bool disposing)
{
if (disposing)
{
(this._db as IDisposable)?.Dispose();
}
}

#region private ================================================================================

private static string NormalizeIndexName(string index)
{
PostgresSchema.ValidateTableName(index);

if (string.IsNullOrWhiteSpace(index))
{
index = Constants.DefaultIndex;
}

return index.Trim();
}

#endregion
}
1 change: 1 addition & 0 deletions PostgresMemoryStorage/PostgresMemoryStorage.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

<ItemGroup>
<PackageReference Include="Microsoft.KernelMemory.Abstractions" />
<PackageReference Include="Pgvector" />
</ItemGroup>

<Import Project="../code-analysis.props" />
Expand Down
89 changes: 89 additions & 0 deletions PostgresMemoryStorage/PostgresSchema.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (c) Microsoft. All rights reserved.

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text.RegularExpressions;
using Microsoft.KernelMemory.MemoryStorage;

namespace Microsoft.KernelMemory.Postgres;

internal static class PostgresSchema
{
private static readonly Regex s_schemaNameRegex = new(@"^[a-zA-Z0-9_]+$");
private static readonly Regex s_tableNameRegex = new(@"^[a-zA-Z0-9_]+$");

// TODO: make these configurable
public const string FieldsId = "id";
public const string FieldsEmbedding = "embedding";
public const string FieldsTags = "tags";
public const string FieldsContent = "content";
public const string FieldsPayload = "payload";
public const string FieldsUpdatedAt = "last_update";

/// <summary>
/// This is used to filter the list of tables when retrieving the list.
/// Only tables with this comment are considered Indexes.
/// TODO: allow to turn off/customize the filtering logic.
/// </summary>
public const string TableComment = "KernelMemoryIndex";

/// <summary>
/// Copy payload from MemoryRecord, excluding the content, which is stored separately.
/// </summary>
/// <param name="record">Source record to copy from</param>
/// <returns>New dictionary with all the payload, except for content</returns>
public static Dictionary<string, object> GetPayload(MemoryRecord? record)
{
if (record == null)
{
return new Dictionary<string, object>();
}

return record.Payload.Where(kv => kv.Key != Constants.ReservedPayloadTextField).ToDictionary(kv => kv.Key, kv => kv.Value);
}

/// <summary>
/// Extract content from MemoryRecord
/// </summary>
/// <param name="record">Source record to extract from</param>
/// <returns>Text content, or empty string if none found</returns>
public static string GetContent(MemoryRecord? record)
{
if (record == null)
{
return string.Empty;
}

if (record.Payload.TryGetValue(Constants.ReservedPayloadTextField, out object? value))
{
return (string)value;
}

return string.Empty;
}

public static string[] GetTags(MemoryRecord? record)
{
if (record == null)
{
return Array.Empty<string>();
}

return record.Tags.Pairs.Select(tag => $"{tag.Key}{Constants.ReservedEqualsChar}{tag.Value}").ToArray();
}

public static void ValidateSchemaName(string name)
{
if (s_schemaNameRegex.IsMatch(name)) { return; }

throw new PostgresException("The schema name contains invalid chars");
}

public static void ValidateTableName(string name)
{
if (s_tableNameRegex.IsMatch(name)) { return; }

throw new PostgresException("The table/index name contains invalid chars");
}
}
2 changes: 1 addition & 1 deletion TestApplication/TestApplication.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
</ItemGroup>

<ItemGroup>
<PackageReference Include="Microsoft.KernelMemory.Core" Version="0.18.231207.2-preview" />
<PackageReference Include="Microsoft.KernelMemory.Core" Version="0.22.231215.1" />
</ItemGroup>

<ItemGroup>
Expand Down
2 changes: 1 addition & 1 deletion nuget-package.props
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<Authors>Microsoft</Authors>
<Company>Microsoft</Company>
<Product>Kernel Memory adapter for Postgres</Product>
<Description>Postgres connector for Microsoft Kernel Memory, to store and search memory using Postgres vector indexing and Postgres features.</Description>
<Description>Postgres(with pgvector extension) connector for Microsoft Kernel Memory, to store and search memory using Postgres vector indexing and Postgres features.</Description>
<PackageTags>Copilot, Memory, RAG, Kernel Memory, Postgres, AI, Artificial Intelligence, Embeddings, Vector DB, Vector Search, ETL</PackageTags>
<PackageId>$(AssemblyName)</PackageId>

Expand Down

0 comments on commit 8fc7850

Please sign in to comment.