Skip to content

Commit caa59b1

Browse files
authored
Add support for Session-specific resources (#610)
* [session-resources] Add Support for tool-specific resources * [session-resources] session aware now * [session-resources] tests * [session-resources] docs * [session-resources] address feedback * [session-resources] nits * [session-resources] nits * [session-resources] yup * [session-resources] add support to streamable_http * [session-resources] add test * [khan-changes] listing resources * [session-resources] use delete * [session-resources] fixes * [session-resources] fix test
1 parent 6d52180 commit caa59b1

File tree

11 files changed

+609
-22
lines changed

11 files changed

+609
-22
lines changed

go.mod

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module github.com/mark3labs/mcp-go
22

3-
go 1.23
3+
go 1.23.0
44

55
require (
66
github.com/google/uuid v1.6.0

server/server.go

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22
package server
33

44
import (
5+
"cmp"
56
"context"
67
"encoding/base64"
78
"encoding/json"
89
"fmt"
10+
"maps"
911
"slices"
1012
"sort"
1113
"sync"
@@ -826,21 +828,36 @@ func (s *MCPServer) handleListResources(
826828
request mcp.ListResourcesRequest,
827829
) (*mcp.ListResourcesResult, *requestError) {
828830
s.resourcesMu.RLock()
829-
resources := make([]mcp.Resource, 0, len(s.resources))
830-
for _, entry := range s.resources {
831-
resources = append(resources, entry.resource)
831+
resourceMap := make(map[string]mcp.Resource, len(s.resources))
832+
for uri, entry := range s.resources {
833+
resourceMap[uri] = entry.resource
832834
}
833835
s.resourcesMu.RUnlock()
834836

837+
// Check if there are session-specific resources
838+
session := ClientSessionFromContext(ctx)
839+
if session != nil {
840+
if sessionWithResources, ok := session.(SessionWithResources); ok {
841+
if sessionResources := sessionWithResources.GetSessionResources(); sessionResources != nil {
842+
// Merge session-specific resources with global resources
843+
for uri, serverResource := range sessionResources {
844+
resourceMap[uri] = serverResource.Resource
845+
}
846+
}
847+
}
848+
}
849+
835850
// Sort the resources by name
836-
sort.Slice(resources, func(i, j int) bool {
837-
return resources[i].Name < resources[j].Name
851+
resourcesList := slices.SortedFunc(maps.Values(resourceMap), func(a, b mcp.Resource) int {
852+
return cmp.Compare(a.Name, b.Name)
838853
})
854+
855+
// Apply pagination
839856
resourcesToReturn, nextCursor, err := listByPagination(
840857
ctx,
841858
s,
842859
request.Params.Cursor,
843-
resources,
860+
resourcesList,
844861
)
845862
if err != nil {
846863
return nil, &requestError{
@@ -900,9 +917,35 @@ func (s *MCPServer) handleReadResource(
900917
request mcp.ReadResourceRequest,
901918
) (*mcp.ReadResourceResult, *requestError) {
902919
s.resourcesMu.RLock()
920+
921+
// First check session-specific resources
922+
var handler ResourceHandlerFunc
923+
var ok bool
924+
925+
session := ClientSessionFromContext(ctx)
926+
if session != nil {
927+
if sessionWithResources, typeAssertOk := session.(SessionWithResources); typeAssertOk {
928+
if sessionResources := sessionWithResources.GetSessionResources(); sessionResources != nil {
929+
resource, sessionOk := sessionResources[request.Params.URI]
930+
if sessionOk {
931+
handler = resource.Handler
932+
ok = true
933+
}
934+
}
935+
}
936+
}
937+
938+
// If not found in session tools, check global tools
939+
if !ok {
940+
globalResource, rok := s.resources[request.Params.URI]
941+
if rok {
942+
handler = globalResource.handler
943+
ok = true
944+
}
945+
}
946+
903947
// First try direct resource handlers
904-
if entry, ok := s.resources[request.Params.URI]; ok {
905-
handler := entry.handler
948+
if ok {
906949
s.resourcesMu.RUnlock()
907950

908951
finalHandler := handler

server/server_test.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -445,9 +445,8 @@ func TestMCPServer_HandleValidMessages(t *testing.T) {
445445
resp, ok := response.(mcp.JSONRPCResponse)
446446
assert.True(t, ok)
447447

448-
listResult, ok := resp.Result.(mcp.ListResourcesResult)
448+
_, ok = resp.Result.(mcp.ListResourcesResult)
449449
assert.True(t, ok)
450-
assert.NotNil(t, listResult.Resources)
451450
},
452451
},
453452
}

server/session.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ type SessionWithTools interface {
3939
SetSessionTools(tools map[string]ServerTool)
4040
}
4141

42+
// SessionWithResources is an extension of ClientSession that can store session-specific resource data
43+
type SessionWithResources interface {
44+
ClientSession
45+
// GetSessionResources returns the resources specific to this session, if any
46+
// This method must be thread-safe for concurrent access
47+
GetSessionResources() map[string]ServerResource
48+
// SetSessionResources sets resources specific to this session
49+
// This method must be thread-safe for concurrent access
50+
SetSessionResources(resources map[string]ServerResource)
51+
}
52+
4253
// SessionWithClientInfo is an extension of ClientSession that can store client info
4354
type SessionWithClientInfo interface {
4455
ClientSession

server/session_test.go

Lines changed: 126 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"context"
55
"encoding/json"
66
"errors"
7+
"maps"
78
"sync"
89
"sync/atomic"
910
"testing"
@@ -100,6 +101,60 @@ func (f *sessionTestClientWithTools) SetSessionTools(tools map[string]ServerTool
100101
f.sessionTools = toolsCopy
101102
}
102103

104+
// sessionTestClientWithResources implements the SessionWithResources interface for testing
105+
type sessionTestClientWithResources struct {
106+
sessionID string
107+
notificationChannel chan mcp.JSONRPCNotification
108+
initialized bool
109+
sessionResources map[string]ServerResource
110+
mu sync.RWMutex // Mutex to protect concurrent access to sessionResources
111+
}
112+
113+
func (f *sessionTestClientWithResources) SessionID() string {
114+
return f.sessionID
115+
}
116+
117+
func (f *sessionTestClientWithResources) NotificationChannel() chan<- mcp.JSONRPCNotification {
118+
return f.notificationChannel
119+
}
120+
121+
func (f *sessionTestClientWithResources) Initialize() {
122+
f.initialized = true
123+
}
124+
125+
func (f *sessionTestClientWithResources) Initialized() bool {
126+
return f.initialized
127+
}
128+
129+
func (f *sessionTestClientWithResources) GetSessionResources() map[string]ServerResource {
130+
f.mu.RLock()
131+
defer f.mu.RUnlock()
132+
133+
if f.sessionResources == nil {
134+
return nil
135+
}
136+
137+
// Return a copy of the map to prevent concurrent modification
138+
resourcesCopy := make(map[string]ServerResource, len(f.sessionResources))
139+
maps.Copy(resourcesCopy, f.sessionResources)
140+
return resourcesCopy
141+
}
142+
143+
func (f *sessionTestClientWithResources) SetSessionResources(resources map[string]ServerResource) {
144+
f.mu.Lock()
145+
defer f.mu.Unlock()
146+
147+
if resources == nil {
148+
f.sessionResources = nil
149+
return
150+
}
151+
152+
// Create a copy of the map to prevent concurrent modification
153+
resourcesCopy := make(map[string]ServerResource, len(resources))
154+
maps.Copy(resourcesCopy, resources)
155+
f.sessionResources = resourcesCopy
156+
}
157+
103158
// sessionTestClientWithClientInfo implements the SessionWithClientInfo interface for testing
104159
type sessionTestClientWithClientInfo struct {
105160
sessionID string
@@ -151,7 +206,7 @@ func (f *sessionTestClientWithClientInfo) SetClientCapabilities(clientCapabiliti
151206
f.clientCapabilities.Store(clientCapabilities)
152207
}
153208

154-
// sessionTestClientWithTools implements the SessionWithLogging interface for testing
209+
// sessionTestClientWithLogging implements the SessionWithLogging interface for testing
155210
type sessionTestClientWithLogging struct {
156211
sessionID string
157212
notificationChannel chan mcp.JSONRPCNotification
@@ -190,6 +245,7 @@ func (f *sessionTestClientWithLogging) GetLogLevel() mcp.LoggingLevel {
190245
var (
191246
_ ClientSession = (*sessionTestClient)(nil)
192247
_ SessionWithTools = (*sessionTestClientWithTools)(nil)
248+
_ SessionWithResources = (*sessionTestClientWithResources)(nil)
193249
_ SessionWithLogging = (*sessionTestClientWithLogging)(nil)
194250
_ SessionWithClientInfo = (*sessionTestClientWithClientInfo)(nil)
195251
)
@@ -260,6 +316,75 @@ func TestSessionWithTools_Integration(t *testing.T) {
260316
})
261317
}
262318

319+
func TestSessionWithResources_Integration(t *testing.T) {
320+
server := NewMCPServer("test-server", "1.0.0")
321+
322+
// Create session-specific resources
323+
sessionResource := ServerResource{
324+
Resource: mcp.NewResource("ui://resource", "session-resource"),
325+
Handler: func(ctx context.Context, request mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
326+
return []mcp.ResourceContents{mcp.TextResourceContents{
327+
URI: "ui://resource",
328+
Text: "session-resource result",
329+
}}, nil
330+
},
331+
}
332+
333+
// Create a session with resources
334+
session := &sessionTestClientWithResources{
335+
sessionID: "session-1",
336+
notificationChannel: make(chan mcp.JSONRPCNotification, 10),
337+
initialized: true,
338+
sessionResources: map[string]ServerResource{
339+
"ui://resource": sessionResource,
340+
},
341+
}
342+
343+
// Register the session
344+
err := server.RegisterSession(context.Background(), session)
345+
require.NoError(t, err)
346+
347+
// Test that we can access the session-specific resource
348+
testReq := mcp.ReadResourceRequest{}
349+
testReq.Params.URI = "ui://resource"
350+
testReq.Params.Arguments = map[string]any{}
351+
352+
// Call using session context
353+
sessionCtx := server.WithContext(context.Background(), session)
354+
355+
// Check if the session was stored in the context correctly
356+
s := ClientSessionFromContext(sessionCtx)
357+
require.NotNil(t, s, "Session should be available from context")
358+
assert.Equal(t, session.SessionID(), s.SessionID(), "Session ID should match")
359+
360+
// Check if the session can be cast to SessionWithResources
361+
swr, ok := s.(SessionWithResources)
362+
require.True(t, ok, "Session should implement SessionWithResources")
363+
364+
// Check if the resources are accessible
365+
resources := swr.GetSessionResources()
366+
require.NotNil(t, resources, "Session resources should be available")
367+
require.Contains(t, resources, "ui://resource", "Session should have ui://resource")
368+
369+
// Test session resource access with session context
370+
t.Run("test session resource access", func(t *testing.T) {
371+
// First test directly getting the resource from session resources
372+
resource, exists := resources["ui://resource"]
373+
require.True(t, exists, "Session resource should exist in the map")
374+
require.NotNil(t, resource, "Session resource should not be nil")
375+
376+
// Now test calling directly with the handler
377+
result, err := resource.Handler(sessionCtx, testReq)
378+
require.NoError(t, err, "No error calling session resource handler directly")
379+
require.NotNil(t, result, "Result should not be nil")
380+
require.Len(t, result, 1, "Result should have one content item")
381+
382+
textContent, ok := result[0].(mcp.TextResourceContents)
383+
require.True(t, ok, "Content should be TextResourceContents")
384+
assert.Equal(t, "session-resource result", textContent.Text, "Result text should match")
385+
})
386+
}
387+
263388
func TestMCPServer_ToolsWithSessionTools(t *testing.T) {
264389
// Basic test to verify that session-specific tools are returned correctly in a tools list
265390
server := NewMCPServer("test-server", "1.0.0", WithToolCapabilities(true))

server/sse.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ type sseSession struct {
2929
initialized atomic.Bool
3030
loggingLevel atomic.Value
3131
tools sync.Map // stores session-specific tools
32+
resources sync.Map // stores session-specific resources
3233
clientInfo atomic.Value // stores session-specific client info
3334
clientCapabilities atomic.Value // stores session-specific client capabilities
3435
}
@@ -75,6 +76,27 @@ func (s *sseSession) GetLogLevel() mcp.LoggingLevel {
7576
return level.(mcp.LoggingLevel)
7677
}
7778

79+
func (s *sseSession) GetSessionResources() map[string]ServerResource {
80+
resources := make(map[string]ServerResource)
81+
s.resources.Range(func(key, value any) bool {
82+
if resource, ok := value.(ServerResource); ok {
83+
resources[key.(string)] = resource
84+
}
85+
return true
86+
})
87+
return resources
88+
}
89+
90+
func (s *sseSession) SetSessionResources(resources map[string]ServerResource) {
91+
// Clear existing resources
92+
s.resources.Clear()
93+
94+
// Set new resources
95+
for name, resource := range resources {
96+
s.resources.Store(name, resource)
97+
}
98+
}
99+
78100
func (s *sseSession) GetSessionTools() map[string]ServerTool {
79101
tools := make(map[string]ServerTool)
80102
s.tools.Range(func(key, value any) bool {
@@ -125,6 +147,7 @@ func (s *sseSession) GetClientCapabilities() mcp.ClientCapabilities {
125147
var (
126148
_ ClientSession = (*sseSession)(nil)
127149
_ SessionWithTools = (*sseSession)(nil)
150+
_ SessionWithResources = (*sseSession)(nil)
128151
_ SessionWithLogging = (*sseSession)(nil)
129152
_ SessionWithClientInfo = (*sseSession)(nil)
130153
)

0 commit comments

Comments
 (0)