From 401e27654ae06f5cdcfee720821f9192b2c8558c Mon Sep 17 00:00:00 2001 From: David Sarno Date: Thu, 14 Aug 2025 22:37:34 -0700 Subject: [PATCH 01/25] Unity MCP: stable framing handshake + non-blocking script writes; remove blob stream tools; simplify tool registration - Python server: always consume handshake and negotiate framing on reconnects (prevents invalid framed length).\n- C#: strict FRAMING=1 handshake and NoDelay; debounce AssetDatabase/compilation.\n- Tools: keep manage_script + script edits; remove manage_script_stream and test tools from default registration.\n- Editor window: guard against auto retargeting IDE config. --- UnityMcpBridge/Editor/Tools/ManageScript.cs | 85 +++++++++++-- UnityMcpBridge/Editor/UnityMcpBridge.cs | 119 ++++++++++++++++-- .../Editor/Windows/UnityMcpEditorWindow.cs | 36 +++--- UnityMcpBridge/UnityMcpServer~/src/server.py | 16 +++ .../UnityMcpServer~/src/unity_connection.py | 94 +++++++++----- 5 files changed, 285 insertions(+), 65 deletions(-) diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index d79e17a6..8fa018b1 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -6,6 +6,7 @@ using UnityEditor; using UnityEngine; using UnityMcpBridge.Editor.Helpers; +using System.Threading; #if USE_ROSLYN using Microsoft.CodeAnalysis; @@ -217,13 +218,20 @@ string namespaceName try { - File.WriteAllText(fullPath, contents); - AssetDatabase.ImportAsset(relativePath); - AssetDatabase.Refresh(); // Ensure Unity recognizes the new script - return Response.Success( + // Atomic-ish create + var enc = System.Text.Encoding.UTF8; + var tmp = fullPath + ".tmp"; + File.WriteAllText(tmp, contents, enc); + File.Move(tmp, fullPath); + + var ok = Response.Success( $"Script '{name}.cs' created successfully at '{relativePath}'.", - new { path = relativePath } + new { path = relativePath, scheduledRefresh = true } ); + + // Schedule heavy work AFTER replying + ManageScriptRefreshHelpers.ScheduleScriptRefresh(relativePath); + return ok; } catch (Exception e) { @@ -298,13 +306,33 @@ string contents try { - File.WriteAllText(fullPath, contents); - AssetDatabase.ImportAsset(relativePath); // Re-import to reflect changes - AssetDatabase.Refresh(); - return Response.Success( + // Safe write with atomic replace when available + var encoding = System.Text.Encoding.UTF8; + string tempPath = fullPath + ".tmp"; + File.WriteAllText(tempPath, contents, encoding); + + string backupPath = fullPath + ".bak"; + try + { + File.Replace(tempPath, fullPath, backupPath); + } + catch (PlatformNotSupportedException) + { + // Fallback for platforms without File.Replace + File.Copy(tempPath, fullPath, true); + try { File.Delete(tempPath); } catch { } + } + + // Prepare success response BEFORE any operation that can trigger a domain reload + var ok = Response.Success( $"Script '{name}.cs' updated successfully at '{relativePath}'.", - new { path = relativePath } + new { path = relativePath, scheduledRefresh = true } ); + + // Schedule a debounced import/compile on next editor tick to avoid stalling the reply + ManageScriptRefreshHelpers.ScheduleScriptRefresh(relativePath); + + return ok; } catch (Exception e) { @@ -1028,3 +1056,40 @@ private static void ValidateSemanticRules(string contents, System.Collections.Ge } } +// Debounced refresh/compile scheduler to coalesce bursts of edits +static class RefreshDebounce +{ + private static int _pending; + private static DateTime _last; + + public static void Schedule(string relPath, TimeSpan window) + { + Interlocked.Exchange(ref _pending, 1); + var now = DateTime.UtcNow; + if ((now - _last) < window) return; + _last = now; + + EditorApplication.delayCall += () => + { + if (Interlocked.Exchange(ref _pending, 0) == 1) + { + // Prefer targeted import and script compile over full refresh + AssetDatabase.ImportAsset(relPath, ImportAssetOptions.ForceUpdate); +#if UNITY_EDITOR + UnityEditor.Compilation.CompilationPipeline.RequestScriptCompilation(); +#endif + // Fallback if needed: + // AssetDatabase.Refresh(); + } + }; + } +} + +static class ManageScriptRefreshHelpers +{ + public static void ScheduleScriptRefresh(string relPath) + { + RefreshDebounce.Schedule(relPath, TimeSpan.FromMilliseconds(200)); + } +} + diff --git a/UnityMcpBridge/Editor/UnityMcpBridge.cs b/UnityMcpBridge/Editor/UnityMcpBridge.cs index b7e8ef0e..38030e28 100644 --- a/UnityMcpBridge/Editor/UnityMcpBridge.cs +++ b/UnityMcpBridge/Editor/UnityMcpBridge.cs @@ -395,22 +395,68 @@ private static async Task HandleClientAsync(TcpClient client) using (client) using (NetworkStream stream = client.GetStream()) { + const int MaxMessageBytes = 64 * 1024 * 1024; // 64 MB safety cap + bool framingEnabledForConnection = false; + try + { + var ep = client.Client?.RemoteEndPoint?.ToString() ?? "unknown"; + Debug.Log($"UNITY-MCP: Client connected {ep}"); + } + catch { } + // Strict framing: always require FRAMING=1 and frame all I/O + try + { + client.NoDelay = true; + } + catch { } + try + { + string handshake = "WELCOME UNITY-MCP 1 FRAMING=1\n"; + byte[] handshakeBytes = System.Text.Encoding.ASCII.GetBytes(handshake); + await stream.WriteAsync(handshakeBytes, 0, handshakeBytes.Length); + } + catch { /* ignore */ } + framingEnabledForConnection = true; + Debug.Log("UNITY-MCP: Sent handshake FRAMING=1 (strict)"); + byte[] buffer = new byte[8192]; while (isRunning) { try { - int bytesRead = await stream.ReadAsync(buffer, 0, buffer.Length); - if (bytesRead == 0) + // Strict framed mode + string commandText = null; + bool usedFraming = true; + + if (true) { - break; // Client disconnected + // Enforced framed mode for this connection + byte[] header = new byte[8]; + int headerFilled = 0; + while (headerFilled < 8) + { + int r = await stream.ReadAsync(header, headerFilled, 8 - headerFilled); + if (r == 0) + { + return; // disconnected + } + headerFilled += r; + } + ulong payloadLen = ReadUInt64BigEndian(header); + if (payloadLen == 0UL || payloadLen > (ulong)MaxMessageBytes) + { + throw new System.IO.IOException($"Invalid framed length: {payloadLen}"); + } + byte[] payload = await ReadExactAsync(stream, (int)payloadLen); + commandText = System.Text.Encoding.UTF8.GetString(payload); } - string commandText = System.Text.Encoding.UTF8.GetString( - buffer, - 0, - bytesRead - ); + try + { + var preview = commandText.Length > 120 ? commandText.Substring(0, 120) + "…" : commandText; + Debug.Log($"UNITY-MCP: recv {(usedFraming ? "framed" : "legacy")}: {preview}"); + } + catch { } string commandId = Guid.NewGuid().ToString(); TaskCompletionSource tcs = new(); @@ -422,6 +468,12 @@ private static async Task HandleClientAsync(TcpClient client) /*lang=json,strict*/ "{\"status\":\"success\",\"result\":{\"message\":\"pong\"}}" ); + if (framingEnabledForConnection) + { + byte[] outHeader = new byte[8]; + WriteUInt64BigEndian(outHeader, (ulong)pingResponseBytes.Length); + await stream.WriteAsync(outHeader, 0, outHeader.Length); + } await stream.WriteAsync(pingResponseBytes, 0, pingResponseBytes.Length); continue; } @@ -433,6 +485,12 @@ private static async Task HandleClientAsync(TcpClient client) string response = await tcs.Task; byte[] responseBytes = System.Text.Encoding.UTF8.GetBytes(response); + if (true) + { + byte[] outHeader = new byte[8]; + WriteUInt64BigEndian(outHeader, (ulong)responseBytes.Length); + await stream.WriteAsync(outHeader, 0, outHeader.Length); + } await stream.WriteAsync(responseBytes, 0, responseBytes.Length); } catch (Exception ex) @@ -444,6 +502,51 @@ private static async Task HandleClientAsync(TcpClient client) } } + private static async System.Threading.Tasks.Task ReadExactAsync(NetworkStream stream, int count) + { + byte[] data = new byte[count]; + int offset = 0; + while (offset < count) + { + int r = await stream.ReadAsync(data, offset, count - offset); + if (r == 0) + { + throw new System.IO.IOException("Connection closed before reading expected bytes"); + } + offset += r; + } + return data; + } + + private static ulong ReadUInt64BigEndian(byte[] buffer) + { + if (buffer == null || buffer.Length < 8) return 0UL; + return ((ulong)buffer[0] << 56) + | ((ulong)buffer[1] << 48) + | ((ulong)buffer[2] << 40) + | ((ulong)buffer[3] << 32) + | ((ulong)buffer[4] << 24) + | ((ulong)buffer[5] << 16) + | ((ulong)buffer[6] << 8) + | buffer[7]; + } + + private static void WriteUInt64BigEndian(byte[] dest, ulong value) + { + if (dest == null || dest.Length < 8) + { + throw new System.ArgumentException("Destination buffer too small for UInt64"); + } + dest[0] = (byte)(value >> 56); + dest[1] = (byte)(value >> 48); + dest[2] = (byte)(value >> 40); + dest[3] = (byte)(value >> 32); + dest[4] = (byte)(value >> 24); + dest[5] = (byte)(value >> 16); + dest[6] = (byte)(value >> 8); + dest[7] = (byte)(value); + } + private static void ProcessCommands() { List processedIds = new(); diff --git a/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs b/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs index 9e42d7ff..19446406 100644 --- a/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs +++ b/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs @@ -1550,29 +1550,33 @@ private void CheckMcpConfiguration(McpClient mcpClient) } else { - // Attempt auto-rewrite once if the package path changed - try + // Attempt auto-rewrite once if the package path changed, but only when explicitly enabled + bool autoManage = UnityEditor.EditorPrefs.GetBool("UnityMCP.AutoManageIDEConfig", false); + if (autoManage) { - string rewriteResult = WriteToConfig(pythonDir, configPath, mcpClient); - if (rewriteResult == "Configured successfully") + try { - if (debugLogsEnabled) + string rewriteResult = WriteToConfig(pythonDir, configPath, mcpClient); + if (rewriteResult == "Configured successfully") { - UnityEngine.Debug.Log($"UnityMCP: Auto-updated MCP config for '{mcpClient.name}' to new path: {pythonDir}"); + if (debugLogsEnabled) + { + UnityEngine.Debug.Log($"UnityMCP: Auto-updated MCP config for '{mcpClient.name}' to new path: {pythonDir}"); + } + mcpClient.SetStatus(McpStatus.Configured); + } + else + { + mcpClient.SetStatus(McpStatus.IncorrectPath); } - mcpClient.SetStatus(McpStatus.Configured); } - else + catch (Exception ex) { mcpClient.SetStatus(McpStatus.IncorrectPath); - } - } - catch (Exception ex) - { - mcpClient.SetStatus(McpStatus.IncorrectPath); - if (debugLogsEnabled) - { - UnityEngine.Debug.LogWarning($"UnityMCP: Auto-config rewrite failed for '{mcpClient.name}': {ex.Message}"); + if (debugLogsEnabled) + { + UnityEngine.Debug.LogWarning($"UnityMCP: Auto-config rewrite failed for '{mcpClient.name}': {ex.Message}"); + } } } } diff --git a/UnityMcpBridge/UnityMcpServer~/src/server.py b/UnityMcpBridge/UnityMcpServer~/src/server.py index 55360b57..52633ef4 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/server.py +++ b/UnityMcpBridge/UnityMcpServer~/src/server.py @@ -1,11 +1,13 @@ from mcp.server.fastmcp import FastMCP, Context, Image import logging +from logging.handlers import RotatingFileHandler from dataclasses import dataclass from contextlib import asynccontextmanager from typing import AsyncIterator, Dict, Any, List from config import config from tools import register_all_tools from unity_connection import get_unity_connection, UnityConnection +from pathlib import Path # Configure logging using settings from config logging.basicConfig( @@ -14,6 +16,20 @@ ) logger = logging.getLogger("unity-mcp-server") +# File logging to avoid stdout interference with MCP stdio +try: + log_dir = Path.home() / ".unity-mcp" + log_dir.mkdir(parents=True, exist_ok=True) + file_handler = RotatingFileHandler(str(log_dir / "server.log"), maxBytes=5*1024*1024, backupCount=3) + file_handler.setFormatter(logging.Formatter(config.log_format)) + file_handler.setLevel(getattr(logging, config.log_level)) + logger.addHandler(file_handler) + # Prevent duplicate propagation to root handlers + logger.propagate = False +except Exception: + # If file logging setup fails, continue with stderr logging only + pass + # Global connection state _unity_connection: UnityConnection = None diff --git a/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py b/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py index 9bad736d..bc602040 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py +++ b/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py @@ -1,6 +1,7 @@ import socket import json import logging +import struct from dataclasses import dataclass from pathlib import Path import time @@ -23,6 +24,7 @@ class UnityConnection: host: str = config.unity_host port: int = None # Will be set dynamically sock: socket.socket = None # Socket for Unity communication + use_framing: bool = False # Negotiated per-connection def __post_init__(self): """Set port from discovery if not explicitly provided""" @@ -37,6 +39,19 @@ def connect(self) -> bool: self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.connect((self.host, self.port)) logger.info(f"Connected to Unity at {self.host}:{self.port}") + + # Strict handshake: require FRAMING=1 + try: + self.sock.settimeout(1.0) + greeting = self.sock.recv(256) + text = greeting.decode('ascii', errors='ignore') if greeting else '' + if 'FRAMING=1' in text: + self.use_framing = True + logger.info('Unity MCP handshake received: FRAMING=1 (strict)') + else: + raise ConnectionError(f'Unity MCP requires FRAMING=1, got: {text!r}') + finally: + self.sock.settimeout(config.connection_timeout) return True except Exception as e: logger.error(f"Failed to connect to Unity: {str(e)}") @@ -53,8 +68,33 @@ def disconnect(self): finally: self.sock = None + def _read_exact(self, sock: socket.socket, count: int) -> bytes: + data = bytearray() + while len(data) < count: + chunk = sock.recv(count - len(data)) + if not chunk: + raise Exception("Connection closed before reading expected bytes") + data.extend(chunk) + return bytes(data) + def receive_full_response(self, sock, buffer_size=config.buffer_size) -> bytes: """Receive a complete response from Unity, handling chunked data.""" + if self.use_framing: + try: + header = self._read_exact(sock, 8) + payload_len = struct.unpack('>Q', header)[0] + if payload_len == 0 or payload_len > (64 * 1024 * 1024): + raise Exception(f"Invalid framed length: {payload_len}") + payload = self._read_exact(sock, payload_len) + logger.info(f"Received framed response ({len(payload)} bytes)") + return payload + except socket.timeout: + logger.warning("Socket timeout during framed receive") + raise Exception("Timeout receiving Unity response") + except Exception as e: + logger.error(f"Error during framed receive: {str(e)}") + raise + chunks = [] sock.settimeout(config.connection_timeout) # Use timeout from config try: @@ -166,13 +206,26 @@ def read_status_file() -> dict | None: payload = json.dumps(command, ensure_ascii=False).encode('utf-8') # Send - self.sock.sendall(payload) + try: + logger.debug(f"send {len(payload)} bytes; mode={'framed' if self.use_framing else 'legacy'}; head={(payload[:32]).decode('utf-8','ignore')}") + except Exception: + pass + if self.use_framing: + header = struct.pack('>Q', len(payload)) + self.sock.sendall(header) + self.sock.sendall(payload) + else: + self.sock.sendall(payload) # During retry bursts use a short receive timeout if attempt > 0 and last_short_timeout is None: last_short_timeout = self.sock.gettimeout() self.sock.settimeout(1.0) response_data = self.receive_full_response(self.sock) + try: + logger.debug(f"recv {len(response_data)} bytes; mode={'framed' if self.use_framing else 'legacy'}; head={(response_data[:32]).decode('utf-8','ignore')}") + except Exception: + pass # restore steady-state timeout if changed if last_short_timeout is not None: self.sock.settimeout(config.connection_timeout) @@ -241,43 +294,22 @@ def read_status_file() -> dict | None: _unity_connection = None def get_unity_connection() -> UnityConnection: - """Retrieve or establish a persistent Unity connection.""" + """Retrieve or establish a persistent Unity connection. + + Note: Do NOT ping on every retrieval to avoid connection storms. Rely on + send_command() exceptions to detect broken sockets and reconnect there. + """ global _unity_connection if _unity_connection is not None: - try: - # Try to ping with a short timeout to verify connection - result = _unity_connection.send_command("ping") - # If we get here, the connection is still valid - logger.debug("Reusing existing Unity connection") - return _unity_connection - except Exception as e: - logger.warning(f"Existing connection failed: {str(e)}") - try: - _unity_connection.disconnect() - except: - pass - _unity_connection = None - - # Create a new connection + return _unity_connection + logger.info("Creating new Unity connection") _unity_connection = UnityConnection() if not _unity_connection.connect(): _unity_connection = None raise ConnectionError("Could not connect to Unity. Ensure the Unity Editor and MCP Bridge are running.") - - try: - # Verify the new connection works - _unity_connection.send_command("ping") - logger.info("Successfully established new Unity connection") - return _unity_connection - except Exception as e: - logger.error(f"Could not verify new connection: {str(e)}") - try: - _unity_connection.disconnect() - except: - pass - _unity_connection = None - raise ConnectionError(f"Could not establish valid Unity connection: {str(e)}") + logger.info("Connected to Unity on startup") + return _unity_connection # ----------------------------- From 7eeac659f50212bbc7bb4fbd22805d8d61e3555f Mon Sep 17 00:00:00 2001 From: David Sarno Date: Fri, 15 Aug 2025 10:59:35 -0700 Subject: [PATCH 02/25] Bridge framing hardening: 64MiB cap, zero-length reject, timeout ReadExact, safe write framing; remove unused vars --- UnityMcpBridge/Editor/UnityMcpBridge.cs | 56 +++++++++++++++++-------- 1 file changed, 38 insertions(+), 18 deletions(-) diff --git a/UnityMcpBridge/Editor/UnityMcpBridge.cs b/UnityMcpBridge/Editor/UnityMcpBridge.cs index 38030e28..fa707483 100644 --- a/UnityMcpBridge/Editor/UnityMcpBridge.cs +++ b/UnityMcpBridge/Editor/UnityMcpBridge.cs @@ -35,6 +35,8 @@ private static Dictionary< > commandQueue = new(); private static int currentUnityPort = 6400; // Dynamic port, starts with default private static bool isAutoConnectMode = false; + private const ulong MaxFrameBytes = 64UL * 1024 * 1024; // 64 MiB hard cap for framed payloads + private const int FrameIOTimeoutMs = 30000; // Per-read timeout to avoid stalled clients // Debug helpers private static bool IsDebugEnabled() @@ -395,8 +397,7 @@ private static async Task HandleClientAsync(TcpClient client) using (client) using (NetworkStream stream = client.GetStream()) { - const int MaxMessageBytes = 64 * 1024 * 1024; // 64 MB safety cap - bool framingEnabledForConnection = false; + // Framed I/O only; legacy mode removed try { var ep = client.Client?.RemoteEndPoint?.ToString() ?? "unknown"; @@ -416,7 +417,6 @@ private static async Task HandleClientAsync(TcpClient client) await stream.WriteAsync(handshakeBytes, 0, handshakeBytes.Length); } catch { /* ignore */ } - framingEnabledForConnection = true; Debug.Log("UNITY-MCP: Sent handshake FRAMING=1 (strict)"); byte[] buffer = new byte[8192]; @@ -431,23 +431,14 @@ private static async Task HandleClientAsync(TcpClient client) if (true) { // Enforced framed mode for this connection - byte[] header = new byte[8]; - int headerFilled = 0; - while (headerFilled < 8) - { - int r = await stream.ReadAsync(header, headerFilled, 8 - headerFilled); - if (r == 0) - { - return; // disconnected - } - headerFilled += r; - } + byte[] header = await ReadExactAsync(stream, 8, FrameIOTimeoutMs); ulong payloadLen = ReadUInt64BigEndian(header); - if (payloadLen == 0UL || payloadLen > (ulong)MaxMessageBytes) + if (payloadLen == 0UL || payloadLen > MaxFrameBytes) { throw new System.IO.IOException($"Invalid framed length: {payloadLen}"); } - byte[] payload = await ReadExactAsync(stream, (int)payloadLen); + int payloadLenInt = checked((int)payloadLen); + byte[] payload = await ReadExactAsync(stream, payloadLenInt, FrameIOTimeoutMs); commandText = System.Text.Encoding.UTF8.GetString(payload); } @@ -468,7 +459,10 @@ private static async Task HandleClientAsync(TcpClient client) /*lang=json,strict*/ "{\"status\":\"success\",\"result\":{\"message\":\"pong\"}}" ); - if (framingEnabledForConnection) + if ((ulong)pingResponseBytes.Length > MaxFrameBytes) + { + throw new System.IO.IOException($"Frame too large: {pingResponseBytes.Length}"); + } { byte[] outHeader = new byte[8]; WriteUInt64BigEndian(outHeader, (ulong)pingResponseBytes.Length); @@ -485,7 +479,10 @@ private static async Task HandleClientAsync(TcpClient client) string response = await tcs.Task; byte[] responseBytes = System.Text.Encoding.UTF8.GetBytes(response); - if (true) + if ((ulong)responseBytes.Length > MaxFrameBytes) + { + throw new System.IO.IOException($"Frame too large: {responseBytes.Length}"); + } { byte[] outHeader = new byte[8]; WriteUInt64BigEndian(outHeader, (ulong)responseBytes.Length); @@ -518,6 +515,29 @@ private static async System.Threading.Tasks.Task ReadExactAsync(NetworkS return data; } + // Timeout-aware exact read helper; avoids indefinite stalls + private static async System.Threading.Tasks.Task ReadExactAsync(NetworkStream stream, int count, int timeoutMs) + { + byte[] data = new byte[count]; + int offset = 0; + while (offset < count) + { + var readTask = stream.ReadAsync(data, offset, count - offset); + var completed = await System.Threading.Tasks.Task.WhenAny(readTask, System.Threading.Tasks.Task.Delay(timeoutMs)); + if (completed != readTask) + { + throw new System.IO.IOException("Read timed out"); + } + int r = readTask.Result; + if (r == 0) + { + throw new System.IO.IOException("Connection closed before reading expected bytes"); + } + offset += r; + } + return data; + } + private static ulong ReadUInt64BigEndian(byte[] buffer) { if (buffer == null || buffer.Length < 8) return 0UL; From eafe3095c7284bed1a953883fd74223b93ee3b63 Mon Sep 17 00:00:00 2001 From: David Sarno Date: Fri, 15 Aug 2025 14:15:31 -0700 Subject: [PATCH 03/25] ManageScript: improve method span parsing and validation behavior for MCP edit ops; mitigate false 'no opening brace' errors and allow relaxed validation for text edits --- UnityMcpBridge/Editor/Tools/ManageScript.cs | 652 +++++++++++++++++++- 1 file changed, 632 insertions(+), 20 deletions(-) diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index 8fa018b1..7c9861a5 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -1,6 +1,7 @@ using System; using System.IO; using System.Linq; +using System.Collections.Generic; using System.Text.RegularExpressions; using Newtonsoft.Json.Linq; using UnityEditor; @@ -48,6 +49,47 @@ namespace UnityMcpBridge.Editor.Tools /// public static class ManageScript { + /// + /// Resolves a directory under Assets/, preventing traversal and escaping. + /// Returns fullPathDir on disk and canonical 'Assets/...' relative path. + /// + private static bool TryResolveUnderAssets(string relDir, out string fullPathDir, out string relPathSafe) + { + string assets = Application.dataPath.Replace('\\', '/'); + string targetDir = Path.Combine(assets, (relDir ?? "Scripts")).Replace('\\', '/'); + string full = Path.GetFullPath(targetDir).Replace('\\', '/'); + + bool underAssets = full.StartsWith(assets + "/", StringComparison.OrdinalIgnoreCase) + || string.Equals(full, assets, StringComparison.OrdinalIgnoreCase); + if (!underAssets) + { + fullPathDir = null; + relPathSafe = null; + return false; + } + + // Best-effort symlink guard: if directory is a reparse point/symlink, reject + try + { + var di = new DirectoryInfo(full); + if (di.Exists) + { + var attrs = di.Attributes; + if ((attrs & FileAttributes.ReparsePoint) != 0) + { + fullPathDir = null; + relPathSafe = null; + return false; + } + } + } + catch { /* best effort; proceed */ } + + fullPathDir = full; + string tail = full.Length > assets.Length ? full.Substring(assets.Length).TrimStart('/') : string.Empty; + relPathSafe = ("Assets/" + tail).TrimEnd('/'); + return true; + } /// /// Main handler for script management actions. /// @@ -97,29 +139,16 @@ public static object HandleCommand(JObject @params) ); } - // Ensure path is relative to Assets/, removing any leading "Assets/" - // Set default directory to "Scripts" if path is not provided - string relativeDir = path ?? "Scripts"; // Default to "Scripts" if path is null - if (!string.IsNullOrEmpty(relativeDir)) - { - relativeDir = relativeDir.Replace('\\', '/').Trim('/'); - if (relativeDir.StartsWith("Assets/", StringComparison.OrdinalIgnoreCase)) - { - relativeDir = relativeDir.Substring("Assets/".Length).TrimStart('/'); - } - } - // Handle empty string case explicitly after processing - if (string.IsNullOrEmpty(relativeDir)) + // Resolve and harden target directory under Assets/ + if (!TryResolveUnderAssets(path, out string fullPathDir, out string relPathSafeDir)) { - relativeDir = "Scripts"; // Ensure default if path was provided as "" or only "/" or "Assets/" + return Response.Error($"Invalid path. Target directory must be within 'Assets/'. Provided: '{(path ?? "(null)")}'"); } - // Construct paths + // Construct file paths string scriptFileName = $"{name}.cs"; - string fullPathDir = Path.Combine(Application.dataPath, relativeDir); // Application.dataPath ends in "Assets" string fullPath = Path.Combine(fullPathDir, scriptFileName); - string relativePath = Path.Combine("Assets", relativeDir, scriptFileName) - .Replace('\\', '/'); // Ensure "Assets/" prefix and forward slashes + string relativePath = Path.Combine(relPathSafeDir, scriptFileName).Replace('\\', '/'); // Ensure the target directory exists for create/update if (action == "create" || action == "update") @@ -154,6 +183,12 @@ public static object HandleCommand(JObject @params) return UpdateScript(fullPath, relativePath, name, contents); case "delete": return DeleteScript(fullPath, relativePath); + case "edit": + { + var edits = @params["edits"] as JArray; + var options = @params["options"] as JObject; + return EditScript(fullPath, relativePath, name, edits, options); + } default: return Response.Error( $"Unknown action: '{action}'. Valid actions are: create, read, update, delete." @@ -222,7 +257,17 @@ string namespaceName var enc = System.Text.Encoding.UTF8; var tmp = fullPath + ".tmp"; File.WriteAllText(tmp, contents, enc); - File.Move(tmp, fullPath); + try + { + // Prefer atomic move within same volume + File.Move(tmp, fullPath); + } + catch (IOException) + { + // Cross-volume or other IO constraint: fallback to copy + File.Copy(tmp, fullPath, overwrite: true); + try { File.Delete(tmp); } catch { } + } var ok = Response.Success( $"Script '{name}.cs' created successfully at '{relativePath}'.", @@ -318,7 +363,12 @@ string contents } catch (PlatformNotSupportedException) { - // Fallback for platforms without File.Replace + File.Copy(tempPath, fullPath, true); + try { File.Delete(tempPath); } catch { } + } + catch (IOException) + { + // Cross-volume moves can throw IOException; fallback to copy File.Copy(tempPath, fullPath, true); try { File.Delete(tempPath); } catch { } } @@ -372,6 +422,568 @@ private static object DeleteScript(string fullPath, string relativePath) } } + /// + /// Structured edits (AST-backed where available) on existing scripts. + /// Supports class-level replace/delete with Roslyn span computation if USE_ROSLYN is defined, + /// otherwise falls back to a conservative balanced-brace scan. + /// + private static object EditScript( + string fullPath, + string relativePath, + string name, + JArray edits, + JObject options) + { + if (!File.Exists(fullPath)) + return Response.Error($"Script not found at '{relativePath}'."); + if (edits == null || edits.Count == 0) + return Response.Error("No edits provided."); + + string original; + try { original = File.ReadAllText(fullPath); } + catch (Exception ex) { return Response.Error($"Failed to read script: {ex.Message}"); } + + string working = original; + + try + { + var replacements = new List<(int start, int length, string text)>(); + + foreach (var e in edits) + { + var op = (JObject)e; + var mode = (op.Value("mode") ?? op.Value("op") ?? string.Empty).ToLowerInvariant(); + + switch (mode) + { + case "replace_class": + { + string className = op.Value("className"); + string ns = op.Value("namespace"); + string replacement = ExtractReplacement(op); + + if (string.IsNullOrWhiteSpace(className)) + return Response.Error("replace_class requires 'className'."); + if (replacement == null) + return Response.Error("replace_class requires 'replacement' (inline or base64)."); + + if (!TryComputeClassSpan(working, className, ns, out var spanStart, out var spanLength, out var why)) + return Response.Error($"replace_class failed: {why}"); + + if (!ValidateClassSnippet(replacement, className, out var vErr)) + return Response.Error($"Replacement snippet invalid: {vErr}"); + + replacements.Add((spanStart, spanLength, NormalizeNewlines(replacement))); + break; + } + + case "delete_class": + { + string className = op.Value("className"); + string ns = op.Value("namespace"); + if (string.IsNullOrWhiteSpace(className)) + return Response.Error("delete_class requires 'className'."); + + if (!TryComputeClassSpan(working, className, ns, out var s, out var l, out var why)) + return Response.Error($"delete_class failed: {why}"); + + replacements.Add((s, l, string.Empty)); + break; + } + + case "replace_method": + { + string className = op.Value("className"); + string ns = op.Value("namespace"); + string methodName = op.Value("methodName"); + string replacement = ExtractReplacement(op); + string returnType = op.Value("returnType"); + string parametersSignature = op.Value("parametersSignature"); + string attributesContains = op.Value("attributesContains"); + + if (string.IsNullOrWhiteSpace(className)) return Response.Error("replace_method requires 'className'."); + if (string.IsNullOrWhiteSpace(methodName)) return Response.Error("replace_method requires 'methodName'."); + if (replacement == null) return Response.Error("replace_method requires 'replacement' (inline or base64)."); + + if (!TryComputeClassSpan(working, className, ns, out var clsStart, out var clsLen, out var whyClass)) + return Response.Error($"replace_method failed to locate class: {whyClass}"); + + if (!TryComputeMethodSpan(working, clsStart, clsLen, methodName, returnType, parametersSignature, attributesContains, out var mStart, out var mLen, out var whyMethod)) + return Response.Error($"replace_method failed: {whyMethod}"); + + replacements.Add((mStart, mLen, NormalizeNewlines(replacement))); + break; + } + + case "delete_method": + { + string className = op.Value("className"); + string ns = op.Value("namespace"); + string methodName = op.Value("methodName"); + string returnType = op.Value("returnType"); + string parametersSignature = op.Value("parametersSignature"); + string attributesContains = op.Value("attributesContains"); + + if (string.IsNullOrWhiteSpace(className)) return Response.Error("delete_method requires 'className'."); + if (string.IsNullOrWhiteSpace(methodName)) return Response.Error("delete_method requires 'methodName'."); + + if (!TryComputeClassSpan(working, className, ns, out var clsStart, out var clsLen, out var whyClass)) + return Response.Error($"delete_method failed to locate class: {whyClass}"); + + if (!TryComputeMethodSpan(working, clsStart, clsLen, methodName, returnType, parametersSignature, attributesContains, out var mStart, out var mLen, out var whyMethod)) + return Response.Error($"delete_method failed: {whyMethod}"); + + replacements.Add((mStart, mLen, string.Empty)); + break; + } + + case "insert_method": + { + string className = op.Value("className"); + string ns = op.Value("namespace"); + string position = (op.Value("position") ?? "end").ToLowerInvariant(); + string afterMethodName = op.Value("afterMethodName"); + string afterReturnType = op.Value("afterReturnType"); + string afterParameters = op.Value("afterParametersSignature"); + string afterAttributesContains = op.Value("afterAttributesContains"); + string snippet = ExtractReplacement(op); + + if (string.IsNullOrWhiteSpace(className)) return Response.Error("insert_method requires 'className'."); + if (snippet == null) return Response.Error("insert_method requires 'replacement' (inline or base64) containing a full method declaration."); + + if (!TryComputeClassSpan(working, className, ns, out var clsStart, out var clsLen, out var whyClass)) + return Response.Error($"insert_method failed to locate class: {whyClass}"); + + if (position == "after") + { + if (string.IsNullOrEmpty(afterMethodName)) return Response.Error("insert_method with position='after' requires 'afterMethodName'."); + if (!TryComputeMethodSpan(working, clsStart, clsLen, afterMethodName, afterReturnType, afterParameters, afterAttributesContains, out var aStart, out var aLen, out var whyAfter)) + return Response.Error($"insert_method(after) failed to locate anchor method: {whyAfter}"); + int insAt = aStart + aLen; + string text = NormalizeNewlines("\n\n" + snippet.TrimEnd() + "\n"); + replacements.Add((insAt, 0, text)); + } + else if (!TryFindClassInsertionPoint(working, clsStart, clsLen, position, out var insAt, out var whyIns)) + return Response.Error($"insert_method failed: {whyIns}"); + else + { + string text = NormalizeNewlines("\n\n" + snippet.TrimEnd() + "\n"); + replacements.Add((insAt, 0, text)); + } + break; + } + + default: + return Response.Error($"Unknown edit mode: '{mode}'. Allowed: replace_class, delete_class, replace_method, delete_method, insert_method."); + } + } + + if (HasOverlaps(replacements)) + return Response.Error("Edits overlap; split into separate calls or adjust targets."); + + foreach (var r in replacements.OrderByDescending(r => r.start)) + working = working.Remove(r.start, r.length).Insert(r.start, r.text); + + // Validate result using override from options if provided; otherwise GUI strictness + var level = GetValidationLevelFromGUI(); + try + { + var validateOpt = options?["validate"]?.ToString()?.ToLowerInvariant(); + if (!string.IsNullOrEmpty(validateOpt)) + { + level = validateOpt switch + { + "basic" => ValidationLevel.Basic, + "standard" => ValidationLevel.Standard, + "comprehensive" => ValidationLevel.Comprehensive, + "strict" => ValidationLevel.Strict, + _ => level + }; + } + } + catch { /* ignore option parsing issues */ } + if (!ValidateScriptSyntax(working, level, out var errors)) + return Response.Error("Script validation failed:\n" + string.Join("\n", errors ?? Array.Empty())); + else if (errors != null && errors.Length > 0) + Debug.LogWarning($"Script validation warnings for {name}:\n" + string.Join("\n", errors)); + + // Atomic write with backup; schedule refresh + var enc = System.Text.Encoding.UTF8; + var tmp = fullPath + ".tmp"; + File.WriteAllText(tmp, working, enc); + string backup = fullPath + ".bak"; + try { File.Replace(tmp, fullPath, backup); } + catch (PlatformNotSupportedException) { File.Copy(tmp, fullPath, true); try { File.Delete(tmp); } catch { } } + catch (IOException) { File.Copy(tmp, fullPath, true); try { File.Delete(tmp); } catch { } } + + // Decide refresh behavior + string refreshMode = options?["refresh"]?.ToString()?.ToLowerInvariant(); + bool immediate = refreshMode == "immediate" || refreshMode == "sync"; + + var ok = Response.Success( + $"Applied {replacements.Count} structured edit(s) to '{relativePath}'.", + new { path = relativePath, editsApplied = replacements.Count, scheduledRefresh = !immediate } + ); + + if (immediate) + { + // Force an immediate import/compile on the main thread + AssetDatabase.ImportAsset(relativePath, ImportAssetOptions.ForceSynchronousImport | ImportAssetOptions.ForceUpdate); +#if UNITY_EDITOR + UnityEditor.Compilation.CompilationPipeline.RequestScriptCompilation(); +#endif + } + else + { + ManageScriptRefreshHelpers.ScheduleScriptRefresh(relativePath); + } + return ok; + } + catch (Exception ex) + { + return Response.Error($"Edit failed: {ex.Message}"); + } + } + + private static bool HasOverlaps(IEnumerable<(int start, int length, string text)> list) + { + var arr = list.OrderBy(x => x.start).ToArray(); + for (int i = 1; i < arr.Length; i++) + { + if (arr[i - 1].start + arr[i - 1].length > arr[i].start) + return true; + } + return false; + } + + private static string ExtractReplacement(JObject op) + { + var inline = op.Value("replacement"); + if (!string.IsNullOrEmpty(inline)) return inline; + + var b64 = op.Value("replacementBase64"); + if (!string.IsNullOrEmpty(b64)) + { + try { return System.Text.Encoding.UTF8.GetString(Convert.FromBase64String(b64)); } + catch { return null; } + } + return null; + } + + private static string NormalizeNewlines(string t) + { + if (string.IsNullOrEmpty(t)) return t; + return t.Replace("\r\n", "\n").Replace("\r", "\n"); + } + + private static bool ValidateClassSnippet(string snippet, string expectedName, out string err) + { +#if USE_ROSLYN + try + { + var tree = CSharpSyntaxTree.ParseText(snippet); + var root = tree.GetRoot(); + var classes = root.DescendantNodes().OfType().ToList(); + if (classes.Count != 1) { err = "snippet must contain exactly one class declaration"; return false; } + // Optional: enforce expected name + // if (classes[0].Identifier.ValueText != expectedName) { err = $"snippet declares '{classes[0].Identifier.ValueText}', expected '{expectedName}'"; return false; } + err = null; return true; + } + catch (Exception ex) { err = ex.Message; return false; } +#else + if (string.IsNullOrWhiteSpace(snippet) || !snippet.Contains("class ")) { err = "no 'class' keyword found in snippet"; return false; } + err = null; return true; +#endif + } + + private static bool TryComputeClassSpan(string source, string className, string ns, out int start, out int length, out string why) + { +#if USE_ROSLYN + try + { + var tree = CSharpSyntaxTree.ParseText(source); + var root = tree.GetRoot(); + var classes = root.DescendantNodes() + .OfType() + .Where(c => c.Identifier.ValueText == className); + + if (!string.IsNullOrEmpty(ns)) + { + classes = classes.Where(c => + (c.FirstAncestorOrSelf()?.Name?.ToString() ?? "") == ns + || (c.FirstAncestorOrSelf()?.Name?.ToString() ?? "") == ns); + } + + var list = classes.ToList(); + if (list.Count == 0) { start = length = 0; why = $"class '{className}' not found" + (ns != null ? $" in namespace '{ns}'" : ""); return false; } + if (list.Count > 1) { start = length = 0; why = $"class '{className}' matched {list.Count} declarations (partial/nested?). Disambiguate."; return false; } + + var cls = list[0]; + var span = cls.FullSpan; // includes attributes & leading trivia + start = span.Start; length = span.Length; why = null; return true; + } + catch + { + // fall back below + } +#endif + return TryComputeClassSpanBalanced(source, className, ns, out start, out length, out why); + } + + private static bool TryComputeClassSpanBalanced(string source, string className, string ns, out int start, out int length, out string why) + { + start = length = 0; why = null; + var idx = IndexOfClassToken(source, className); + if (idx < 0) { why = $"class '{className}' not found (balanced scan)"; return false; } + + if (!string.IsNullOrEmpty(ns) && !AppearsWithinNamespaceHeader(source, idx, ns)) + { why = $"class '{className}' not under namespace '{ns}' (balanced scan)"; return false; } + + // Include modifiers/attributes on the same line: back up to the start of line + int lineStart = idx; + while (lineStart > 0 && source[lineStart - 1] != '\n' && source[lineStart - 1] != '\r') lineStart--; + + int i = idx; + while (i < source.Length && source[i] != '{') i++; + if (i >= source.Length) { why = "no opening brace after class header"; return false; } + + int depth = 0; bool inStr = false, inChar = false, inSL = false, inML = false, esc = false; + int startSpan = lineStart; + for (; i < source.Length; i++) + { + char c = source[i]; + char n = i + 1 < source.Length ? source[i + 1] : '\0'; + + if (inSL) { if (c == '\n') inSL = false; continue; } + if (inML) { if (c == '*' && n == '/') { inML = false; i++; } continue; } + if (inStr) { if (!esc && c == '"') inStr = false; esc = (!esc && c == '\\'); continue; } + if (inChar) { if (!esc && c == '\'') inChar = false; esc = (!esc && c == '\\'); continue; } + + if (c == '/' && n == '/') { inSL = true; i++; continue; } + if (c == '/' && n == '*') { inML = true; i++; continue; } + if (c == '"') { inStr = true; continue; } + if (c == '\'') { inChar = true; continue; } + + if (c == '{') { depth++; } + else if (c == '}') + { + depth--; + if (depth == 0) { start = startSpan; length = (i - startSpan) + 1; return true; } + if (depth < 0) { why = "brace underflow"; return false; } + } + } + why = "unterminated class block"; return false; + } + + private static bool TryComputeMethodSpan( + string source, + int classStart, + int classLength, + string methodName, + string returnType, + string parametersSignature, + string attributesContains, + out int start, + out int length, + out string why) + { + start = length = 0; why = null; + int searchStart = classStart; + int searchEnd = Math.Min(source.Length, classStart + classLength); + + // 1) Find the method header using a stricter regex (allows optional attributes above) + string rtPattern = string.IsNullOrEmpty(returnType) ? @"[^\s]+" : Regex.Escape(returnType).Replace("\\ ", "\\s+"); + string namePattern = Regex.Escape(methodName); + string paramsPattern = string.IsNullOrEmpty(parametersSignature) ? @"[\s\S]*?" : Regex.Escape(parametersSignature); + string pattern = + @"(?m)^[\t ]*(?:\[[^\n\]]+\][\t ]*\n)*[\t ]*" + + @"(?:(?:public|private|protected|internal|static|virtual|override|sealed|async|extern|unsafe|new|partial|readonly|volatile|event|abstract|ref|in|out)\s+)*" + + rtPattern + @"[\t ]+" + namePattern + @"\s*\(" + paramsPattern + @"\)"; + + string slice = source.Substring(searchStart, searchEnd - searchStart); + var headerMatch = Regex.Match(slice, pattern, RegexOptions.Multiline); + if (!headerMatch.Success) + { + why = $"method '{methodName}' header not found in class"; return false; + } + int headerIndex = searchStart + headerMatch.Index; + + // Optional attributes filter: look upward from headerIndex for contiguous attribute lines + if (!string.IsNullOrEmpty(attributesContains)) + { + int attrScanStart = headerIndex; + while (attrScanStart > searchStart) + { + int prevNl = source.LastIndexOf('\n', attrScanStart - 1); + if (prevNl < 0 || prevNl < searchStart) break; + string prevLine = source.Substring(prevNl + 1, attrScanStart - (prevNl + 1)); + if (prevLine.TrimStart().StartsWith("[")) { attrScanStart = prevNl; continue; } + break; + } + string attrBlock = source.Substring(attrScanStart, headerIndex - attrScanStart); + if (attrBlock.IndexOf(attributesContains, StringComparison.Ordinal) < 0) + { + why = $"method '{methodName}' found but attributes filter did not match"; return false; + } + } + + // backtrack to the very start of header/attributes to include in span + int lineStart = headerIndex; + while (lineStart > searchStart && source[lineStart - 1] != '\n' && source[lineStart - 1] != '\r') lineStart--; + // If previous lines are attributes, include them + int attrStart = lineStart; + int probe = lineStart - 1; + while (probe > searchStart) + { + int prevNl = source.LastIndexOf('\n', probe); + if (prevNl < 0 || prevNl < searchStart) break; + string prev = source.Substring(prevNl + 1, attrStart - (prevNl + 1)); + if (prev.TrimStart().StartsWith("[")) { attrStart = prevNl + 1; probe = prevNl - 1; } + else break; + } + + // 2) Walk from the end of signature to detect body style ('{' or '=> ...;') and compute end + int i = headerIndex; + int parenDepth = 0; bool inStr = false, inChar = false, inSL = false, inML = false, esc = false; + for (; i < searchEnd; i++) + { + char c = source[i]; + char n = i + 1 < searchEnd ? source[i + 1] : '\0'; + if (inSL) { if (c == '\n') inSL = false; continue; } + if (inML) { if (c == '*' && n == '/') { inML = false; i++; } continue; } + if (inStr) { if (!esc && c == '"') inStr = false; esc = (!esc && c == '\\'); continue; } + if (inChar) { if (!esc && c == '\'') inChar = false; esc = (!esc && c == '\\'); continue; } + + if (c == '/' && n == '/') { inSL = true; i++; continue; } + if (c == '/' && n == '*') { inML = true; i++; continue; } + if (c == '"') { inStr = true; continue; } + if (c == '\'') { inChar = true; continue; } + + if (c == '(') parenDepth++; + if (c == ')') { parenDepth--; if (parenDepth == 0) { i++; break; } } + } + + // After params: detect expression-bodied or block-bodied + // Skip whitespace/comments + for (; i < searchEnd; i++) + { + char c = source[i]; + char n = i + 1 < searchEnd ? source[i + 1] : '\0'; + if (char.IsWhiteSpace(c)) continue; + if (c == '/' && n == '/') { while (i < searchEnd && source[i] != '\n') i++; continue; } + if (c == '/' && n == '*') { i += 2; while (i + 1 < searchEnd && !(source[i] == '*' && source[i + 1] == '/')) i++; i++; continue; } + break; + } + + if (i < searchEnd - 1 && source[i] == '=' && source[i + 1] == '>') + { + // expression-bodied method: seek to terminating semicolon + int j = i; + bool done = false; + while (j < searchEnd) + { + char c = source[j]; + if (c == ';') { done = true; break; } + j++; + } + if (!done) { why = "unterminated expression-bodied method"; return false; } + start = attrStart; length = (j - attrStart) + 1; return true; + } + + if (i >= searchEnd || source[i] != '{') { why = "no opening brace after method signature"; return false; } + + int depth = 0; inStr = false; inChar = false; inSL = false; inML = false; esc = false; + int startSpan = attrStart; + for (; i < searchEnd; i++) + { + char c = source[i]; + char n = i + 1 < searchEnd ? source[i + 1] : '\0'; + if (inSL) { if (c == '\n') inSL = false; continue; } + if (inML) { if (c == '*' && n == '/') { inML = false; i++; } continue; } + if (inStr) { if (!esc && c == '"') inStr = false; esc = (!esc && c == '\\'); continue; } + if (inChar) { if (!esc && c == '\'') inChar = false; esc = (!esc && c == '\\'); continue; } + + if (c == '/' && n == '/') { inSL = true; i++; continue; } + if (c == '/' && n == '*') { inML = true; i++; continue; } + if (c == '"') { inStr = true; continue; } + if (c == '\'') { inChar = true; continue; } + + if (c == '{') depth++; + else if (c == '}') + { + depth--; + if (depth == 0) { start = startSpan; length = (i - startSpan) + 1; return true; } + if (depth < 0) { why = "brace underflow in method"; return false; } + } + } + why = "unterminated method block"; return false; + } + + private static int IndexOfTokenWithin(string s, string token, int start, int end) + { + int idx = s.IndexOf(token, start, StringComparison.Ordinal); + return (idx >= 0 && idx < end) ? idx : -1; + } + + private static bool TryFindClassInsertionPoint(string source, int classStart, int classLength, string position, out int insertAt, out string why) + { + insertAt = 0; why = null; + int searchStart = classStart; + int searchEnd = Math.Min(source.Length, classStart + classLength); + + if (position == "start") + { + // find first '{' after class header, insert just after with a newline + int i = IndexOfTokenWithin(source, "{", searchStart, searchEnd); + if (i < 0) { why = "could not find class opening brace"; return false; } + insertAt = i + 1; return true; + } + else // end + { + // walk to matching closing brace of class and insert just before it + int i = IndexOfTokenWithin(source, "{", searchStart, searchEnd); + if (i < 0) { why = "could not find class opening brace"; return false; } + int depth = 0; bool inStr = false, inChar = false, inSL = false, inML = false, esc = false; + for (; i < searchEnd; i++) + { + char c = source[i]; + char n = i + 1 < searchEnd ? source[i + 1] : '\0'; + if (inSL) { if (c == '\n') inSL = false; continue; } + if (inML) { if (c == '*' && n == '/') { inML = false; i++; } continue; } + if (inStr) { if (!esc && c == '"') inStr = false; esc = (!esc && c == '\\'); continue; } + if (inChar) { if (!esc && c == '\'') inChar = false; esc = (!esc && c == '\\'); continue; } + + if (c == '/' && n == '/') { inSL = true; i++; continue; } + if (c == '/' && n == '*') { inML = true; i++; continue; } + if (c == '"') { inStr = true; continue; } + if (c == '\'') { inChar = true; continue; } + + if (c == '{') depth++; + else if (c == '}') + { + depth--; + if (depth == 0) { insertAt = i; return true; } + if (depth < 0) { why = "brace underflow while scanning class"; return false; } + } + } + why = "could not find class closing brace"; return false; + } + } + + private static int IndexOfClassToken(string s, string className) + { + // simple token search; could be tightened with Regex for word boundaries + var pattern = "class " + className; + return s.IndexOf(pattern, StringComparison.Ordinal); + } + + private static bool AppearsWithinNamespaceHeader(string s, int pos, string ns) + { + int from = Math.Max(0, pos - 2000); + var slice = s.Substring(from, pos - from); + return slice.Contains("namespace " + ns); + } + /// /// Generates basic C# script content based on name and type. /// From 73d212fc9c6d980d0f2f8b3b77dcbafd11327f05 Mon Sep 17 00:00:00 2001 From: David Sarno Date: Fri, 15 Aug 2025 22:45:35 -0700 Subject: [PATCH 04/25] Unity MCP: prefer micro-edits & resources; add script_apply_edits priority and server apply_text_edits/validate; add resources list/read; deprecate manage_script read/update/edit; remove stdout prints; tweak connection handshake logging --- UnityMcpBridge/Editor/Tools/ManageScript.cs | 308 ++++++++++++++++-- .../UnityMcpServer~/src/pyrightconfig.json | 4 + UnityMcpBridge/UnityMcpServer~/src/server.py | 75 ++++- .../UnityMcpServer~/src/tools/__init__.py | 7 +- .../src/tools/manage_script.py | 10 +- .../src/tools/manage_script_edits.py | 148 +++++++++ .../UnityMcpServer~/src/unity_connection.py | 15 +- test_unity_socket_framing.py | 88 +++++ 8 files changed, 613 insertions(+), 42 deletions(-) create mode 100644 UnityMcpBridge/UnityMcpServer~/src/pyrightconfig.json create mode 100644 UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py create mode 100644 test_unity_socket_framing.py diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index 7c9861a5..d2df4584 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -56,7 +56,14 @@ public static class ManageScript private static bool TryResolveUnderAssets(string relDir, out string fullPathDir, out string relPathSafe) { string assets = Application.dataPath.Replace('\\', '/'); - string targetDir = Path.Combine(assets, (relDir ?? "Scripts")).Replace('\\', '/'); + + // Normalize caller path: allow both "Scripts/..." and "Assets/Scripts/..." + string rel = (relDir ?? "Scripts").Replace('\\', '/').Trim(); + if (string.IsNullOrEmpty(rel)) rel = "Scripts"; + if (rel.StartsWith("Assets/", StringComparison.OrdinalIgnoreCase)) rel = rel.Substring(7); + rel = rel.TrimStart('/'); + + string targetDir = Path.Combine(assets, rel).Replace('\\', '/'); string full = Path.GetFullPath(targetDir).Replace('\\', '/'); bool underAssets = full.StartsWith(assets + "/", StringComparison.OrdinalIgnoreCase) @@ -178,17 +185,40 @@ public static object HandleCommand(JObject @params) namespaceName ); case "read": - return ReadScript(fullPath, relativePath); + return Response.Error("Deprecated: reads are resources now. Use resources/read with a unity://path or unity://script URI."); case "update": - return UpdateScript(fullPath, relativePath, name, contents); + return Response.Error("Deprecated: use apply_text_edits (small, line/col edits) rather than whole-file replace."); case "delete": return DeleteScript(fullPath, relativePath); - case "edit": + case "apply_text_edits": { var edits = @params["edits"] as JArray; - var options = @params["options"] as JObject; - return EditScript(fullPath, relativePath, name, edits, options); + string precondition = @params["precondition_sha256"]?.ToString(); // optional, currently ignored here + return ApplyTextEdits(fullPath, relativePath, name, edits); + } + case "validate": + { + string level = @params["level"]?.ToString()?.ToLowerInvariant() ?? "standard"; + var chosen = level switch + { + "basic" => ValidationLevel.Basic, + "strict" => ValidationLevel.Strict, + _ => ValidationLevel.Standard + }; + string fileText; + try { fileText = File.ReadAllText(fullPath); } + catch (Exception ex) { return Response.Error($"Failed to read script: {ex.Message}"); } + + bool ok = ValidateScriptSyntax(fileText, chosen, out string[] diags); + var result = new + { + isValid = ok, + diagnostics = diags ?? Array.Empty() + }; + return ok ? Response.Success("Validation completed.", result) : Response.Error("Validation failed.", result); } + case "edit": + return Response.Error("Deprecated: use apply_text_edits. Structured 'edit' mode has been retired in favor of simple text edits."); default: return Response.Error( $"Unknown action: '{action}'. Valid actions are: create, read, update, delete." @@ -390,6 +420,108 @@ string contents } } + /// + /// Apply simple text edits specified by line/column ranges. Applies transactionally and validates result. + /// + private static object ApplyTextEdits( + string fullPath, + string relativePath, + string name, + JArray edits) + { + if (!File.Exists(fullPath)) + return Response.Error($"Script not found at '{relativePath}'."); + if (edits == null || edits.Count == 0) + return Response.Error("No edits provided."); + + string original; + try { original = File.ReadAllText(fullPath); } + catch (Exception ex) { return Response.Error($"Failed to read script: {ex.Message}"); } + + // Convert edits to absolute index ranges + var spans = new List<(int start, int end, string text)>(); + foreach (var e in edits) + { + try + { + int sl = Math.Max(1, e.Value("startLine")); + int sc = Math.Max(1, e.Value("startCol")); + int el = Math.Max(1, e.Value("endLine")); + int ec = Math.Max(1, e.Value("endCol")); + string newText = e.Value("newText") ?? string.Empty; + + if (!TryIndexFromLineCol(original, sl, sc, out int sidx)) + return Response.Error($"apply_text_edits: start out of range (line {sl}, col {sc})"); + if (!TryIndexFromLineCol(original, el, ec, out int eidx)) + return Response.Error($"apply_text_edits: end out of range (line {el}, col {ec})"); + if (eidx < sidx) (sidx, eidx) = (eidx, sidx); + + spans.Add((sidx, eidx, newText)); + } + catch (Exception ex) + { + return Response.Error($"Invalid edit payload: {ex.Message}"); + } + } + + // Ensure non-overlap and apply from back to front + spans = spans.OrderByDescending(t => t.start).ToList(); + for (int i = 1; i < spans.Count; i++) + { + if (spans[i].end > spans[i - 1].start) + return Response.Error("Edits overlap; split into separate calls or adjust ranges."); + } + + string working = original; + foreach (var sp in spans) + { + working = working.Remove(sp.start, sp.end - sp.start).Insert(sp.start, sp.text ?? string.Empty); + } + + // Validate result + var level = GetValidationLevelFromGUI(); + if (!ValidateScriptSyntax(working, level, out var errors)) + return Response.Error("Script validation failed:\n" + string.Join("\n", errors ?? Array.Empty())); + + // Atomic write and schedule refresh + try + { + var enc = System.Text.Encoding.UTF8; + var tmp = fullPath + ".tmp"; + File.WriteAllText(tmp, working, enc); + string backup = fullPath + ".bak"; + try { File.Replace(tmp, fullPath, backup); } + catch (PlatformNotSupportedException) { File.Copy(tmp, fullPath, true); try { File.Delete(tmp); } catch { } } + catch (IOException) { File.Copy(tmp, fullPath, true); try { File.Delete(tmp); } catch { } } + + ManageScriptRefreshHelpers.ScheduleScriptRefresh(relativePath); + return Response.Success($"Applied {spans.Count} text edit(s) to '{relativePath}'.", new { path = relativePath, editsApplied = spans.Count, scheduledRefresh = true }); + } + catch (Exception ex) + { + return Response.Error($"Failed to write edits: {ex.Message}"); + } + } + + private static bool TryIndexFromLineCol(string text, int line1, int col1, out int index) + { + // 1-based line/col to absolute index (0-based), col positions are counted in code points + int line = 1, col = 1; + for (int i = 0; i <= text.Length; i++) + { + if (line == line1 && col == col1) + { + index = i; + return true; + } + if (i == text.Length) break; + char c = text[i]; + if (c == '\n') { line++; col = 1; } + else { col++; } + } + index = -1; return false; + } + private static object DeleteScript(string fullPath, string relativePath) { if (!File.Exists(fullPath)) @@ -448,6 +580,12 @@ private static object EditScript( try { var replacements = new List<(int start, int length, string text)>(); + int appliedCount = 0; + + // Apply mode: atomic (default) computes all spans against original and applies together. + // Sequential applies each edit immediately to the current working text (useful for dependent edits). + string applyMode = options?["applyMode"]?.ToString()?.ToLowerInvariant(); + bool applySequentially = applyMode == "sequential"; foreach (var e in edits) { @@ -473,7 +611,15 @@ private static object EditScript( if (!ValidateClassSnippet(replacement, className, out var vErr)) return Response.Error($"Replacement snippet invalid: {vErr}"); - replacements.Add((spanStart, spanLength, NormalizeNewlines(replacement))); + if (applySequentially) + { + working = working.Remove(spanStart, spanLength).Insert(spanStart, NormalizeNewlines(replacement)); + appliedCount++; + } + else + { + replacements.Add((spanStart, spanLength, NormalizeNewlines(replacement))); + } break; } @@ -487,7 +633,15 @@ private static object EditScript( if (!TryComputeClassSpan(working, className, ns, out var s, out var l, out var why)) return Response.Error($"delete_class failed: {why}"); - replacements.Add((s, l, string.Empty)); + if (applySequentially) + { + working = working.Remove(s, l); + appliedCount++; + } + else + { + replacements.Add((s, l, string.Empty)); + } break; } @@ -509,9 +663,24 @@ private static object EditScript( return Response.Error($"replace_method failed to locate class: {whyClass}"); if (!TryComputeMethodSpan(working, clsStart, clsLen, methodName, returnType, parametersSignature, attributesContains, out var mStart, out var mLen, out var whyMethod)) - return Response.Error($"replace_method failed: {whyMethod}"); + { + bool hasDependentInsert = edits.Any(j => j is JObject jo && + string.Equals(jo.Value("className"), className, StringComparison.Ordinal) && + string.Equals(jo.Value("methodName"), methodName, StringComparison.Ordinal) && + ((jo.Value("mode") ?? jo.Value("op") ?? string.Empty).ToLowerInvariant() == "insert_method")); + string hint = hasDependentInsert && !applySequentially ? " Hint: This batch inserts this method. Use options.applyMode='sequential' or split into separate calls." : string.Empty; + return Response.Error($"replace_method failed: {whyMethod}.{hint}"); + } - replacements.Add((mStart, mLen, NormalizeNewlines(replacement))); + if (applySequentially) + { + working = working.Remove(mStart, mLen).Insert(mStart, NormalizeNewlines(replacement)); + appliedCount++; + } + else + { + replacements.Add((mStart, mLen, NormalizeNewlines(replacement))); + } break; } @@ -531,9 +700,24 @@ private static object EditScript( return Response.Error($"delete_method failed to locate class: {whyClass}"); if (!TryComputeMethodSpan(working, clsStart, clsLen, methodName, returnType, parametersSignature, attributesContains, out var mStart, out var mLen, out var whyMethod)) - return Response.Error($"delete_method failed: {whyMethod}"); + { + bool hasDependentInsert = edits.Any(j => j is JObject jo && + string.Equals(jo.Value("className"), className, StringComparison.Ordinal) && + string.Equals(jo.Value("methodName"), methodName, StringComparison.Ordinal) && + ((jo.Value("mode") ?? jo.Value("op") ?? string.Empty).ToLowerInvariant() == "insert_method")); + string hint = hasDependentInsert && !applySequentially ? " Hint: This batch inserts this method. Use options.applyMode='sequential' or split into separate calls." : string.Empty; + return Response.Error($"delete_method failed: {whyMethod}.{hint}"); + } - replacements.Add((mStart, mLen, string.Empty)); + if (applySequentially) + { + working = working.Remove(mStart, mLen); + appliedCount++; + } + else + { + replacements.Add((mStart, mLen, string.Empty)); + } break; } @@ -561,14 +745,30 @@ private static object EditScript( return Response.Error($"insert_method(after) failed to locate anchor method: {whyAfter}"); int insAt = aStart + aLen; string text = NormalizeNewlines("\n\n" + snippet.TrimEnd() + "\n"); - replacements.Add((insAt, 0, text)); + if (applySequentially) + { + working = working.Insert(insAt, text); + appliedCount++; + } + else + { + replacements.Add((insAt, 0, text)); + } } else if (!TryFindClassInsertionPoint(working, clsStart, clsLen, position, out var insAt, out var whyIns)) return Response.Error($"insert_method failed: {whyIns}"); else { string text = NormalizeNewlines("\n\n" + snippet.TrimEnd() + "\n"); - replacements.Add((insAt, 0, text)); + if (applySequentially) + { + working = working.Insert(insAt, text); + appliedCount++; + } + else + { + replacements.Add((insAt, 0, text)); + } } break; } @@ -578,11 +778,15 @@ private static object EditScript( } } - if (HasOverlaps(replacements)) - return Response.Error("Edits overlap; split into separate calls or adjust targets."); + if (!applySequentially) + { + if (HasOverlaps(replacements)) + return Response.Error("Edits overlap; split into separate calls or adjust targets."); - foreach (var r in replacements.OrderByDescending(r => r.start)) - working = working.Remove(r.start, r.length).Insert(r.start, r.text); + foreach (var r in replacements.OrderByDescending(r => r.start)) + working = working.Remove(r.start, r.length).Insert(r.start, r.text); + appliedCount = replacements.Count; + } // Validate result using override from options if provided; otherwise GUI strictness var level = GetValidationLevelFromGUI(); @@ -621,8 +825,8 @@ private static object EditScript( bool immediate = refreshMode == "immediate" || refreshMode == "sync"; var ok = Response.Success( - $"Applied {replacements.Count} structured edit(s) to '{relativePath}'.", - new { path = relativePath, editsApplied = replacements.Count, scheduledRefresh = !immediate } + $"Applied {appliedCount} structured edit(s) to '{relativePath}'.", + new { path = relativePath, editsApplied = appliedCount, scheduledRefresh = !immediate } ); if (immediate) @@ -796,9 +1000,9 @@ private static bool TryComputeMethodSpan( string namePattern = Regex.Escape(methodName); string paramsPattern = string.IsNullOrEmpty(parametersSignature) ? @"[\s\S]*?" : Regex.Escape(parametersSignature); string pattern = - @"(?m)^[\t ]*(?:\[[^\n\]]+\][\t ]*\n)*[\t ]*" + + @"(?m)^[\t ]*(?:\[[^\]]+\][\t ]*)*[\t ]*" + @"(?:(?:public|private|protected|internal|static|virtual|override|sealed|async|extern|unsafe|new|partial|readonly|volatile|event|abstract|ref|in|out)\s+)*" + - rtPattern + @"[\t ]+" + namePattern + @"\s*\(" + paramsPattern + @"\)"; + rtPattern + @"[\t ]+" + namePattern + @"\s*(?:<[^>]+>)?\s*\(" + paramsPattern + @"\)"; string slice = source.Substring(searchStart, searchEnd - searchStart); var headerMatch = Regex.Match(slice, pattern, RegexOptions.Multiline); @@ -843,7 +1047,13 @@ private static bool TryComputeMethodSpan( } // 2) Walk from the end of signature to detect body style ('{' or '=> ...;') and compute end - int i = headerIndex; + // Find the '(' that belongs to the method signature, not attributes + int nameTokenIdx = IndexOfTokenWithin(source, methodName, headerIndex, searchEnd); + if (nameTokenIdx < 0) { why = $"method '{methodName}' token not found after header"; return false; } + int sigOpenParen = IndexOfTokenWithin(source, "(", nameTokenIdx, searchEnd); + if (sigOpenParen < 0) { why = "method parameter list '(' not found"; return false; } + + int i = sigOpenParen; int parenDepth = 0; bool inStr = false, inChar = false, inSL = false, inML = false, esc = false; for (; i < searchEnd; i++) { @@ -875,6 +1085,58 @@ private static bool TryComputeMethodSpan( break; } + // Tolerate generic constraints between params and body: multiple 'where T : ...' + for (;;) + { + // Skip whitespace/comments before checking for 'where' + for (; i < searchEnd; i++) + { + char c = source[i]; + char n = i + 1 < searchEnd ? source[i + 1] : '\0'; + if (char.IsWhiteSpace(c)) continue; + if (c == '/' && n == '/') { while (i < searchEnd && source[i] != '\n') i++; continue; } + if (c == '/' && n == '*') { i += 2; while (i + 1 < searchEnd && !(source[i] == '*' && source[i + 1] == '/')) i++; i++; continue; } + break; + } + + // Check word-boundary 'where' + bool hasWhere = false; + if (i + 5 <= searchEnd) + { + hasWhere = source[i] == 'w' && source[i + 1] == 'h' && source[i + 2] == 'e' && source[i + 3] == 'r' && source[i + 4] == 'e'; + if (hasWhere) + { + // Left boundary + if (i - 1 >= 0) + { + char lb = source[i - 1]; + if (char.IsLetterOrDigit(lb) || lb == '_') hasWhere = false; + } + // Right boundary + if (hasWhere && i + 5 < searchEnd) + { + char rb = source[i + 5]; + if (char.IsLetterOrDigit(rb) || rb == '_') hasWhere = false; + } + } + } + if (!hasWhere) break; + + // Advance past the entire where-constraint clause until we hit '{' or '=>' or ';' + i += 5; // past 'where' + while (i < searchEnd) + { + char c = source[i]; + char n = i + 1 < searchEnd ? source[i + 1] : '\0'; + if (c == '{' || c == ';' || (c == '=' && n == '>')) break; + // Skip comments inline + if (c == '/' && n == '/') { while (i < searchEnd && source[i] != '\n') i++; continue; } + if (c == '/' && n == '*') { i += 2; while (i + 1 < searchEnd && !(source[i] == '*' && source[i + 1] == '/')) i++; i++; continue; } + i++; + } + } + + // Re-check for expression-bodied after constraints if (i < searchEnd - 1 && source[i] == '=' && source[i + 1] == '>') { // expression-bodied method: seek to terminating semicolon diff --git a/UnityMcpBridge/UnityMcpServer~/src/pyrightconfig.json b/UnityMcpBridge/UnityMcpServer~/src/pyrightconfig.json new file mode 100644 index 00000000..cfa4ff8c --- /dev/null +++ b/UnityMcpBridge/UnityMcpServer~/src/pyrightconfig.json @@ -0,0 +1,4 @@ +{ + "typeCheckingMode": "basic", + "reportMissingImports": "none" +} diff --git a/UnityMcpBridge/UnityMcpServer~/src/server.py b/UnityMcpBridge/UnityMcpServer~/src/server.py index 52633ef4..88add06d 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/server.py +++ b/UnityMcpBridge/UnityMcpServer~/src/server.py @@ -1,5 +1,6 @@ from mcp.server.fastmcp import FastMCP, Context, Image import logging +import sys from logging.handlers import RotatingFileHandler from dataclasses import dataclass from contextlib import asynccontextmanager @@ -9,12 +10,20 @@ from unity_connection import get_unity_connection, UnityConnection from pathlib import Path -# Configure logging using settings from config -logging.basicConfig( - level=getattr(logging, config.log_level), - format=config.log_format -) +# Configure logging: strictly stderr/file only (never stdout) +stderr_handler = logging.StreamHandler(stream=sys.stderr) +stderr_handler.setFormatter(logging.Formatter(config.log_format)) + +handlers = [stderr_handler] logger = logging.getLogger("unity-mcp-server") +logger.setLevel(getattr(logging, config.log_level)) +for h in list(logger.handlers): + logger.removeHandler(h) +for h in list(logging.getLogger().handlers): + logging.getLogger().removeHandler(h) +logger.addHandler(stderr_handler) +logging.getLogger().addHandler(stderr_handler) +logging.getLogger().setLevel(getattr(logging, config.log_level)) # File logging to avoid stdout interference with MCP stdio try: @@ -84,6 +93,62 @@ def asset_creation_strategy() -> str: "- Always include a camera and main light in your scenes.\\n" ) +# Resources support: list and read Unity scripts/files +@mcp.capabilities(resources={"listChanged": True}) +class _: + pass + +import os +import hashlib + +def _unity_assets_root() -> str: + # Heuristic: from the Unity project root (one level up from Library/ProjectSettings), 'Assets' + # Here, assume server runs from repo; let clients pass absolute paths under project too. + return None + +def _safe_path(uri: str) -> str | None: + # URIs: unity://path/Assets/... or file:///absolute + if uri.startswith("unity://path/"): + p = uri[len("unity://path/"):] + return p + if uri.startswith("file://"): + return uri[len("file://"):] + # Minimal tolerance for plain Assets/... paths + if uri.startswith("Assets/"): + return uri + return None + +@mcp.resource.list() +def list_resources(ctx: Context) -> list[dict]: + # Lightweight: expose only C# under Assets by default + assets = [] + try: + root = os.getcwd() + for base, _, files in os.walk(os.path.join(root, "Assets")): + for f in files: + if f.endswith(".cs"): + rel = os.path.relpath(os.path.join(base, f), root).replace("\\", "/") + assets.append({ + "uri": f"unity://path/{rel}", + "name": os.path.basename(rel) + }) + except Exception: + pass + return assets + +@mcp.resource.read() +def read_resource(ctx: Context, uri: str) -> dict: + path = _safe_path(uri) + if not path or not os.path.exists(path): + return {"mimeType": "text/plain", "text": f"Resource not found: {uri}"} + try: + with open(path, "r", encoding="utf-8") as f: + text = f.read() + sha = hashlib.sha256(text.encode("utf-8")).hexdigest() + return {"mimeType": "text/plain", "text": text, "metadata": {"sha256": sha}} + except Exception as e: + return {"mimeType": "text/plain", "text": f"Error reading resource: {e}"} + # Run the server if __name__ == "__main__": mcp.run(transport='stdio') diff --git a/UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py b/UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py index 4d8d63cf..91ee9495 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py +++ b/UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py @@ -1,3 +1,4 @@ +from .manage_script_edits import register_manage_script_edits_tools from .manage_script import register_manage_script_tools from .manage_scene import register_manage_scene_tools from .manage_editor import register_manage_editor_tools @@ -9,7 +10,9 @@ def register_all_tools(mcp): """Register all refactored tools with the MCP server.""" - print("Registering Unity MCP Server refactored tools...") + # Note: Do not print to stdout; Claude treats stdout as MCP JSON. Use logging. + # Prefer the surgical edits tool so LLMs discover it first + register_manage_script_edits_tools(mcp) register_manage_script_tools(mcp) register_manage_scene_tools(mcp) register_manage_editor_tools(mcp) @@ -18,4 +21,4 @@ def register_all_tools(mcp): register_manage_shader_tools(mcp) register_read_console_tools(mcp) register_execute_menu_item_tools(mcp) - print("Unity MCP Server tool registration complete.") + # Do not print to stdout here either. diff --git a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py index a41fb85c..af44a446 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py +++ b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py @@ -19,8 +19,10 @@ def manage_script( script_type: str, namespace: str ) -> Dict[str, Any]: - """Manages C# scripts in Unity (create, read, update, delete). - Make reference variables public for easier access in the Unity Editor. + """Manage C# scripts in Unity. + + IMPORTANT: + - This router is minimized. Use resources/read for file content and 'script_apply_edits' for changes. Args: action: Operation ('create', 'read', 'update', 'delete'). @@ -34,6 +36,10 @@ def manage_script( Dictionary with results ('success', 'message', 'data'). """ try: + # Deprecate full-file update path entirely + if action == 'update': + return {"success": False, "message": "Deprecated: use script_apply_edits (line/col edits) or resources/read + small edits."} + # Prepare parameters for Unity params = { "action": action, diff --git a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py new file mode 100644 index 00000000..9cb746df --- /dev/null +++ b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py @@ -0,0 +1,148 @@ +from mcp.server.fastmcp import FastMCP, Context +from typing import Dict, Any, List +import base64 +import re +from unity_connection import send_command_with_retry + + +def _apply_edits_locally(original_text: str, edits: List[Dict[str, Any]]) -> str: + text = original_text + for edit in edits or []: + op = ( + (edit.get("op") + or edit.get("operation") + or edit.get("type") + or edit.get("mode") + or "") + .strip() + .lower() + ) + + if not op: + allowed = "anchor_insert, prepend, append, replace_range, regex_replace" + raise RuntimeError( + f"op is required; allowed: {allowed}. Use 'op' (aliases accepted: type/mode/operation)." + ) + + if op == "prepend": + prepend_text = edit.get("text", "") + text = (prepend_text if prepend_text.endswith("\n") else prepend_text + "\n") + text + elif op == "append": + append_text = edit.get("text", "") + if not text.endswith("\n"): + text += "\n" + text += append_text + if not text.endswith("\n"): + text += "\n" + elif op == "anchor_insert": + anchor = edit.get("anchor", "") + position = (edit.get("position") or "before").lower() + insert_text = edit.get("text", "") + flags = re.MULTILINE + m = re.search(anchor, text, flags) + if not m: + if edit.get("allow_noop", True): + continue + raise RuntimeError(f"anchor not found: {anchor}") + idx = m.start() if position == "before" else m.end() + text = text[:idx] + insert_text + text[idx:] + elif op == "replace_range": + start_line = int(edit.get("startLine", 1)) + end_line = int(edit.get("endLine", start_line)) + replacement = edit.get("text", "") + lines = text.splitlines(keepends=True) + if start_line < 1 or end_line < start_line or end_line > len(lines): + raise RuntimeError("replace_range out of bounds") + a = start_line - 1 + b = end_line + rep = replacement + if rep and not rep.endswith("\n"): + rep += "\n" + text = "".join(lines[:a]) + rep + "".join(lines[b:]) + elif op == "regex_replace": + pattern = edit.get("pattern", "") + repl = edit.get("replacement", "") + count = int(edit.get("count", 0)) # 0 = replace all + flags = re.MULTILINE + if edit.get("ignore_case"): + flags |= re.IGNORECASE + text = re.sub(pattern, repl, text, count=count, flags=flags) + else: + allowed = "anchor_insert, prepend, append, replace_range, regex_replace" + raise RuntimeError(f"unknown edit op: {op}; allowed: {allowed}. Use 'op' (aliases accepted: type/mode/operation).") + return text + + +def register_manage_script_edits_tools(mcp: FastMCP): + @mcp.tool(description=( + "Apply targeted edits to an existing C# script WITHOUT replacing the whole file. " + "Preferred for inserts/patches. Supports ops: anchor_insert, prepend, append, " + "replace_range, regex_replace. For full-file creation, use manage_script(create)." + )) + def script_apply_edits( + ctx: Context, + name: str, + path: str, + edits: List[Dict[str, Any]], + options: Dict[str, Any] | None = None, + script_type: str = "MonoBehaviour", + namespace: str = "", + ) -> Dict[str, Any]: + # If the edits request structured class/method ops, route directly to Unity's 'edit' action + for e in edits or []: + op = (e.get("op") or e.get("operation") or e.get("type") or e.get("mode") or "").strip().lower() + if op in ("replace_class", "delete_class", "replace_method", "delete_method", "insert_method"): + params: Dict[str, Any] = { + "action": "edit", + "name": name, + "path": path, + "namespace": namespace, + "scriptType": script_type, + "edits": edits, + } + if options is not None: + params["options"] = options + resp = send_command_with_retry("manage_script", params) + return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)} + + # 1) read from Unity + read_resp = send_command_with_retry("manage_script", { + "action": "read", + "name": name, + "path": path, + "namespace": namespace, + "scriptType": script_type, + }) + if not isinstance(read_resp, dict) or not read_resp.get("success"): + return read_resp if isinstance(read_resp, dict) else {"success": False, "message": str(read_resp)} + + data = read_resp.get("data") or read_resp.get("result", {}).get("data") or {} + contents = data.get("contents") + if contents is None and data.get("contentsEncoded") and data.get("encodedContents"): + contents = base64.b64decode(data["encodedContents"]).decode("utf-8") + if contents is None: + return {"success": False, "message": "No contents returned from Unity read."} + + # 2) apply edits locally + try: + new_contents = _apply_edits_locally(contents, edits) + except Exception as e: + return {"success": False, "message": f"Edit application failed: {e}"} + + # 3) update to Unity + params: Dict[str, Any] = { + "action": "update", + "name": name, + "path": path, + "namespace": namespace, + "scriptType": script_type, + "encodedContents": base64.b64encode(new_contents.encode("utf-8")).decode("ascii"), + "contentsEncoded": True, + } + if options is not None: + params["options"] = options + write_resp = send_command_with_retry("manage_script", params) + return write_resp if isinstance(write_resp, dict) else {"success": False, "message": str(write_resp)} + + + diff --git a/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py b/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py index bc602040..f04fb430 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py +++ b/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py @@ -38,7 +38,7 @@ def connect(self) -> bool: try: self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.sock.connect((self.host, self.port)) - logger.info(f"Connected to Unity at {self.host}:{self.port}") + logger.debug(f"Connected to Unity at {self.host}:{self.port}") # Strict handshake: require FRAMING=1 try: @@ -47,7 +47,7 @@ def connect(self) -> bool: text = greeting.decode('ascii', errors='ignore') if greeting else '' if 'FRAMING=1' in text: self.use_framing = True - logger.info('Unity MCP handshake received: FRAMING=1 (strict)') + logger.debug('Unity MCP handshake received: FRAMING=1 (strict)') else: raise ConnectionError(f'Unity MCP requires FRAMING=1, got: {text!r}') finally: @@ -188,15 +188,10 @@ def read_status_file() -> dict | None: for attempt in range(attempts + 1): try: - # Ensure connected + # Ensure connected (perform handshake each time so framing stays correct) if not self.sock: - # During retries use short connect timeout - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.settimeout(1.0) - self.sock.connect((self.host, self.port)) - # restore steady-state timeout for receive - self.sock.settimeout(config.connection_timeout) - logger.info(f"Connected to Unity at {self.host}:{self.port}") + if not self.connect(): + raise Exception("Could not connect to Unity") # Build payload if command_type == 'ping': diff --git a/test_unity_socket_framing.py b/test_unity_socket_framing.py new file mode 100644 index 00000000..b0e179c9 --- /dev/null +++ b/test_unity_socket_framing.py @@ -0,0 +1,88 @@ +#!/usr/bin/env python3 +import socket, struct, json, sys + +HOST = "127.0.0.1" +PORT = 6400 +SIZE_MB = int(sys.argv[1]) if len(sys.argv) > 1 else 5 # e.g., 5 or 10 +FILL = "R" + +def recv_exact(sock, n): + buf = bytearray(n) + view = memoryview(buf) + off = 0 + while off < n: + r = sock.recv_into(view[off:]) + if r == 0: + raise RuntimeError("socket closed") + off += r + return bytes(buf) + +def is_valid_json(b): + try: + json.loads(b.decode("utf-8")) + return True + except Exception: + return False + +def recv_legacy_json(sock, timeout=60): + sock.settimeout(timeout) + chunks = [] + while True: + chunk = sock.recv(65536) + if not chunk: + data = b"".join(chunks) + if not data: + raise RuntimeError("no data, socket closed") + return data + chunks.append(chunk) + data = b"".join(chunks) + if data.strip() == b"ping": + return data + if is_valid_json(data): + return data + +def main(): + body = { + "type": "read_console", + "params": { + "action": "get", + "types": ["all"], + "count": 1000, + "format": "detailed", + "includeStacktrace": True, + "filterText": FILL * (SIZE_MB * 1024 * 1024) + } + } + body_bytes = json.dumps(body, ensure_ascii=False).encode("utf-8") + + with socket.create_connection((HOST, PORT), timeout=5) as s: + s.settimeout(2) + # Read optional greeting + try: + greeting = s.recv(256) + except Exception: + greeting = b"" + greeting_text = greeting.decode("ascii", errors="ignore").strip() + print(f"Greeting: {greeting_text or '(none)'}") + + framing = "FRAMING=1" in greeting_text + print(f"Using framing? {framing}") + + s.settimeout(120) + if framing: + header = struct.pack(">Q", len(body_bytes)) + s.sendall(header + body_bytes) + resp_len = struct.unpack(">Q", recv_exact(s, 8))[0] + print(f"Response framed length: {resp_len}") + resp = recv_exact(s, resp_len) + else: + s.sendall(body_bytes) + resp = recv_legacy_json(s) + + print(f"Response bytes: {len(resp)}") + print(f"Response head: {resp[:120].decode('utf-8','ignore')}") + +if __name__ == "__main__": + main() + + From a12dcab7b0dcffa184bb9964831a95947f08447a Mon Sep 17 00:00:00 2001 From: dsarno Date: Sat, 16 Aug 2025 03:49:52 -0700 Subject: [PATCH 05/25] test: add initial script and asset edit tests --- UnityMcpBridge/Editor/Tools/ManageScript.cs | 178 ++++++++++++++++-- .../UnityMcpServer~/src/tools/manage_asset.py | 2 +- .../src/tools/manage_script.py | 124 ++++++++++-- test_unity_socket_framing.py | 5 +- tests/test_script_tools.py | 123 ++++++++++++ 5 files changed, 394 insertions(+), 38 deletions(-) create mode 100644 tests/test_script_tools.py diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index d2df4584..0d2fae60 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -8,10 +8,12 @@ using UnityEngine; using UnityMcpBridge.Editor.Helpers; using System.Threading; +using System.Security.Cryptography; #if USE_ROSLYN using Microsoft.CodeAnalysis; using Microsoft.CodeAnalysis.CSharp; +using Microsoft.CodeAnalysis.Formatting; #endif #if UNITY_EDITOR @@ -193,12 +195,12 @@ public static object HandleCommand(JObject @params) case "apply_text_edits": { var edits = @params["edits"] as JArray; - string precondition = @params["precondition_sha256"]?.ToString(); // optional, currently ignored here - return ApplyTextEdits(fullPath, relativePath, name, edits); + string precondition = @params["precondition_sha256"]?.ToString(); + return ApplyTextEdits(fullPath, relativePath, name, edits, precondition); } case "validate": { - string level = @params["level"]?.ToString()?.ToLowerInvariant() ?? "standard"; + string level = @params["level"]?.ToString()?.ToLowerInvariant() ?? "basic"; var chosen = level switch { "basic" => ValidationLevel.Basic, @@ -209,13 +211,19 @@ public static object HandleCommand(JObject @params) try { fileText = File.ReadAllText(fullPath); } catch (Exception ex) { return Response.Error($"Failed to read script: {ex.Message}"); } - bool ok = ValidateScriptSyntax(fileText, chosen, out string[] diags); - var result = new + bool ok = ValidateScriptSyntax(fileText, chosen, out string[] diagsRaw); + var diags = (diagsRaw ?? Array.Empty()).Select(s => { - isValid = ok, - diagnostics = diags ?? Array.Empty() - }; - return ok ? Response.Success("Validation completed.", result) : Response.Error("Validation failed.", result); + var m = Regex.Match(s, @"^(ERROR|WARNING|INFO): (.*?)(?: \(Line (\d+)\))?$"); + string severity = m.Success ? m.Groups[1].Value.ToLowerInvariant() : "info"; + string message = m.Success ? m.Groups[2].Value : s; + int lineNum = m.Success && int.TryParse(m.Groups[3].Value, out var l) ? l : 0; + return new { line = lineNum, col = 0, severity, message }; + }).ToArray(); + + var result = new { diagnostics = diags }; + return ok ? Response.Success("Validation completed.", result) + : Response.Error("Validation failed.", result); } case "edit": return Response.Error("Deprecated: use apply_text_edits. Structured 'edit' mode has been retired in favor of simple text edits."); @@ -299,9 +307,10 @@ string namespaceName try { File.Delete(tmp); } catch { } } + var uri = $"unity://path/{relativePath}"; var ok = Response.Success( $"Script '{name}.cs' created successfully at '{relativePath}'.", - new { path = relativePath, scheduledRefresh = true } + new { uri, scheduledRefresh = true } ); // Schedule heavy work AFTER replying @@ -423,11 +432,14 @@ string contents /// /// Apply simple text edits specified by line/column ranges. Applies transactionally and validates result. /// + private const int MaxEditPayloadBytes = 15 * 1024; + private static object ApplyTextEdits( string fullPath, string relativePath, string name, - JArray edits) + JArray edits, + string preconditionSha256) { if (!File.Exists(fullPath)) return Response.Error($"Script not found at '{relativePath}'."); @@ -438,8 +450,15 @@ private static object ApplyTextEdits( try { original = File.ReadAllText(fullPath); } catch (Exception ex) { return Response.Error($"Failed to read script: {ex.Message}"); } + string currentSha = ComputeSha256(original); + if (!string.IsNullOrEmpty(preconditionSha256) && !preconditionSha256.Equals(currentSha, StringComparison.OrdinalIgnoreCase)) + { + return Response.Error("stale_file", new { status = "stale_file", expected_sha256 = preconditionSha256, current_sha256 = currentSha }); + } + // Convert edits to absolute index ranges var spans = new List<(int start, int end, string text)>(); + int totalBytes = 0; foreach (var e in edits) { try @@ -457,6 +476,7 @@ private static object ApplyTextEdits( if (eidx < sidx) (sidx, eidx) = (eidx, sidx); spans.Add((sidx, eidx, newText)); + totalBytes += System.Text.Encoding.UTF8.GetByteCount(newText); } catch (Exception ex) { @@ -464,6 +484,11 @@ private static object ApplyTextEdits( } } + if (totalBytes > MaxEditPayloadBytes) + { + return Response.Error("too_large", new { status = "too_large", limitBytes = MaxEditPayloadBytes, hint = "split into smaller edits" }); + } + // Ensure non-overlap and apply from back to front spans = spans.OrderByDescending(t => t.start).ToList(); for (int i = 1; i < spans.Count; i++) @@ -478,10 +503,40 @@ private static object ApplyTextEdits( working = working.Remove(sp.start, sp.end - sp.start).Insert(sp.start, sp.text ?? string.Empty); } - // Validate result - var level = GetValidationLevelFromGUI(); - if (!ValidateScriptSyntax(working, level, out var errors)) - return Response.Error("Script validation failed:\n" + string.Join("\n", errors ?? Array.Empty())); + if (!CheckBalancedDelimiters(working, out int line, out char expected)) + { + int startLine = Math.Max(1, line - 5); + int endLine = line + 5; + string hint = $"unbalanced_braces at line {line}. Call resources/read for lines {startLine}-{endLine} and resend a smaller apply_text_edits that restores balance."; + return Response.Error(hint, new { status = "unbalanced_braces", line, expected = expected.ToString() }); + } + +#if USE_ROSLYN + var tree = CSharpSyntaxTree.ParseText(working); + var diagnostics = tree.GetDiagnostics().Where(d => d.Severity == DiagnosticSeverity.Error).Take(3) + .Select(d => new { + line = d.Location.GetLineSpan().StartLinePosition.Line + 1, + col = d.Location.GetLineSpan().StartLinePosition.Character + 1, + code = d.Id, + message = d.GetMessage() + }).ToArray(); + if (diagnostics.Length > 0) + { + return Response.Error("syntax_error", new { status = "syntax_error", diagnostics }); + } + + // Optional formatting + try + { + var root = tree.GetRoot(); + var workspace = new AdhocWorkspace(); + root = Microsoft.CodeAnalysis.Formatting.Formatter.Format(root, workspace); + working = root.ToFullString(); + } + catch { } +#endif + + string newSha = ComputeSha256(working); // Atomic write and schedule refresh try @@ -495,7 +550,17 @@ private static object ApplyTextEdits( catch (IOException) { File.Copy(tmp, fullPath, true); try { File.Delete(tmp); } catch { } } ManageScriptRefreshHelpers.ScheduleScriptRefresh(relativePath); - return Response.Success($"Applied {spans.Count} text edit(s) to '{relativePath}'.", new { path = relativePath, editsApplied = spans.Count, scheduledRefresh = true }); + return Response.Success( + $"Applied {spans.Count} text edit(s) to '{relativePath}'.", + new + { + applied = spans.Count, + unchanged = 0, + sha256 = newSha, + uri = $"unity://path/{relativePath}", + scheduledRefresh = true + } + ); } catch (Exception ex) { @@ -522,6 +587,84 @@ private static bool TryIndexFromLineCol(string text, int line1, int col1, out in index = -1; return false; } + private static string ComputeSha256(string contents) + { + using (var sha = SHA256.Create()) + { + var bytes = System.Text.Encoding.UTF8.GetBytes(contents); + var hash = sha.ComputeHash(bytes); + return BitConverter.ToString(hash).Replace("-", string.Empty).ToLowerInvariant(); + } + } + + private static bool CheckBalancedDelimiters(string text, out int line, out char expected) + { + var braceStack = new Stack(); + var parenStack = new Stack(); + var bracketStack = new Stack(); + bool inString = false, inChar = false, inSingle = false, inMulti = false, escape = false; + line = 1; expected = '\0'; + + for (int i = 0; i < text.Length; i++) + { + char c = text[i]; + char next = i + 1 < text.Length ? text[i + 1] : '\0'; + + if (c == '\n') { line++; if (inSingle) inSingle = false; } + + if (escape) { escape = false; continue; } + + if (inString) + { + if (c == '\\') { escape = true; } + else if (c == '"') inString = false; + continue; + } + if (inChar) + { + if (c == '\\') { escape = true; } + else if (c == '\'') inChar = false; + continue; + } + if (inSingle) continue; + if (inMulti) + { + if (c == '*' && next == '/') { inMulti = false; i++; } + continue; + } + + if (c == '"') { inString = true; continue; } + if (c == '\'') { inChar = true; continue; } + if (c == '/' && next == '/') { inSingle = true; i++; continue; } + if (c == '/' && next == '*') { inMulti = true; i++; continue; } + + switch (c) + { + case '{': braceStack.Push(line); break; + case '}': + if (braceStack.Count == 0) { expected = '{'; return false; } + braceStack.Pop(); + break; + case '(': parenStack.Push(line); break; + case ')': + if (parenStack.Count == 0) { expected = '('; return false; } + parenStack.Pop(); + break; + case '[': bracketStack.Push(line); break; + case ']': + if (bracketStack.Count == 0) { expected = '['; return false; } + bracketStack.Pop(); + break; + } + } + + if (braceStack.Count > 0) { line = braceStack.Peek(); expected = '}'; return false; } + if (parenStack.Count > 0) { line = parenStack.Peek(); expected = ')'; return false; } + if (bracketStack.Count > 0) { line = bracketStack.Peek(); expected = ']'; return false; } + + return true; + } + private static object DeleteScript(string fullPath, string relativePath) { if (!File.Exists(fullPath)) @@ -537,7 +680,8 @@ private static object DeleteScript(string fullPath, string relativePath) { AssetDatabase.Refresh(); return Response.Success( - $"Script '{Path.GetFileName(relativePath)}' moved to trash successfully." + $"Script '{Path.GetFileName(relativePath)}' moved to trash successfully.", + new { deleted = true } ); } else diff --git a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_asset.py b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_asset.py index 19ac0c2e..ccafb047 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_asset.py +++ b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_asset.py @@ -76,4 +76,4 @@ async def manage_asset( # Use centralized async retry helper to avoid blocking the event loop result = await async_send_command_with_retry("manage_asset", params_dict, loop=loop) # Return the result obtained from Unity - return result if isinstance(result, dict) else {"success": False, "message": str(result)} \ No newline at end of file + return result if isinstance(result, dict) else {"success": False, "message": str(result)} diff --git a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py index af44a446..f7836da3 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py +++ b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py @@ -1,14 +1,95 @@ from mcp.server.fastmcp import FastMCP, Context -from typing import Dict, Any +from typing import Dict, Any, List from unity_connection import get_unity_connection, send_command_with_retry from config import config import time import os import base64 + def register_manage_script_tools(mcp: FastMCP): """Register all script management tools with the MCP server.""" + def _split_uri(uri: str) -> tuple[str, str]: + if uri.startswith("unity://path/"): + path = uri[len("unity://path/") :] + elif uri.startswith("file://"): + path = uri[len("file://") :] + else: + path = uri + path = path.replace("\\", "/") + name = os.path.splitext(os.path.basename(path))[0] + directory = os.path.dirname(path) + return name, directory + + @mcp.tool() + def apply_text_edits( + ctx: Context, + uri: str, + edits: List[Dict[str, Any]], + precondition_sha256: str | None = None, + ) -> Dict[str, Any]: + """Apply small text edits to a C# script identified by URI.""" + name, directory = _split_uri(uri) + params = { + "action": "apply_text_edits", + "name": name, + "path": directory, + "edits": edits, + "precondition_sha256": precondition_sha256, + } + params = {k: v for k, v in params.items() if v is not None} + resp = send_command_with_retry("manage_script", params) + return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)} + + @mcp.tool() + def create_script( + ctx: Context, + path: str, + contents: str = "", + script_type: str | None = None, + namespace: str | None = None, + ) -> Dict[str, Any]: + """Create a new C# script at the given path.""" + name = os.path.splitext(os.path.basename(path))[0] + directory = os.path.dirname(path) + params: Dict[str, Any] = { + "action": "create", + "name": name, + "path": directory, + "namespace": namespace, + "scriptType": script_type, + } + if contents is not None: + params["encodedContents"] = base64.b64encode(contents.encode("utf-8")).decode("utf-8") + params["contentsEncoded"] = True + params = {k: v for k, v in params.items() if v is not None} + resp = send_command_with_retry("manage_script", params) + return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)} + + @mcp.tool() + def delete_script(ctx: Context, uri: str) -> Dict[str, Any]: + """Delete a C# script by URI.""" + name, directory = _split_uri(uri) + params = {"action": "delete", "name": name, "path": directory} + resp = send_command_with_retry("manage_script", params) + return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)} + + @mcp.tool() + def validate_script( + ctx: Context, uri: str, level: str = "basic" + ) -> Dict[str, Any]: + """Validate a C# script and return diagnostics.""" + name, directory = _split_uri(uri) + params = { + "action": "validate", + "name": name, + "path": directory, + "level": level, + } + resp = send_command_with_retry("manage_script", params) + return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)} + @mcp.tool() def manage_script( ctx: Context, @@ -17,12 +98,13 @@ def manage_script( path: str, contents: str, script_type: str, - namespace: str + namespace: str, ) -> Dict[str, Any]: - """Manage C# scripts in Unity. + """Compatibility router for legacy script operations. IMPORTANT: - - This router is minimized. Use resources/read for file content and 'script_apply_edits' for changes. + - Direct file reads should use resources/read. + - Edits should use apply_text_edits. Args: action: Operation ('create', 'read', 'update', 'delete'). @@ -38,7 +120,7 @@ def manage_script( try: # Deprecate full-file update path entirely if action == 'update': - return {"success": False, "message": "Deprecated: use script_apply_edits (line/col edits) or resources/read + small edits."} + return {"success": False, "message": "Deprecated: use apply_text_edits or resources/read + small edits."} # Prepare parameters for Unity params = { @@ -46,36 +128,40 @@ def manage_script( "name": name, "path": path, "namespace": namespace, - "scriptType": script_type + "scriptType": script_type, } - + # Base64 encode the contents if they exist to avoid JSON escaping issues if contents is not None: if action in ['create', 'update']: - # Encode content for safer transmission params["encodedContents"] = base64.b64encode(contents.encode('utf-8')).decode('utf-8') params["contentsEncoded"] = True else: params["contents"] = contents - - # Remove None values so they don't get sent as null + params = {k: v for k, v in params.items() if v is not None} - # Send command via centralized retry helper response = send_command_with_retry("manage_script", params) - - # Process response from Unity + if isinstance(response, dict) and response.get("success"): - # If the response contains base64 encoded content, decode it if response.get("data", {}).get("contentsEncoded"): decoded_contents = base64.b64decode(response["data"]["encodedContents"]).decode('utf-8') response["data"]["contents"] = decoded_contents del response["data"]["encodedContents"] del response["data"]["contentsEncoded"] - - return {"success": True, "message": response.get("message", "Operation successful."), "data": response.get("data")} - return response if isinstance(response, dict) else {"success": False, "message": str(response)} + + return { + "success": True, + "message": response.get("message", "Operation successful."), + "data": response.get("data"), + } + return response if isinstance(response, dict) else { + "success": False, + "message": str(response), + } except Exception as e: - # Handle Python-side errors (e.g., connection issues) - return {"success": False, "message": f"Python error managing script: {str(e)}"} \ No newline at end of file + return { + "success": False, + "message": f"Python error managing script: {str(e)}", + } diff --git a/test_unity_socket_framing.py b/test_unity_socket_framing.py index b0e179c9..c24064a1 100644 --- a/test_unity_socket_framing.py +++ b/test_unity_socket_framing.py @@ -3,7 +3,10 @@ HOST = "127.0.0.1" PORT = 6400 -SIZE_MB = int(sys.argv[1]) if len(sys.argv) > 1 else 5 # e.g., 5 or 10 +try: + SIZE_MB = int(sys.argv[1]) +except (IndexError, ValueError): + SIZE_MB = 5 # e.g., 5 or 10 FILL = "R" def recv_exact(sock, n): diff --git a/tests/test_script_tools.py b/tests/test_script_tools.py new file mode 100644 index 00000000..9b953a1a --- /dev/null +++ b/tests/test_script_tools.py @@ -0,0 +1,123 @@ +import sys +import pathlib +import importlib.util +import types +import pytest + +# add server src to path and load modules without triggering package imports +ROOT = pathlib.Path(__file__).resolve().parents[1] +SRC = ROOT / "UnityMcpBridge" / "UnityMcpServer~" / "src" +sys.path.insert(0, str(SRC)) + +# stub mcp.server.fastmcp to satisfy imports without full dependency +mcp_pkg = types.ModuleType("mcp") +server_pkg = types.ModuleType("mcp.server") +fastmcp_pkg = types.ModuleType("mcp.server.fastmcp") + +class _Dummy: + pass + +fastmcp_pkg.FastMCP = _Dummy +fastmcp_pkg.Context = _Dummy +server_pkg.fastmcp = fastmcp_pkg +mcp_pkg.server = server_pkg +sys.modules.setdefault("mcp", mcp_pkg) +sys.modules.setdefault("mcp.server", server_pkg) +sys.modules.setdefault("mcp.server.fastmcp", fastmcp_pkg) + +def load_module(path, name): + spec = importlib.util.spec_from_file_location(name, path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + +manage_script_module = load_module(SRC / "tools" / "manage_script.py", "manage_script_module") +manage_asset_module = load_module(SRC / "tools" / "manage_asset.py", "manage_asset_module") + + +class DummyMCP: + def __init__(self): + self.tools = {} + + def tool(self): + def decorator(func): + self.tools[func.__name__] = func + return func + return decorator + +def setup_manage_script(): + mcp = DummyMCP() + manage_script_module.register_manage_script_tools(mcp) + return mcp.tools + +def setup_manage_asset(): + mcp = DummyMCP() + manage_asset_module.register_manage_asset_tools(mcp) + return mcp.tools + +def test_apply_text_edits_long_file(monkeypatch): + tools = setup_manage_script() + apply_edits = tools["apply_text_edits"] + captured = {} + + def fake_send(cmd, params): + captured["cmd"] = cmd + captured["params"] = params + return {"success": True} + + monkeypatch.setattr(manage_script_module, "send_command_with_retry", fake_send) + + edit = {"startLine": 1005, "startCol": 0, "endLine": 1005, "endCol": 5, "newText": "Hello"} + resp = apply_edits(None, "unity://path/Assets/Scripts/LongFile.cs", [edit]) + assert captured["cmd"] == "manage_script" + assert captured["params"]["action"] == "apply_text_edits" + assert captured["params"]["edits"][0]["startLine"] == 1005 + assert resp["success"] is True + +def test_sequential_edits_use_precondition(monkeypatch): + tools = setup_manage_script() + apply_edits = tools["apply_text_edits"] + calls = [] + + def fake_send(cmd, params): + calls.append(params) + return {"success": True, "sha256": f"hash{len(calls)}"} + + monkeypatch.setattr(manage_script_module, "send_command_with_retry", fake_send) + + edit1 = {"startLine": 1, "startCol": 0, "endLine": 1, "endCol": 0, "newText": "//header\n"} + resp1 = apply_edits(None, "unity://path/Assets/Scripts/File.cs", [edit1]) + edit2 = {"startLine": 2, "startCol": 0, "endLine": 2, "endCol": 0, "newText": "//second\n"} + resp2 = apply_edits(None, "unity://path/Assets/Scripts/File.cs", [edit2], precondition_sha256=resp1["sha256"]) + + assert calls[1]["precondition_sha256"] == resp1["sha256"] + assert resp2["sha256"] == "hash2" + +def test_manage_asset_prefab_modify_request(monkeypatch): + tools = setup_manage_asset() + manage_asset = tools["manage_asset"] + captured = {} + + async def fake_async(cmd, params, loop=None): + captured["cmd"] = cmd + captured["params"] = params + return {"success": True} + + monkeypatch.setattr(manage_asset_module, "async_send_command_with_retry", fake_async) + monkeypatch.setattr(manage_asset_module, "get_unity_connection", lambda: object()) + + async def run(): + resp = await manage_asset( + None, + action="modify", + path="Assets/Prefabs/Player.prefab", + properties={"hp": 100}, + ) + assert captured["cmd"] == "manage_asset" + assert captured["params"]["action"] == "modify" + assert captured["params"]["path"] == "Assets/Prefabs/Player.prefab" + assert captured["params"]["properties"] == {"hp": 100} + assert resp["success"] is True + + import asyncio + asyncio.run(run()) From de4a6bc36137fb97510c90d46bc71242919b77ec Mon Sep 17 00:00:00 2001 From: dsarno Date: Sat, 16 Aug 2025 06:18:37 -0700 Subject: [PATCH 06/25] Maintain manage_script compatibility and add safety checks --- UnityMcpBridge/Editor/Tools/ManageScript.cs | 35 +++++++++++++++---- .../Editor/Windows/UnityMcpEditorWindow.cs | 9 +++++ .../UnityMcpServer~/src/pyrightconfig.json | 9 ++++- test_unity_socket_framing.py | 3 ++ 4 files changed, 48 insertions(+), 8 deletions(-) diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index 0d2fae60..19c9f24a 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -187,9 +187,11 @@ public static object HandleCommand(JObject @params) namespaceName ); case "read": - return Response.Error("Deprecated: reads are resources now. Use resources/read with a unity://path or unity://script URI."); + Debug.LogWarning("manage_script.read is deprecated; prefer resources/read. Serving read for backward compatibility."); + return ReadScript(fullPath, relativePath); case "update": - return Response.Error("Deprecated: use apply_text_edits (small, line/col edits) rather than whole-file replace."); + Debug.LogWarning("manage_script.update is deprecated; prefer apply_text_edits. Serving update for backward compatibility."); + return UpdateScript(fullPath, relativePath, name, contents); case "delete": return DeleteScript(fullPath, relativePath); case "apply_text_edits": @@ -226,10 +228,13 @@ public static object HandleCommand(JObject @params) : Response.Error("Validation failed.", result); } case "edit": - return Response.Error("Deprecated: use apply_text_edits. Structured 'edit' mode has been retired in favor of simple text edits."); + Debug.LogWarning("manage_script.edit is deprecated; prefer apply_text_edits. Serving structured edit for backward compatibility."); + var edits = @params["edits"] as JArray; + var options = @params["options"] as JObject; + return EditScript(fullPath, relativePath, name, edits, options); default: return Response.Error( - $"Unknown action: '{action}'. Valid actions are: create, read, update, delete." + $"Unknown action: '{action}'. Valid actions are: create, delete, apply_text_edits, validate, read (deprecated), update (deprecated), edit (deprecated)." ); } } @@ -581,10 +586,26 @@ private static bool TryIndexFromLineCol(string text, int line1, int col1, out in } if (i == text.Length) break; char c = text[i]; - if (c == '\n') { line++; col = 1; } - else { col++; } + if (c == '\r') + { + // Treat CRLF as a single newline; skip the LF if present + if (i + 1 < text.Length && text[i + 1] == '\n') + i++; + line++; + col = 1; + } + else if (c == '\n') + { + line++; + col = 1; + } + else + { + col++; + } } - index = -1; return false; + index = -1; + return false; } private static string ComputeSha256(string contents) diff --git a/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs b/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs index 19446406..d80ffbb5 100644 --- a/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs +++ b/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs @@ -1579,6 +1579,15 @@ private void CheckMcpConfiguration(McpClient mcpClient) } } } + else + { + // Surface mismatch even if auto-manage is disabled + mcpClient.SetStatus(McpStatus.IncorrectPath); + if (debugLogsEnabled) + { + UnityEngine.Debug.Log($"UnityMCP: IDE config mismatch for '{mcpClient.name}' and auto-manage disabled"); + } + } } } else diff --git a/UnityMcpBridge/UnityMcpServer~/src/pyrightconfig.json b/UnityMcpBridge/UnityMcpServer~/src/pyrightconfig.json index cfa4ff8c..4fdeb465 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/pyrightconfig.json +++ b/UnityMcpBridge/UnityMcpServer~/src/pyrightconfig.json @@ -1,4 +1,11 @@ { "typeCheckingMode": "basic", - "reportMissingImports": "none" + "reportMissingImports": "none", + "pythonVersion": "3.11", + "executionEnvironments": [ + { + "root": ".", + "pythonVersion": "3.11" + } + ] } diff --git a/test_unity_socket_framing.py b/test_unity_socket_framing.py index c24064a1..7495ccb3 100644 --- a/test_unity_socket_framing.py +++ b/test_unity_socket_framing.py @@ -77,6 +77,9 @@ def main(): s.sendall(header + body_bytes) resp_len = struct.unpack(">Q", recv_exact(s, 8))[0] print(f"Response framed length: {resp_len}") + MAX_RESP = 128 * 1024 * 1024 + if resp_len <= 0 or resp_len > MAX_RESP: + raise RuntimeError(f"invalid framed length: {resp_len} (max {MAX_RESP})") resp = recv_exact(s, resp_len) else: s.sendall(body_bytes) From c13a2dae6f27f07ddcf9b7269f18b35d63d58097 Mon Sep 17 00:00:00 2001 From: dsarno Date: Sat, 16 Aug 2025 09:06:27 -0700 Subject: [PATCH 07/25] Support explicit validation levels --- UnityMcpBridge/Editor/Tools/ManageScript.cs | 37 +++++++++++++++---- .../Editor/Windows/UnityMcpEditorWindow.cs | 9 +++++ .../UnityMcpServer~/src/pyrightconfig.json | 9 ++++- test_unity_socket_framing.py | 3 ++ 4 files changed, 50 insertions(+), 8 deletions(-) diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index 0d2fae60..29339604 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -187,9 +187,11 @@ public static object HandleCommand(JObject @params) namespaceName ); case "read": - return Response.Error("Deprecated: reads are resources now. Use resources/read with a unity://path or unity://script URI."); + Debug.LogWarning("manage_script.read is deprecated; prefer resources/read. Serving read for backward compatibility."); + return ReadScript(fullPath, relativePath); case "update": - return Response.Error("Deprecated: use apply_text_edits (small, line/col edits) rather than whole-file replace."); + Debug.LogWarning("manage_script.update is deprecated; prefer apply_text_edits. Serving update for backward compatibility."); + return UpdateScript(fullPath, relativePath, name, contents); case "delete": return DeleteScript(fullPath, relativePath); case "apply_text_edits": @@ -204,7 +206,9 @@ public static object HandleCommand(JObject @params) var chosen = level switch { "basic" => ValidationLevel.Basic, + "standard" => ValidationLevel.Standard, "strict" => ValidationLevel.Strict, + "comprehensive" => ValidationLevel.Comprehensive, _ => ValidationLevel.Standard }; string fileText; @@ -226,10 +230,13 @@ public static object HandleCommand(JObject @params) : Response.Error("Validation failed.", result); } case "edit": - return Response.Error("Deprecated: use apply_text_edits. Structured 'edit' mode has been retired in favor of simple text edits."); + Debug.LogWarning("manage_script.edit is deprecated; prefer apply_text_edits. Serving structured edit for backward compatibility."); + var edits = @params["edits"] as JArray; + var options = @params["options"] as JObject; + return EditScript(fullPath, relativePath, name, edits, options); default: return Response.Error( - $"Unknown action: '{action}'. Valid actions are: create, read, update, delete." + $"Unknown action: '{action}'. Valid actions are: create, delete, apply_text_edits, validate, read (deprecated), update (deprecated), edit (deprecated)." ); } } @@ -581,10 +588,26 @@ private static bool TryIndexFromLineCol(string text, int line1, int col1, out in } if (i == text.Length) break; char c = text[i]; - if (c == '\n') { line++; col = 1; } - else { col++; } + if (c == '\r') + { + // Treat CRLF as a single newline; skip the LF if present + if (i + 1 < text.Length && text[i + 1] == '\n') + i++; + line++; + col = 1; + } + else if (c == '\n') + { + line++; + col = 1; + } + else + { + col++; + } } - index = -1; return false; + index = -1; + return false; } private static string ComputeSha256(string contents) diff --git a/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs b/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs index 19446406..d80ffbb5 100644 --- a/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs +++ b/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs @@ -1579,6 +1579,15 @@ private void CheckMcpConfiguration(McpClient mcpClient) } } } + else + { + // Surface mismatch even if auto-manage is disabled + mcpClient.SetStatus(McpStatus.IncorrectPath); + if (debugLogsEnabled) + { + UnityEngine.Debug.Log($"UnityMCP: IDE config mismatch for '{mcpClient.name}' and auto-manage disabled"); + } + } } } else diff --git a/UnityMcpBridge/UnityMcpServer~/src/pyrightconfig.json b/UnityMcpBridge/UnityMcpServer~/src/pyrightconfig.json index cfa4ff8c..4fdeb465 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/pyrightconfig.json +++ b/UnityMcpBridge/UnityMcpServer~/src/pyrightconfig.json @@ -1,4 +1,11 @@ { "typeCheckingMode": "basic", - "reportMissingImports": "none" + "reportMissingImports": "none", + "pythonVersion": "3.11", + "executionEnvironments": [ + { + "root": ".", + "pythonVersion": "3.11" + } + ] } diff --git a/test_unity_socket_framing.py b/test_unity_socket_framing.py index c24064a1..7495ccb3 100644 --- a/test_unity_socket_framing.py +++ b/test_unity_socket_framing.py @@ -77,6 +77,9 @@ def main(): s.sendall(header + body_bytes) resp_len = struct.unpack(">Q", recv_exact(s, 8))[0] print(f"Response framed length: {resp_len}") + MAX_RESP = 128 * 1024 * 1024 + if resp_len <= 0 or resp_len > MAX_RESP: + raise RuntimeError(f"invalid framed length: {resp_len} (max {MAX_RESP})") resp = recv_exact(s, resp_len) else: s.sendall(body_bytes) From 49a3355c7f0451122f9dc8774761f8171383f2c9 Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 06:58:28 -0700 Subject: [PATCH 08/25] Fix script tool returns and handshake edge cases --- UnityMcpBridge/Editor/UnityMcpBridge.cs | 77 ++++++++----------- UnityMcpBridge/UnityMcpServer~/src/server.py | 56 ++++++-------- .../UnityMcpServer~/src/tools/__init__.py | 7 +- .../src/tools/manage_script.py | 40 +++++----- .../src/tools/manage_script_edits.py | 8 +- .../UnityMcpServer~/src/unity_connection.py | 13 +++- 6 files changed, 99 insertions(+), 102 deletions(-) diff --git a/UnityMcpBridge/Editor/UnityMcpBridge.cs b/UnityMcpBridge/Editor/UnityMcpBridge.cs index fa707483..feb631ba 100644 --- a/UnityMcpBridge/Editor/UnityMcpBridge.cs +++ b/UnityMcpBridge/Editor/UnityMcpBridge.cs @@ -431,15 +431,7 @@ private static async Task HandleClientAsync(TcpClient client) if (true) { // Enforced framed mode for this connection - byte[] header = await ReadExactAsync(stream, 8, FrameIOTimeoutMs); - ulong payloadLen = ReadUInt64BigEndian(header); - if (payloadLen == 0UL || payloadLen > MaxFrameBytes) - { - throw new System.IO.IOException($"Invalid framed length: {payloadLen}"); - } - int payloadLenInt = checked((int)payloadLen); - byte[] payload = await ReadExactAsync(stream, payloadLenInt, FrameIOTimeoutMs); - commandText = System.Text.Encoding.UTF8.GetString(payload); + commandText = await ReadFrameAsUtf8Async(stream, FrameIOTimeoutMs); } try @@ -459,16 +451,7 @@ private static async Task HandleClientAsync(TcpClient client) /*lang=json,strict*/ "{\"status\":\"success\",\"result\":{\"message\":\"pong\"}}" ); - if ((ulong)pingResponseBytes.Length > MaxFrameBytes) - { - throw new System.IO.IOException($"Frame too large: {pingResponseBytes.Length}"); - } - { - byte[] outHeader = new byte[8]; - WriteUInt64BigEndian(outHeader, (ulong)pingResponseBytes.Length); - await stream.WriteAsync(outHeader, 0, outHeader.Length); - } - await stream.WriteAsync(pingResponseBytes, 0, pingResponseBytes.Length); + await WriteFrameAsync(stream, pingResponseBytes); continue; } @@ -479,16 +462,7 @@ private static async Task HandleClientAsync(TcpClient client) string response = await tcs.Task; byte[] responseBytes = System.Text.Encoding.UTF8.GetBytes(response); - if ((ulong)responseBytes.Length > MaxFrameBytes) - { - throw new System.IO.IOException($"Frame too large: {responseBytes.Length}"); - } - { - byte[] outHeader = new byte[8]; - WriteUInt64BigEndian(outHeader, (ulong)responseBytes.Length); - await stream.WriteAsync(outHeader, 0, outHeader.Length); - } - await stream.WriteAsync(responseBytes, 0, responseBytes.Length); + await WriteFrameAsync(stream, responseBytes); } catch (Exception ex) { @@ -499,22 +473,6 @@ private static async Task HandleClientAsync(TcpClient client) } } - private static async System.Threading.Tasks.Task ReadExactAsync(NetworkStream stream, int count) - { - byte[] data = new byte[count]; - int offset = 0; - while (offset < count) - { - int r = await stream.ReadAsync(data, offset, count - offset); - if (r == 0) - { - throw new System.IO.IOException("Connection closed before reading expected bytes"); - } - offset += r; - } - return data; - } - // Timeout-aware exact read helper; avoids indefinite stalls private static async System.Threading.Tasks.Task ReadExactAsync(NetworkStream stream, int count, int timeoutMs) { @@ -538,6 +496,35 @@ private static async System.Threading.Tasks.Task ReadExactAsync(NetworkS return data; } + private static async System.Threading.Tasks.Task WriteFrameAsync(NetworkStream stream, byte[] payload) + { + if ((ulong)payload.LongLength > MaxFrameBytes) + { + throw new System.IO.IOException($"Frame too large: {payload.LongLength}"); + } + byte[] header = new byte[8]; + WriteUInt64BigEndian(header, (ulong)payload.LongLength); + await stream.WriteAsync(header, 0, header.Length); + await stream.WriteAsync(payload, 0, payload.Length); + } + + private static async System.Threading.Tasks.Task ReadFrameAsUtf8Async(NetworkStream stream, int timeoutMs) + { + byte[] header = await ReadExactAsync(stream, 8, timeoutMs); + ulong payloadLen = ReadUInt64BigEndian(header); + if (payloadLen == 0UL || payloadLen > MaxFrameBytes) + { + throw new System.IO.IOException($"Invalid framed length: {payloadLen}"); + } + if (payloadLen > int.MaxValue) + { + throw new System.IO.IOException("Frame too large for buffer"); + } + int count = (int)payloadLen; + byte[] payload = await ReadExactAsync(stream, count, timeoutMs); + return System.Text.Encoding.UTF8.GetString(payload); + } + private static ulong ReadUInt64BigEndian(byte[] buffer) { if (buffer == null || buffer.Length < 8) return 0UL; diff --git a/UnityMcpBridge/UnityMcpServer~/src/server.py b/UnityMcpBridge/UnityMcpServer~/src/server.py index 88add06d..99f41229 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/server.py +++ b/UnityMcpBridge/UnityMcpServer~/src/server.py @@ -9,6 +9,8 @@ from tools import register_all_tools from unity_connection import get_unity_connection, UnityConnection from pathlib import Path +import os +import hashlib # Configure logging: strictly stderr/file only (never stdout) stderr_handler = logging.StreamHandler(stream=sys.stderr) @@ -98,52 +100,44 @@ def asset_creation_strategy() -> str: class _: pass -import os -import hashlib - -def _unity_assets_root() -> str: - # Heuristic: from the Unity project root (one level up from Library/ProjectSettings), 'Assets' - # Here, assume server runs from repo; let clients pass absolute paths under project too. - return None +PROJECT_ROOT = Path(os.environ.get("UNITY_PROJECT_ROOT", Path.cwd())).resolve() +ASSETS_ROOT = (PROJECT_ROOT / "Assets").resolve() -def _safe_path(uri: str) -> str | None: - # URIs: unity://path/Assets/... or file:///absolute +def _resolve_safe_path_from_uri(uri: str) -> Path | None: + raw: str | None = None if uri.startswith("unity://path/"): - p = uri[len("unity://path/"):] - return p - if uri.startswith("file://"): - return uri[len("file://"):] - # Minimal tolerance for plain Assets/... paths - if uri.startswith("Assets/"): - return uri - return None + raw = uri[len("unity://path/"):] + elif uri.startswith("file://"): + raw = uri[len("file://"):] + elif uri.startswith("Assets/"): + raw = uri + if raw is None: + return None + p = (PROJECT_ROOT / raw).resolve() + try: + p.relative_to(PROJECT_ROOT) + except ValueError: + return None + return p @mcp.resource.list() def list_resources(ctx: Context) -> list[dict]: - # Lightweight: expose only C# under Assets by default assets = [] try: - root = os.getcwd() - for base, _, files in os.walk(os.path.join(root, "Assets")): - for f in files: - if f.endswith(".cs"): - rel = os.path.relpath(os.path.join(base, f), root).replace("\\", "/") - assets.append({ - "uri": f"unity://path/{rel}", - "name": os.path.basename(rel) - }) + for p in ASSETS_ROOT.rglob("*.cs"): + rel = p.relative_to(PROJECT_ROOT).as_posix() + assets.append({"uri": f"unity://path/{rel}", "name": p.name}) except Exception: pass return assets @mcp.resource.read() def read_resource(ctx: Context, uri: str) -> dict: - path = _safe_path(uri) - if not path or not os.path.exists(path): + p = _resolve_safe_path_from_uri(uri) + if not p or not p.exists(): return {"mimeType": "text/plain", "text": f"Resource not found: {uri}"} try: - with open(path, "r", encoding="utf-8") as f: - text = f.read() + text = p.read_text(encoding="utf-8") sha = hashlib.sha256(text.encode("utf-8")).hexdigest() return {"mimeType": "text/plain", "text": text, "metadata": {"sha256": sha}} except Exception as e: diff --git a/UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py b/UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py index 91ee9495..710b53dc 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py +++ b/UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py @@ -1,3 +1,4 @@ +import logging from .manage_script_edits import register_manage_script_edits_tools from .manage_script import register_manage_script_tools from .manage_scene import register_manage_scene_tools @@ -8,9 +9,11 @@ from .read_console import register_read_console_tools from .execute_menu_item import register_execute_menu_item_tools +logger = logging.getLogger("unity-mcp-server") + def register_all_tools(mcp): """Register all refactored tools with the MCP server.""" - # Note: Do not print to stdout; Claude treats stdout as MCP JSON. Use logging. + logger.info("Registering Unity MCP Server refactored tools...") # Prefer the surgical edits tool so LLMs discover it first register_manage_script_edits_tools(mcp) register_manage_script_tools(mcp) @@ -21,4 +24,4 @@ def register_all_tools(mcp): register_manage_shader_tools(mcp) register_read_console_tools(mcp) register_execute_menu_item_tools(mcp) - # Do not print to stdout here either. + logger.info("Unity MCP Server tool registration complete.") diff --git a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py index f7836da3..b44dd743 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py +++ b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script.py @@ -60,7 +60,7 @@ def create_script( "namespace": namespace, "scriptType": script_type, } - if contents is not None: + if contents: params["encodedContents"] = base64.b64encode(contents.encode("utf-8")).decode("utf-8") params["contentsEncoded"] = True params = {k: v for k, v in params.items() if v is not None} @@ -107,7 +107,7 @@ def manage_script( - Edits should use apply_text_edits. Args: - action: Operation ('create', 'read', 'update', 'delete'). + action: Operation ('create', 'read', 'delete'). name: Script name (no .cs extension). path: Asset path (default: "Assets/"). contents: C# code for 'create'/'update'. @@ -132,8 +132,8 @@ def manage_script( } # Base64 encode the contents if they exist to avoid JSON escaping issues - if contents is not None: - if action in ['create', 'update']: + if contents: + if action == 'create': params["encodedContents"] = base64.b64encode(contents.encode('utf-8')).decode('utf-8') params["contentsEncoded"] = True else: @@ -143,22 +143,22 @@ def manage_script( response = send_command_with_retry("manage_script", params) - if isinstance(response, dict) and response.get("success"): - if response.get("data", {}).get("contentsEncoded"): - decoded_contents = base64.b64decode(response["data"]["encodedContents"]).decode('utf-8') - response["data"]["contents"] = decoded_contents - del response["data"]["encodedContents"] - del response["data"]["contentsEncoded"] - - return { - "success": True, - "message": response.get("message", "Operation successful."), - "data": response.get("data"), - } - return response if isinstance(response, dict) else { - "success": False, - "message": str(response), - } + if isinstance(response, dict): + if response.get("success"): + if response.get("data", {}).get("contentsEncoded"): + decoded_contents = base64.b64decode(response["data"]["encodedContents"]).decode('utf-8') + response["data"]["contents"] = decoded_contents + del response["data"]["encodedContents"] + del response["data"]["contentsEncoded"] + + return { + "success": True, + "message": response.get("message", "Operation successful."), + "data": response.get("data"), + } + return response + + return {"success": False, "message": str(response)} except Exception as e: return { diff --git a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py index 9cb746df..bd7f7137 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py +++ b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py @@ -51,10 +51,11 @@ def _apply_edits_locally(original_text: str, edits: List[Dict[str, Any]]) -> str end_line = int(edit.get("endLine", start_line)) replacement = edit.get("text", "") lines = text.splitlines(keepends=True) - if start_line < 1 or end_line < start_line or end_line > len(lines): + max_end = len(lines) + 1 + if start_line < 1 or end_line < start_line or end_line > max_end: raise RuntimeError("replace_range out of bounds") a = start_line - 1 - b = end_line + b = min(end_line, len(lines)) rep = replacement if rep and not rep.endswith("\n"): rep += "\n" @@ -88,7 +89,8 @@ def script_apply_edits( script_type: str = "MonoBehaviour", namespace: str = "", ) -> Dict[str, Any]: - # If the edits request structured class/method ops, route directly to Unity's 'edit' action + # If the edits request structured class/method ops, route directly to Unity's 'edit' action. + # These bypass local text validation/encoding since Unity performs the semantic changes. for e in edits or []: op = (e.get("op") or e.get("operation") or e.get("type") or e.get("mode") or "").strip().lower() if op in ("replace_class", "delete_class", "replace_method", "delete_method", "insert_method"): diff --git a/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py b/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py index f04fb430..ab47a503 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py +++ b/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py @@ -49,12 +49,23 @@ def connect(self) -> bool: self.use_framing = True logger.debug('Unity MCP handshake received: FRAMING=1 (strict)') else: + try: + msg = b'Unity MCP requires FRAMING=1' + header = struct.pack('>Q', len(msg)) + self.sock.sendall(header + msg) + except Exception: + pass raise ConnectionError(f'Unity MCP requires FRAMING=1, got: {text!r}') finally: self.sock.settimeout(config.connection_timeout) return True except Exception as e: logger.error(f"Failed to connect to Unity: {str(e)}") + try: + if self.sock: + self.sock.close() + except Exception: + pass self.sock = None return False @@ -83,7 +94,7 @@ def receive_full_response(self, sock, buffer_size=config.buffer_size) -> bytes: try: header = self._read_exact(sock, 8) payload_len = struct.unpack('>Q', header)[0] - if payload_len == 0 or payload_len > (64 * 1024 * 1024): + if payload_len > (64 * 1024 * 1024): raise Exception(f"Invalid framed length: {payload_len}") payload = self._read_exact(sock, payload_len) logger.info(f"Received framed response ({len(payload)} bytes)") From c735117910e8d6bff6c9ce357011b1a9b6e55d32 Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 07:35:51 -0700 Subject: [PATCH 09/25] Prevent overflow when counting edit bytes --- UnityMcpBridge/Editor/Tools/ManageScript.cs | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index 29339604..1324d9b2 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -465,7 +465,7 @@ private static object ApplyTextEdits( // Convert edits to absolute index ranges var spans = new List<(int start, int end, string text)>(); - int totalBytes = 0; + long totalBytes = 0; foreach (var e in edits) { try @@ -483,7 +483,10 @@ private static object ApplyTextEdits( if (eidx < sidx) (sidx, eidx) = (eidx, sidx); spans.Add((sidx, eidx, newText)); - totalBytes += System.Text.Encoding.UTF8.GetByteCount(newText); + checked + { + totalBytes += System.Text.Encoding.UTF8.GetByteCount(newText); + } } catch (Exception ex) { From 79ba5ecdb7f245aff30707e88bf42047d4775dc7 Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 08:03:30 -0700 Subject: [PATCH 10/25] Fix refresh debounce scheduling --- UnityMcpBridge/Editor/Tools/ManageScript.cs | 32 +++++++++++++++------ 1 file changed, 23 insertions(+), 9 deletions(-) diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index 1324d9b2..87025b07 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -202,7 +202,7 @@ public static object HandleCommand(JObject @params) } case "validate": { - string level = @params["level"]?.ToString()?.ToLowerInvariant() ?? "basic"; + string level = @params["level"]?.ToString()?.ToLowerInvariant() ?? "standard"; var chosen = level switch { "basic" => ValidationLevel.Basic, @@ -343,8 +343,10 @@ private static object ReadScript(string fullPath, string relativePath) // Return both normal and encoded contents for larger files bool isLarge = contents.Length > 10000; // If content is large, include encoded version + var uri = $"unity://path/{relativePath}"; var responseData = new { + uri, path = relativePath, contents = contents, // For large files, also include base64-encoded version @@ -420,9 +422,10 @@ string contents } // Prepare success response BEFORE any operation that can trigger a domain reload + var uri = $"unity://path/{relativePath}"; var ok = Response.Success( $"Script '{name}.cs' updated successfully at '{relativePath}'.", - new { path = relativePath, scheduledRefresh = true } + new { uri, path = relativePath, scheduledRefresh = true } ); // Schedule a debounced import/compile on next editor tick to avoid stalling the reply @@ -1523,11 +1526,14 @@ private static bool ValidateScriptSyntax(string contents, ValidationLevel level, } #if USE_ROSLYN - // Advanced Roslyn-based validation - if (!ValidateScriptSyntaxRoslyn(contents, level, errorList)) + // Advanced Roslyn-based validation: only run for Standard+; fail on Roslyn errors + if (level >= ValidationLevel.Standard) { - errors = errorList.ToArray(); - return level != ValidationLevel.Standard; //TODO: Allow standard to run roslyn right now, might formalize it in the future + if (!ValidateScriptSyntaxRoslyn(contents, level, errorList)) + { + errors = errorList.ToArray(); + return false; + } } #endif @@ -2105,20 +2111,28 @@ static class RefreshDebounce { private static int _pending; private static DateTime _last; + private static readonly object _lock = new object(); + private static readonly HashSet _paths = new HashSet(StringComparer.OrdinalIgnoreCase); + private static bool _scheduled; public static void Schedule(string relPath, TimeSpan window) { Interlocked.Exchange(ref _pending, 1); + lock (_lock) { _paths.Add(relPath); } var now = DateTime.UtcNow; - if ((now - _last) < window) return; + if (_scheduled && (now - _last) < window) return; _last = now; + _scheduled = true; EditorApplication.delayCall += () => { + _scheduled = false; if (Interlocked.Exchange(ref _pending, 0) == 1) { - // Prefer targeted import and script compile over full refresh - AssetDatabase.ImportAsset(relPath, ImportAssetOptions.ForceUpdate); + string[] toImport; + lock (_lock) { toImport = _paths.ToArray(); _paths.Clear(); } + foreach (var p in toImport) + AssetDatabase.ImportAsset(p, ImportAssetOptions.ForceUpdate); #if UNITY_EDITOR UnityEditor.Compilation.CompilationPipeline.RequestScriptCompilation(); #endif From 01fc4f16da47d8760f7f8cd4c7e7d828f03cf2ad Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 08:22:23 -0700 Subject: [PATCH 11/25] Clean up temp backups and guard symlinked edits --- UnityMcpBridge/Editor/Tools/ManageScript.cs | 118 +++++++++++++++++--- 1 file changed, 100 insertions(+), 18 deletions(-) diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index 1324d9b2..91801142 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -202,7 +202,7 @@ public static object HandleCommand(JObject @params) } case "validate": { - string level = @params["level"]?.ToString()?.ToLowerInvariant() ?? "basic"; + string level = @params["level"]?.ToString()?.ToLowerInvariant() ?? "standard"; var chosen = level switch { "basic" => ValidationLevel.Basic, @@ -343,8 +343,10 @@ private static object ReadScript(string fullPath, string relativePath) // Return both normal and encoded contents for larger files bool isLarge = contents.Length > 10000; // If content is large, include encoded version + var uri = $"unity://path/{relativePath}"; var responseData = new { + uri, path = relativePath, contents = contents, // For large files, also include base64-encoded version @@ -406,23 +408,36 @@ string contents try { File.Replace(tempPath, fullPath, backupPath); + // Clean up backup to avoid stray assets inside the project + try + { + if (File.Exists(backupPath)) + File.Delete(backupPath); + } + catch + { + // ignore failures deleting the backup + } } catch (PlatformNotSupportedException) { File.Copy(tempPath, fullPath, true); try { File.Delete(tempPath); } catch { } + try { if (File.Exists(backupPath)) File.Delete(backupPath); } catch { } } catch (IOException) { // Cross-volume moves can throw IOException; fallback to copy File.Copy(tempPath, fullPath, true); try { File.Delete(tempPath); } catch { } + try { if (File.Exists(backupPath)) File.Delete(backupPath); } catch { } } // Prepare success response BEFORE any operation that can trigger a domain reload + var uri = $"unity://path/{relativePath}"; var ok = Response.Success( $"Script '{name}.cs' updated successfully at '{relativePath}'.", - new { path = relativePath, scheduledRefresh = true } + new { uri, path = relativePath, scheduledRefresh = true } ); // Schedule a debounced import/compile on next editor tick to avoid stalling the reply @@ -450,6 +465,17 @@ private static object ApplyTextEdits( { if (!File.Exists(fullPath)) return Response.Error($"Script not found at '{relativePath}'."); + // Refuse edits if the target is a symlink + try + { + var attrs = File.GetAttributes(fullPath); + if ((attrs & FileAttributes.ReparsePoint) != 0) + return Response.Error("Refusing to edit a symlinked script path."); + } + catch + { + // If checking attributes fails, proceed without the symlink guard + } if (edits == null || edits.Count == 0) return Response.Error("No edits provided."); @@ -555,9 +581,23 @@ private static object ApplyTextEdits( var tmp = fullPath + ".tmp"; File.WriteAllText(tmp, working, enc); string backup = fullPath + ".bak"; - try { File.Replace(tmp, fullPath, backup); } - catch (PlatformNotSupportedException) { File.Copy(tmp, fullPath, true); try { File.Delete(tmp); } catch { } } - catch (IOException) { File.Copy(tmp, fullPath, true); try { File.Delete(tmp); } catch { } } + try + { + File.Replace(tmp, fullPath, backup); + try { if (File.Exists(backup)) File.Delete(backup); } catch { /* ignore */ } + } + catch (PlatformNotSupportedException) + { + File.Copy(tmp, fullPath, true); + try { File.Delete(tmp); } catch { } + try { if (File.Exists(backup)) File.Delete(backup); } catch { } + } + catch (IOException) + { + File.Copy(tmp, fullPath, true); + try { File.Delete(tmp); } catch { } + try { if (File.Exists(backup)) File.Delete(backup); } catch { } + } ManageScriptRefreshHelpers.ScheduleScriptRefresh(relativePath); return Response.Success( @@ -738,6 +778,17 @@ private static object EditScript( { if (!File.Exists(fullPath)) return Response.Error($"Script not found at '{relativePath}'."); + // Refuse edits if the target is a symlink + try + { + var attrs = File.GetAttributes(fullPath); + if ((attrs & FileAttributes.ReparsePoint) != 0) + return Response.Error("Refusing to edit a symlinked script path."); + } + catch + { + // ignore failures checking attributes and proceed + } if (edits == null || edits.Count == 0) return Response.Error("No edits provided."); @@ -986,9 +1037,23 @@ private static object EditScript( var tmp = fullPath + ".tmp"; File.WriteAllText(tmp, working, enc); string backup = fullPath + ".bak"; - try { File.Replace(tmp, fullPath, backup); } - catch (PlatformNotSupportedException) { File.Copy(tmp, fullPath, true); try { File.Delete(tmp); } catch { } } - catch (IOException) { File.Copy(tmp, fullPath, true); try { File.Delete(tmp); } catch { } } + try + { + File.Replace(tmp, fullPath, backup); + try { if (File.Exists(backup)) File.Delete(backup); } catch { /* ignore */ } + } + catch (PlatformNotSupportedException) + { + File.Copy(tmp, fullPath, true); + try { File.Delete(tmp); } catch { } + try { if (File.Exists(backup)) File.Delete(backup); } catch { } + } + catch (IOException) + { + File.Copy(tmp, fullPath, true); + try { File.Delete(tmp); } catch { } + try { if (File.Exists(backup)) File.Delete(backup); } catch { } + } // Decide refresh behavior string refreshMode = options?["refresh"]?.ToString()?.ToLowerInvariant(); @@ -1001,11 +1066,17 @@ private static object EditScript( if (immediate) { - // Force an immediate import/compile on the main thread - AssetDatabase.ImportAsset(relativePath, ImportAssetOptions.ForceSynchronousImport | ImportAssetOptions.ForceUpdate); + // Force on main thread + EditorApplication.delayCall += () => + { + AssetDatabase.ImportAsset( + relativePath, + ImportAssetOptions.ForceSynchronousImport | ImportAssetOptions.ForceUpdate + ); #if UNITY_EDITOR - UnityEditor.Compilation.CompilationPipeline.RequestScriptCompilation(); + UnityEditor.Compilation.CompilationPipeline.RequestScriptCompilation(); #endif + }; } else { @@ -1523,11 +1594,14 @@ private static bool ValidateScriptSyntax(string contents, ValidationLevel level, } #if USE_ROSLYN - // Advanced Roslyn-based validation - if (!ValidateScriptSyntaxRoslyn(contents, level, errorList)) + // Advanced Roslyn-based validation: only run for Standard+; fail on Roslyn errors + if (level >= ValidationLevel.Standard) { - errors = errorList.ToArray(); - return level != ValidationLevel.Standard; //TODO: Allow standard to run roslyn right now, might formalize it in the future + if (!ValidateScriptSyntaxRoslyn(contents, level, errorList)) + { + errors = errorList.ToArray(); + return false; + } } #endif @@ -2105,20 +2179,28 @@ static class RefreshDebounce { private static int _pending; private static DateTime _last; + private static readonly object _lock = new object(); + private static readonly HashSet _paths = new HashSet(StringComparer.OrdinalIgnoreCase); + private static bool _scheduled; public static void Schedule(string relPath, TimeSpan window) { Interlocked.Exchange(ref _pending, 1); + lock (_lock) { _paths.Add(relPath); } var now = DateTime.UtcNow; - if ((now - _last) < window) return; + if (_scheduled && (now - _last) < window) return; _last = now; + _scheduled = true; EditorApplication.delayCall += () => { + _scheduled = false; if (Interlocked.Exchange(ref _pending, 0) == 1) { - // Prefer targeted import and script compile over full refresh - AssetDatabase.ImportAsset(relPath, ImportAssetOptions.ForceUpdate); + string[] toImport; + lock (_lock) { toImport = _paths.ToArray(); _paths.Clear(); } + foreach (var p in toImport) + AssetDatabase.ImportAsset(p, ImportAssetOptions.ForceUpdate); #if UNITY_EDITOR UnityEditor.Compilation.CompilationPipeline.RequestScriptCompilation(); #endif From fd791272075bd1c1258e961c29539b8b1d6af4bd Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 08:32:55 -0700 Subject: [PATCH 12/25] Fix debouncing race condition --- UnityMcpBridge/Editor/Tools/ManageScript.cs | 68 +++++++++++++++------ 1 file changed, 50 insertions(+), 18 deletions(-) diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index 91801142..90367c1a 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -2178,36 +2178,68 @@ private static void ValidateSemanticRules(string contents, System.Collections.Ge static class RefreshDebounce { private static int _pending; - private static DateTime _last; private static readonly object _lock = new object(); private static readonly HashSet _paths = new HashSet(StringComparer.OrdinalIgnoreCase); + + // The timestamp of the most recent schedule request. + private static DateTime _lastRequest; + + // Guard to ensure we only have a single ticking callback running. private static bool _scheduled; public static void Schedule(string relPath, TimeSpan window) { + // Record that work is pending and track the path in a threadsafe way. Interlocked.Exchange(ref _pending, 1); - lock (_lock) { _paths.Add(relPath); } - var now = DateTime.UtcNow; - if (_scheduled && (now - _last) < window) return; - _last = now; - _scheduled = true; + lock (_lock) + { + _paths.Add(relPath); + _lastRequest = DateTime.UtcNow; - EditorApplication.delayCall += () => + // If a debounce timer is already scheduled it will pick up the new request. + if (_scheduled) + return; + + _scheduled = true; + } + + // Kick off a ticking callback that waits until the window has elapsed + // from the last request before performing the refresh. + EditorApplication.delayCall += () => Tick(window); + } + + private static void Tick(TimeSpan window) + { + bool ready; + lock (_lock) { - _scheduled = false; - if (Interlocked.Exchange(ref _pending, 0) == 1) + // Only proceed once the debounce window has fully elapsed. + ready = (DateTime.UtcNow - _lastRequest) >= window; + if (ready) { - string[] toImport; - lock (_lock) { toImport = _paths.ToArray(); _paths.Clear(); } - foreach (var p in toImport) - AssetDatabase.ImportAsset(p, ImportAssetOptions.ForceUpdate); + _scheduled = false; + } + } + + if (!ready) + { + // Window has not yet elapsed; check again on the next editor tick. + EditorApplication.delayCall += () => Tick(window); + return; + } + + if (Interlocked.Exchange(ref _pending, 0) == 1) + { + string[] toImport; + lock (_lock) { toImport = _paths.ToArray(); _paths.Clear(); } + foreach (var p in toImport) + AssetDatabase.ImportAsset(p, ImportAssetOptions.ForceUpdate); #if UNITY_EDITOR - UnityEditor.Compilation.CompilationPipeline.RequestScriptCompilation(); + UnityEditor.Compilation.CompilationPipeline.RequestScriptCompilation(); #endif - // Fallback if needed: - // AssetDatabase.Refresh(); - } - }; + // Fallback if needed: + // AssetDatabase.Refresh(); + } } } From 200483e826d00e25337138abc1bdd58441b3bb7d Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 08:59:42 -0700 Subject: [PATCH 13/25] Add initial transport handshake tests with plan placeholders --- tests/test_logging_stdout.py | 11 ++++ tests/test_resources_api.py | 11 ++++ tests/test_script_editing.py | 36 +++++++++++ tests/test_transport_framing.py | 102 ++++++++++++++++++++++++++++++++ 4 files changed, 160 insertions(+) create mode 100644 tests/test_logging_stdout.py create mode 100644 tests/test_resources_api.py create mode 100644 tests/test_script_editing.py create mode 100644 tests/test_transport_framing.py diff --git a/tests/test_logging_stdout.py b/tests/test_logging_stdout.py new file mode 100644 index 00000000..98dc23f4 --- /dev/null +++ b/tests/test_logging_stdout.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.mark.skip(reason="TODO: ensure server logs only to stderr and rotating file") +def test_no_stdout_output_from_tools(): + pass + + +@pytest.mark.skip(reason="TODO: sweep for accidental print statements in codebase") +def test_no_print_statements_in_codebase(): + pass diff --git a/tests/test_resources_api.py b/tests/test_resources_api.py new file mode 100644 index 00000000..bdcd7290 --- /dev/null +++ b/tests/test_resources_api.py @@ -0,0 +1,11 @@ +import pytest + + +@pytest.mark.skip(reason="TODO: resource.list returns only Assets/**/*.cs and rejects traversal") +def test_resource_list_filters_and_rejects_traversal(): + pass + + +@pytest.mark.skip(reason="TODO: resource.list rejects file:// paths outside project, including drive letters and symlinks") +def test_resource_list_rejects_outside_paths(): + pass diff --git a/tests/test_script_editing.py b/tests/test_script_editing.py new file mode 100644 index 00000000..e0b3705b --- /dev/null +++ b/tests/test_script_editing.py @@ -0,0 +1,36 @@ +import pytest + + +@pytest.mark.skip(reason="TODO: create new script, validate, apply edits, build and compile scene") +def test_script_edit_happy_path(): + pass + + +@pytest.mark.skip(reason="TODO: multiple micro-edits debounce to single compilation") +def test_micro_edits_debounce(): + pass + + +@pytest.mark.skip(reason="TODO: line ending variations handled correctly") +def test_line_endings_and_columns(): + pass + + +@pytest.mark.skip(reason="TODO: regex_replace no-op with allow_noop honored") +def test_regex_replace_noop_allowed(): + pass + + +@pytest.mark.skip(reason="TODO: large edit size boundaries and overflow protection") +def test_large_edit_size_and_overflow(): + pass + + +@pytest.mark.skip(reason="TODO: symlink and junction protections on edits") +def test_symlink_and_junction_protection(): + pass + + +@pytest.mark.skip(reason="TODO: atomic write guarantees") +def test_atomic_write_guarantees(): + pass diff --git a/tests/test_transport_framing.py b/tests/test_transport_framing.py new file mode 100644 index 00000000..1c3d02fa --- /dev/null +++ b/tests/test_transport_framing.py @@ -0,0 +1,102 @@ +import sys +import json +import struct +import socket +import threading +import time +from pathlib import Path + +import pytest + +# add server src to path +ROOT = Path(__file__).resolve().parents[1] +SRC = ROOT / "UnityMcpBridge" / "UnityMcpServer~" / "src" +sys.path.insert(0, str(SRC)) + +from unity_connection import UnityConnection + + +def start_dummy_server(greeting: bytes, respond_ping: bool = False): + """Start a minimal TCP server for handshake tests.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("127.0.0.1", 0)) + sock.listen(1) + port = sock.getsockname()[1] + + def _run(): + conn, _ = sock.accept() + if greeting: + conn.sendall(greeting) + if respond_ping: + try: + header = conn.recv(8) + if len(header) == 8: + length = struct.unpack(">Q", header)[0] + payload = b"" + while len(payload) < length: + chunk = conn.recv(length - len(payload)) + if not chunk: + break + payload += chunk + if payload == b'{"type":"ping"}': + resp = b'{"type":"pong"}' + conn.sendall(struct.pack(">Q", len(resp)) + resp) + except Exception: + pass + time.sleep(0.1) + try: + conn.close() + finally: + sock.close() + + threading.Thread(target=_run, daemon=True).start() + return port + + +def test_handshake_requires_framing(): + port = start_dummy_server(b"MCP/0.1\n") + conn = UnityConnection(host="127.0.0.1", port=port) + assert conn.connect() is False + assert conn.sock is None + + +def test_small_frame_ping_pong(): + port = start_dummy_server(b"MCP/0.1 FRAMING=1\n", respond_ping=True) + conn = UnityConnection(host="127.0.0.1", port=port) + assert conn.connect() is True + assert conn.use_framing is True + payload = b'{"type":"ping"}' + conn.sock.sendall(struct.pack(">Q", len(payload)) + payload) + resp = conn.receive_full_response(conn.sock) + assert json.loads(resp.decode("utf-8"))["type"] == "pong" + conn.disconnect() + + +@pytest.mark.skip(reason="TODO: unframed data before reading greeting should disconnect") +def test_unframed_data_disconnect(): + pass + + +@pytest.mark.skip(reason="TODO: zero-length payload should raise error") +def test_zero_length_payload_error(): + pass + + +@pytest.mark.skip(reason="TODO: oversized payload should disconnect") +def test_oversized_payload_rejected(): + pass + + +@pytest.mark.skip(reason="TODO: partial header/payload triggers timeout and disconnect") +def test_partial_frame_timeout(): + pass + + +@pytest.mark.skip(reason="TODO: concurrency test with parallel tool invocations") +def test_parallel_invocations_no_interleaving(): + pass + + +@pytest.mark.skip(reason="TODO: reconnection after drop mid-command") +def test_reconnect_mid_command(): + pass From a3c81d657d5a2333a9e985643b0721cfe22dc6c5 Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 16:29:04 -0700 Subject: [PATCH 14/25] Fix dummy server startup and cleanup in transport tests --- tests/test_transport_framing.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_transport_framing.py b/tests/test_transport_framing.py index 1c3d02fa..50483f48 100644 --- a/tests/test_transport_framing.py +++ b/tests/test_transport_framing.py @@ -11,6 +11,8 @@ # add server src to path ROOT = Path(__file__).resolve().parents[1] SRC = ROOT / "UnityMcpBridge" / "UnityMcpServer~" / "src" +if not SRC.exists(): + raise FileNotFoundError(f"Server source directory not found: {SRC}") sys.path.insert(0, str(SRC)) from unity_connection import UnityConnection @@ -22,8 +24,10 @@ def start_dummy_server(greeting: bytes, respond_ping: bool = False): sock.bind(("127.0.0.1", 0)) sock.listen(1) port = sock.getsockname()[1] + ready = threading.Event() def _run(): + ready.set() conn, _ = sock.accept() if greeting: conn.sendall(greeting) @@ -46,10 +50,13 @@ def _run(): time.sleep(0.1) try: conn.close() + except Exception: + pass finally: sock.close() threading.Thread(target=_run, daemon=True).start() + ready.wait() return port From b01978c59e3e3a2589378a738fb6b764b05196fd Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 16:38:28 -0700 Subject: [PATCH 15/25] test: enforce no prints and handshake preamble --- README.md | 12 ++++++++++ tests/test_logging_stdout.py | 14 +++++++++-- tests/test_transport_framing.py | 42 +++++++++++++++++++++++++++++++-- 3 files changed, 64 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index d3c5c111..bb4dd965 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,18 @@ Unity MCP connects your tools using two components: --- +### Transport framing + +Unity MCP requires explicit framing negotiation. After connecting, the server +sends a `MCP/0.1` greeting. Clients must respond with `FRAMING=1`, and all +subsequent messages are sent as 8-byte big-endian length-prefixed JSON frames. + +### Resource URIs + +Assets are addressed using `unity://` URIs relative to the project root. For +example, `unity://path/Assets/Scripts/Foo.cs` refers to the file +`Assets/Scripts/Foo.cs` inside the Unity project. + ## Installation ⚙️ > **Note:** The setup is constantly improving as we update the package. Check back if you randomly start to run into issues. diff --git a/tests/test_logging_stdout.py b/tests/test_logging_stdout.py index 98dc23f4..d4389818 100644 --- a/tests/test_logging_stdout.py +++ b/tests/test_logging_stdout.py @@ -1,3 +1,6 @@ +import re +from pathlib import Path + import pytest @@ -6,6 +9,13 @@ def test_no_stdout_output_from_tools(): pass -@pytest.mark.skip(reason="TODO: sweep for accidental print statements in codebase") def test_no_print_statements_in_codebase(): - pass + """Ensure no stray print statements remain in server source.""" + src = Path(__file__).resolve().parents[1] / "UnityMcpBridge" / "UnityMcpServer~" / "src" + assert src.exists(), f"Server source directory not found: {src}" + offenders = [] + for py_file in src.rglob("*.py"): + text = py_file.read_text(encoding="utf-8") + if re.search(r"^\s*print\(", text, re.MULTILINE): + offenders.append(py_file.relative_to(src)) + assert not offenders, f"print statements found in: {offenders}" diff --git a/tests/test_transport_framing.py b/tests/test_transport_framing.py index 50483f48..602cb312 100644 --- a/tests/test_transport_framing.py +++ b/tests/test_transport_framing.py @@ -4,6 +4,7 @@ import socket import threading import time +import select from pathlib import Path import pytest @@ -60,6 +61,33 @@ def _run(): return port +def start_handshake_enforcing_server(): + """Server that drops connection if client sends data before handshake.""" + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.bind(("127.0.0.1", 0)) + sock.listen(1) + port = sock.getsockname()[1] + ready = threading.Event() + + def _run(): + ready.set() + conn, _ = sock.accept() + # if client sends any data before greeting, disconnect + r, _, _ = select.select([conn], [], [], 0.1) + if r: + conn.close() + sock.close() + return + conn.sendall(b"MCP/0.1 FRAMING=1\n") + time.sleep(0.1) + conn.close() + sock.close() + + threading.Thread(target=_run, daemon=True).start() + ready.wait() + return port + + def test_handshake_requires_framing(): port = start_dummy_server(b"MCP/0.1\n") conn = UnityConnection(host="127.0.0.1", port=port) @@ -79,9 +107,19 @@ def test_small_frame_ping_pong(): conn.disconnect() -@pytest.mark.skip(reason="TODO: unframed data before reading greeting should disconnect") def test_unframed_data_disconnect(): - pass + port = start_handshake_enforcing_server() + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(("127.0.0.1", port)) + sock.sendall(b"BAD") + time.sleep(0.1) + try: + data = sock.recv(1024) + assert data == b"" + except ConnectionError: + pass + finally: + sock.close() @pytest.mark.skip(reason="TODO: zero-length payload should raise error") From 555d96510bd077e3a1dc8789a2426943e2e55ff4 Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 16:39:02 -0700 Subject: [PATCH 16/25] feat: add defensive server path resolution in tests --- tests/test_transport_framing.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/tests/test_transport_framing.py b/tests/test_transport_framing.py index 602cb312..46601c3a 100644 --- a/tests/test_transport_framing.py +++ b/tests/test_transport_framing.py @@ -9,11 +9,18 @@ import pytest -# add server src to path +# locate server src dynamically to avoid hardcoded layout assumptions ROOT = Path(__file__).resolve().parents[1] -SRC = ROOT / "UnityMcpBridge" / "UnityMcpServer~" / "src" -if not SRC.exists(): - raise FileNotFoundError(f"Server source directory not found: {SRC}") +candidates = [ + ROOT / "UnityMcpBridge" / "UnityMcpServer~" / "src", + ROOT / "UnityMcpServer~" / "src", +] +SRC = next((p for p in candidates if p.exists()), None) +if SRC is None: + searched = "\n".join(str(p) for p in candidates) + raise FileNotFoundError( + "Unity MCP server source not found. Tried:\n" + searched + ) sys.path.insert(0, str(SRC)) from unity_connection import UnityConnection From e4544f68c3e0d271e9d9f48e5575b092e920fc0a Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 16:49:52 -0700 Subject: [PATCH 17/25] Refine server source path lookup --- README.md | 12 ------------ tests/test_logging_stdout.py | 20 ++++++++++++++++---- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index bb4dd965..d3c5c111 100644 --- a/README.md +++ b/README.md @@ -58,18 +58,6 @@ Unity MCP connects your tools using two components: --- -### Transport framing - -Unity MCP requires explicit framing negotiation. After connecting, the server -sends a `MCP/0.1` greeting. Clients must respond with `FRAMING=1`, and all -subsequent messages are sent as 8-byte big-endian length-prefixed JSON frames. - -### Resource URIs - -Assets are addressed using `unity://` URIs relative to the project root. For -example, `unity://path/Assets/Scripts/Foo.cs` refers to the file -`Assets/Scripts/Foo.cs` inside the Unity project. - ## Installation ⚙️ > **Note:** The setup is constantly improving as we update the package. Check back if you randomly start to run into issues. diff --git a/tests/test_logging_stdout.py b/tests/test_logging_stdout.py index d4389818..38e55d20 100644 --- a/tests/test_logging_stdout.py +++ b/tests/test_logging_stdout.py @@ -4,6 +4,20 @@ import pytest +# locate server src dynamically to avoid hardcoded layout assumptions +ROOT = Path(__file__).resolve().parents[1] +candidates = [ + ROOT / "UnityMcpBridge" / "UnityMcpServer~" / "src", + ROOT / "UnityMcpServer~" / "src", +] +SRC = next((p for p in candidates if p.exists()), None) +if SRC is None: + searched = "\n".join(str(p) for p in candidates) + raise FileNotFoundError( + "Unity MCP server source not found. Tried:\n" + searched + ) + + @pytest.mark.skip(reason="TODO: ensure server logs only to stderr and rotating file") def test_no_stdout_output_from_tools(): pass @@ -11,11 +25,9 @@ def test_no_stdout_output_from_tools(): def test_no_print_statements_in_codebase(): """Ensure no stray print statements remain in server source.""" - src = Path(__file__).resolve().parents[1] / "UnityMcpBridge" / "UnityMcpServer~" / "src" - assert src.exists(), f"Server source directory not found: {src}" offenders = [] - for py_file in src.rglob("*.py"): + for py_file in SRC.rglob("*.py"): text = py_file.read_text(encoding="utf-8") if re.search(r"^\s*print\(", text, re.MULTILINE): - offenders.append(py_file.relative_to(src)) + offenders.append(py_file.relative_to(SRC)) assert not offenders, f"print statements found in: {offenders}" From 9dbb4ffbcb5cb23a61f39d499b461d8fb82ad353 Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 17:02:36 -0700 Subject: [PATCH 18/25] Refine handshake tests and stdout hygiene --- tests/test_logging_stdout.py | 8 ++++++-- tests/test_transport_framing.py | 23 ++++++++++++++--------- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/tests/test_logging_stdout.py b/tests/test_logging_stdout.py index 38e55d20..d6e728b7 100644 --- a/tests/test_logging_stdout.py +++ b/tests/test_logging_stdout.py @@ -28,6 +28,10 @@ def test_no_print_statements_in_codebase(): offenders = [] for py_file in SRC.rglob("*.py"): text = py_file.read_text(encoding="utf-8") - if re.search(r"^\s*print\(", text, re.MULTILINE): + if re.search(r"^\s*print\(", text, re.MULTILINE) or re.search( + r"sys\.stdout\.write\(", text + ): offenders.append(py_file.relative_to(SRC)) - assert not offenders, f"print statements found in: {offenders}" + assert not offenders, ( + "stdout writes found in: " + ", ".join(str(o) for o in offenders) + ) diff --git a/tests/test_transport_framing.py b/tests/test_transport_framing.py index 46601c3a..39e84afd 100644 --- a/tests/test_transport_framing.py +++ b/tests/test_transport_framing.py @@ -80,7 +80,8 @@ def _run(): ready.set() conn, _ = sock.accept() # if client sends any data before greeting, disconnect - r, _, _ = select.select([conn], [], [], 0.1) + # give clients a bit more time to send pre-handshake data before we greet + r, _, _ = select.select([conn], [], [], 0.2) if r: conn.close() sock.close() @@ -105,13 +106,15 @@ def test_handshake_requires_framing(): def test_small_frame_ping_pong(): port = start_dummy_server(b"MCP/0.1 FRAMING=1\n", respond_ping=True) conn = UnityConnection(host="127.0.0.1", port=port) - assert conn.connect() is True - assert conn.use_framing is True - payload = b'{"type":"ping"}' - conn.sock.sendall(struct.pack(">Q", len(payload)) + payload) - resp = conn.receive_full_response(conn.sock) - assert json.loads(resp.decode("utf-8"))["type"] == "pong" - conn.disconnect() + try: + assert conn.connect() is True + assert conn.use_framing is True + payload = b'{"type":"ping"}' + conn.sock.sendall(struct.pack(">Q", len(payload)) + payload) + resp = conn.receive_full_response(conn.sock) + assert json.loads(resp.decode("utf-8"))["type"] == "pong" + finally: + conn.disconnect() def test_unframed_data_disconnect(): @@ -123,7 +126,9 @@ def test_unframed_data_disconnect(): try: data = sock.recv(1024) assert data == b"" - except ConnectionError: + except (ConnectionResetError, ConnectionAbortedError): + # Some platforms raise instead of returning empty bytes when the + # server closes the connection after detecting pre-handshake data. pass finally: sock.close() From 7a42fe6f4687b645761b0a9af2075657b8117c87 Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 17:23:03 -0700 Subject: [PATCH 19/25] Convert skipped tests to xfail and improve framing robustness --- .../UnityMcpServer~/src/unity_connection.py | 2 + tests/test_logging_stdout.py | 48 ++++++++++++++----- tests/test_resources_api.py | 4 +- tests/test_script_editing.py | 14 +++--- tests/test_transport_framing.py | 42 +++++++++------- 5 files changed, 73 insertions(+), 37 deletions(-) diff --git a/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py b/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py index ab47a503..7bf28c01 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py +++ b/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py @@ -94,6 +94,8 @@ def receive_full_response(self, sock, buffer_size=config.buffer_size) -> bytes: try: header = self._read_exact(sock, 8) payload_len = struct.unpack('>Q', header)[0] + if payload_len == 0: + raise Exception("Invalid framed length: 0") if payload_len > (64 * 1024 * 1024): raise Exception(f"Invalid framed length: {payload_len}") payload = self._read_exact(sock, payload_len) diff --git a/tests/test_logging_stdout.py b/tests/test_logging_stdout.py index d6e728b7..2c9d0163 100644 --- a/tests/test_logging_stdout.py +++ b/tests/test_logging_stdout.py @@ -1,4 +1,4 @@ -import re +import ast from pathlib import Path import pytest @@ -13,8 +13,9 @@ SRC = next((p for p in candidates if p.exists()), None) if SRC is None: searched = "\n".join(str(p) for p in candidates) - raise FileNotFoundError( - "Unity MCP server source not found. Tried:\n" + searched + pytest.skip( + "Unity MCP server source not found. Tried:\n" + searched, + allow_module_level=True, ) @@ -24,14 +25,39 @@ def test_no_stdout_output_from_tools(): def test_no_print_statements_in_codebase(): - """Ensure no stray print statements remain in server source.""" + """Ensure no stray print/sys.stdout writes remain in server source.""" offenders = [] for py_file in SRC.rglob("*.py"): - text = py_file.read_text(encoding="utf-8") - if re.search(r"^\s*print\(", text, re.MULTILINE) or re.search( - r"sys\.stdout\.write\(", text - ): + try: + text = py_file.read_text(encoding="utf-8", errors="strict") + except UnicodeDecodeError: + # Be tolerant of encoding edge cases in source tree + text = py_file.read_text(encoding="utf-8", errors="ignore") + try: + tree = ast.parse(text, filename=str(py_file)) + except SyntaxError: offenders.append(py_file.relative_to(SRC)) - assert not offenders, ( - "stdout writes found in: " + ", ".join(str(o) for o in offenders) - ) + continue + + class StdoutVisitor(ast.NodeVisitor): + def __init__(self): + self.hit = False + + def visit_Call(self, node: ast.Call): + # print(...) + if isinstance(node.func, ast.Name) and node.func.id == "print": + self.hit = True + # sys.stdout.write(...) + if isinstance(node.func, ast.Attribute) and node.func.attr == "write": + val = node.func.value + if isinstance(val, ast.Attribute) and val.attr == "stdout": + if isinstance(val.value, ast.Name) and val.value.id == "sys": + self.hit = True + self.generic_visit(node) + + v = StdoutVisitor() + v.visit(tree) + if v.hit: + offenders.append(py_file.relative_to(SRC)) + + assert not offenders, "stdout writes found in: " + ", ".join(str(o) for o in offenders) diff --git a/tests/test_resources_api.py b/tests/test_resources_api.py index bdcd7290..62cc1ac1 100644 --- a/tests/test_resources_api.py +++ b/tests/test_resources_api.py @@ -1,11 +1,11 @@ import pytest -@pytest.mark.skip(reason="TODO: resource.list returns only Assets/**/*.cs and rejects traversal") +@pytest.mark.xfail(strict=False, reason="resource.list should return only Assets/**/*.cs and reject traversal") def test_resource_list_filters_and_rejects_traversal(): pass -@pytest.mark.skip(reason="TODO: resource.list rejects file:// paths outside project, including drive letters and symlinks") +@pytest.mark.xfail(strict=False, reason="resource.list should reject outside paths including drive letters and symlinks") def test_resource_list_rejects_outside_paths(): pass diff --git a/tests/test_script_editing.py b/tests/test_script_editing.py index e0b3705b..88046d00 100644 --- a/tests/test_script_editing.py +++ b/tests/test_script_editing.py @@ -1,36 +1,36 @@ import pytest -@pytest.mark.skip(reason="TODO: create new script, validate, apply edits, build and compile scene") +@pytest.mark.xfail(strict=False, reason="pending: create new script, validate, apply edits, build and compile scene") def test_script_edit_happy_path(): pass -@pytest.mark.skip(reason="TODO: multiple micro-edits debounce to single compilation") +@pytest.mark.xfail(strict=False, reason="pending: multiple micro-edits debounce to single compilation") def test_micro_edits_debounce(): pass -@pytest.mark.skip(reason="TODO: line ending variations handled correctly") +@pytest.mark.xfail(strict=False, reason="pending: line ending variations handled correctly") def test_line_endings_and_columns(): pass -@pytest.mark.skip(reason="TODO: regex_replace no-op with allow_noop honored") +@pytest.mark.xfail(strict=False, reason="pending: regex_replace no-op with allow_noop honored") def test_regex_replace_noop_allowed(): pass -@pytest.mark.skip(reason="TODO: large edit size boundaries and overflow protection") +@pytest.mark.xfail(strict=False, reason="pending: large edit size boundaries and overflow protection") def test_large_edit_size_and_overflow(): pass -@pytest.mark.skip(reason="TODO: symlink and junction protections on edits") +@pytest.mark.xfail(strict=False, reason="pending: symlink and junction protections on edits") def test_symlink_and_junction_protection(): pass -@pytest.mark.skip(reason="TODO: atomic write guarantees") +@pytest.mark.xfail(strict=False, reason="pending: atomic write guarantees") def test_atomic_write_guarantees(): pass diff --git a/tests/test_transport_framing.py b/tests/test_transport_framing.py index 39e84afd..011473b3 100644 --- a/tests/test_transport_framing.py +++ b/tests/test_transport_framing.py @@ -18,8 +18,9 @@ SRC = next((p for p in candidates if p.exists()), None) if SRC is None: searched = "\n".join(str(p) for p in candidates) - raise FileNotFoundError( - "Unity MCP server source not found. Tried:\n" + searched + pytest.skip( + "Unity MCP server source not found. Tried:\n" + searched, + allow_module_level=True, ) sys.path.insert(0, str(SRC)) @@ -37,19 +38,25 @@ def start_dummy_server(greeting: bytes, respond_ping: bool = False): def _run(): ready.set() conn, _ = sock.accept() + conn.settimeout(1.0) if greeting: conn.sendall(greeting) if respond_ping: try: - header = conn.recv(8) - if len(header) == 8: - length = struct.unpack(">Q", header)[0] - payload = b"" - while len(payload) < length: - chunk = conn.recv(length - len(payload)) + # Read exactly n bytes helper + def _read_exact(n: int) -> bytes: + buf = b"" + while len(buf) < n: + chunk = conn.recv(n - len(buf)) if not chunk: break - payload += chunk + buf += chunk + return buf + + header = _read_exact(8) + if len(header) == 8: + length = struct.unpack(">Q", header)[0] + payload = _read_exact(length) if payload == b'{"type":"ping"}': resp = b'{"type":"pong"}' conn.sendall(struct.pack(">Q", len(resp)) + resp) @@ -79,13 +86,14 @@ def start_handshake_enforcing_server(): def _run(): ready.set() conn, _ = sock.accept() - # if client sends any data before greeting, disconnect - # give clients a bit more time to send pre-handshake data before we greet - r, _, _ = select.select([conn], [], [], 0.2) - if r: - conn.close() - sock.close() - return + # If client sends any data before greeting, disconnect (poll briefly) + deadline = time.time() + 0.5 + while time.time() < deadline: + r, _, _ = select.select([conn], [], [], 0.05) + if r: + conn.close() + sock.close() + return conn.sendall(b"MCP/0.1 FRAMING=1\n") time.sleep(0.1) conn.close() @@ -122,7 +130,7 @@ def test_unframed_data_disconnect(): sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", port)) sock.sendall(b"BAD") - time.sleep(0.1) + time.sleep(0.4) try: data = sock.recv(1024) assert data == b"" From 1a50016503183e7953cc74c43c143ff9458af5b2 Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 17:32:23 -0700 Subject: [PATCH 20/25] clarify stdout test failure messaging --- tests/test_logging_stdout.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/test_logging_stdout.py b/tests/test_logging_stdout.py index 2c9d0163..4c53a10f 100644 --- a/tests/test_logging_stdout.py +++ b/tests/test_logging_stdout.py @@ -27,6 +27,7 @@ def test_no_stdout_output_from_tools(): def test_no_print_statements_in_codebase(): """Ensure no stray print/sys.stdout writes remain in server source.""" offenders = [] + syntax_errors = [] for py_file in SRC.rglob("*.py"): try: text = py_file.read_text(encoding="utf-8", errors="strict") @@ -36,7 +37,7 @@ def test_no_print_statements_in_codebase(): try: tree = ast.parse(text, filename=str(py_file)) except SyntaxError: - offenders.append(py_file.relative_to(SRC)) + syntax_errors.append(py_file.relative_to(SRC)) continue class StdoutVisitor(ast.NodeVisitor): @@ -60,4 +61,5 @@ def visit_Call(self, node: ast.Call): if v.hit: offenders.append(py_file.relative_to(SRC)) + assert not syntax_errors, "syntax errors in: " + ", ".join(str(e) for e in syntax_errors) assert not offenders, "stdout writes found in: " + ", ".join(str(o) for o in offenders) From b45b80f19d8fced51df22b87d804a509eecf6d5b Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 19:40:54 -0700 Subject: [PATCH 21/25] Add handshake fallback and logging checks --- .../UnityMcpServer~/src/unity_connection.py | 102 ++++++++++-------- tests/test_logging_stdout.py | 22 +++- tests/test_transport_framing.py | 1 + 3 files changed, 76 insertions(+), 49 deletions(-) diff --git a/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py b/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py index 7bf28c01..2726966f 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py +++ b/UnityMcpBridge/UnityMcpServer~/src/unity_connection.py @@ -1,13 +1,15 @@ -import socket +import contextlib +import errno import json import logging +import random +import socket import struct +import threading +import time from dataclasses import dataclass from pathlib import Path -import time -import random -import errno -from typing import Dict, Any +from typing import Any, Dict from config import config from port_discovery import PortDiscovery @@ -30,6 +32,7 @@ def __post_init__(self): """Set port from discovery if not explicitly provided""" if self.port is None: self.port = PortDiscovery.discover_unity_port() + self._io_lock = threading.Lock() def connect(self) -> bool: """Establish a connection to the Unity Editor.""" @@ -42,20 +45,24 @@ def connect(self) -> bool: # Strict handshake: require FRAMING=1 try: - self.sock.settimeout(1.0) + require_framing = getattr(config, "require_framing", True) + self.sock.settimeout(getattr(config, "handshake_timeout", 1.0)) greeting = self.sock.recv(256) text = greeting.decode('ascii', errors='ignore') if greeting else '' if 'FRAMING=1' in text: self.use_framing = True logger.debug('Unity MCP handshake received: FRAMING=1 (strict)') else: - try: - msg = b'Unity MCP requires FRAMING=1' - header = struct.pack('>Q', len(msg)) - self.sock.sendall(header + msg) - except Exception: - pass - raise ConnectionError(f'Unity MCP requires FRAMING=1, got: {text!r}') + if require_framing: + # Best-effort advisory; peer may ignore if not framed-capable + with contextlib.suppress(Exception): + msg = b'Unity MCP requires FRAMING=1' + header = struct.pack('>Q', len(msg)) + self.sock.sendall(header + msg) + raise ConnectionError(f'Unity MCP requires FRAMING=1, got: {text!r}') + else: + self.use_framing = False + logger.warning('Unity MCP handshake missing FRAMING=1; proceeding in legacy mode by configuration') finally: self.sock.settimeout(config.connection_timeout) return True @@ -101,9 +108,9 @@ def receive_full_response(self, sock, buffer_size=config.buffer_size) -> bytes: payload = self._read_exact(sock, payload_len) logger.info(f"Received framed response ({len(payload)} bytes)") return payload - except socket.timeout: + except socket.timeout as e: logger.warning("Socket timeout during framed receive") - raise Exception("Timeout receiving Unity response") + raise TimeoutError("Timeout receiving Unity response") from e except Exception as e: logger.error(f"Error during framed receive: {str(e)}") raise @@ -201,10 +208,9 @@ def read_status_file() -> dict | None: for attempt in range(attempts + 1): try: - # Ensure connected (perform handshake each time so framing stays correct) - if not self.sock: - if not self.connect(): - raise Exception("Could not connect to Unity") + # Ensure connected (handshake occurs within connect()) + if not self.sock and not self.connect(): + raise Exception("Could not connect to Unity") # Build payload if command_type == 'ping': @@ -213,31 +219,39 @@ def read_status_file() -> dict | None: command = {"type": command_type, "params": params or {}} payload = json.dumps(command, ensure_ascii=False).encode('utf-8') - # Send - try: - logger.debug(f"send {len(payload)} bytes; mode={'framed' if self.use_framing else 'legacy'}; head={(payload[:32]).decode('utf-8','ignore')}") - except Exception: - pass - if self.use_framing: - header = struct.pack('>Q', len(payload)) - self.sock.sendall(header) - self.sock.sendall(payload) - else: - self.sock.sendall(payload) - - # During retry bursts use a short receive timeout - if attempt > 0 and last_short_timeout is None: - last_short_timeout = self.sock.gettimeout() - self.sock.settimeout(1.0) - response_data = self.receive_full_response(self.sock) - try: - logger.debug(f"recv {len(response_data)} bytes; mode={'framed' if self.use_framing else 'legacy'}; head={(response_data[:32]).decode('utf-8','ignore')}") - except Exception: - pass - # restore steady-state timeout if changed - if last_short_timeout is not None: - self.sock.settimeout(config.connection_timeout) - last_short_timeout = None + # Send/receive are serialized to protect the shared socket + with self._io_lock: + mode = 'framed' if self.use_framing else 'legacy' + with contextlib.suppress(Exception): + logger.debug( + "send %d bytes; mode=%s; head=%s", + len(payload), + mode, + (payload[:32]).decode('utf-8', 'ignore'), + ) + if self.use_framing: + header = struct.pack('>Q', len(payload)) + self.sock.sendall(header) + self.sock.sendall(payload) + else: + self.sock.sendall(payload) + + # During retry bursts use a short receive timeout + if attempt > 0 and last_short_timeout is None: + last_short_timeout = self.sock.gettimeout() + self.sock.settimeout(1.0) + response_data = self.receive_full_response(self.sock) + with contextlib.suppress(Exception): + logger.debug( + "recv %d bytes; mode=%s; head=%s", + len(response_data), + mode, + (response_data[:32]).decode('utf-8', 'ignore'), + ) + # restore steady-state timeout if changed + if last_short_timeout is not None: + self.sock.settimeout(last_short_timeout) + last_short_timeout = None # Parse if command_type == 'ping': diff --git a/tests/test_logging_stdout.py b/tests/test_logging_stdout.py index 4c53a10f..6fef7861 100644 --- a/tests/test_logging_stdout.py +++ b/tests/test_logging_stdout.py @@ -48,12 +48,24 @@ def visit_Call(self, node: ast.Call): # print(...) if isinstance(node.func, ast.Name) and node.func.id == "print": self.hit = True + # builtins.print(...) + elif ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "print" + and isinstance(node.func.value, ast.Name) + and node.func.value.id == "builtins" + ): + self.hit = True # sys.stdout.write(...) - if isinstance(node.func, ast.Attribute) and node.func.attr == "write": - val = node.func.value - if isinstance(val, ast.Attribute) and val.attr == "stdout": - if isinstance(val.value, ast.Name) and val.value.id == "sys": - self.hit = True + if ( + isinstance(node.func, ast.Attribute) + and node.func.attr == "write" + and isinstance(node.func.value, ast.Attribute) + and node.func.value.attr == "stdout" + and isinstance(node.func.value.value, ast.Name) + and node.func.value.value.id == "sys" + ): + self.hit = True self.generic_visit(node) v = StdoutVisitor() diff --git a/tests/test_transport_framing.py b/tests/test_transport_framing.py index 011473b3..2008c4c1 100644 --- a/tests/test_transport_framing.py +++ b/tests/test_transport_framing.py @@ -129,6 +129,7 @@ def test_unframed_data_disconnect(): port = start_handshake_enforcing_server() sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) sock.connect(("127.0.0.1", port)) + sock.settimeout(1.0) sock.sendall(b"BAD") time.sleep(0.4) try: From 9c759f16cd3a58df092bd92e25d46904ed288012 Mon Sep 17 00:00:00 2001 From: David Sarno Date: Sun, 17 Aug 2025 13:06:33 -0700 Subject: [PATCH 22/25] Claude Desktop: write BOM-free config to macOS path; dual-path fallback; add uv -q for quieter stdio; MCP server: compatibility guards for capabilities/resource decorators and indentation fix; ManageScript: shadow var fix; robust mac config path. --- UnityMcpBridge/Editor/Data/McpClients.cs | 7 + UnityMcpBridge/Editor/Models/McpClient.cs | 1 + UnityMcpBridge/Editor/Tools/ManageScript.cs | 8 +- .../Editor/Windows/UnityMcpEditorWindow.cs | 130 +++++++++++++++--- UnityMcpBridge/UnityMcpServer~/src/server.py | 54 ++++---- 5 files changed, 154 insertions(+), 46 deletions(-) diff --git a/UnityMcpBridge/Editor/Data/McpClients.cs b/UnityMcpBridge/Editor/Data/McpClients.cs index ac5d8e3e..3a9fade3 100644 --- a/UnityMcpBridge/Editor/Data/McpClients.cs +++ b/UnityMcpBridge/Editor/Data/McpClients.cs @@ -69,6 +69,13 @@ public class McpClients "Claude", "claude_desktop_config.json" ), + macConfigPath = Path.Combine( + Environment.GetFolderPath(Environment.SpecialFolder.Personal), + "Library", + "Application Support", + "Claude", + "claude_desktop_config.json" + ), linuxConfigPath = Path.Combine( Environment.GetFolderPath(Environment.SpecialFolder.UserProfile), ".config", diff --git a/UnityMcpBridge/Editor/Models/McpClient.cs b/UnityMcpBridge/Editor/Models/McpClient.cs index 9f69e903..005a4e1b 100644 --- a/UnityMcpBridge/Editor/Models/McpClient.cs +++ b/UnityMcpBridge/Editor/Models/McpClient.cs @@ -4,6 +4,7 @@ public class McpClient { public string name; public string windowsConfigPath; + public string macConfigPath; public string linuxConfigPath; public McpTypes mcpType; public string configStatus; diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index 90367c1a..31ce8e78 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -196,9 +196,9 @@ public static object HandleCommand(JObject @params) return DeleteScript(fullPath, relativePath); case "apply_text_edits": { - var edits = @params["edits"] as JArray; + var textEdits = @params["edits"] as JArray; string precondition = @params["precondition_sha256"]?.ToString(); - return ApplyTextEdits(fullPath, relativePath, name, edits, precondition); + return ApplyTextEdits(fullPath, relativePath, name, textEdits, precondition); } case "validate": { @@ -231,9 +231,9 @@ public static object HandleCommand(JObject @params) } case "edit": Debug.LogWarning("manage_script.edit is deprecated; prefer apply_text_edits. Serving structured edit for backward compatibility."); - var edits = @params["edits"] as JArray; + var structEdits = @params["edits"] as JArray; var options = @params["options"] as JObject; - return EditScript(fullPath, relativePath, name, edits, options); + return EditScript(fullPath, relativePath, name, structEdits, options); default: return Response.Error( $"Unknown action: '{action}'. Valid actions are: create, delete, apply_text_edits, validate, read (deprecated), update (deprecated), edit (deprecated)." diff --git a/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs b/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs index d80ffbb5..d3e0b012 100644 --- a/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs +++ b/UnityMcpBridge/Editor/Windows/UnityMcpEditorWindow.cs @@ -1083,12 +1083,32 @@ private string WriteToConfig(string pythonDir, string configPath, McpClient mcpC serverSrc = ServerInstaller.GetServerPath(); } - // 2) Canonical args order - var newArgs = new[] { "run", "--directory", serverSrc, "server.py" }; + // 2) Canonical args order (add quiet flag to prevent stdout noise breaking MCP stdio) + var newArgs = new[] { "-q", "run", "--directory", serverSrc, "server.py" }; // 3) Only write if changed bool changed = !string.Equals(existingCommand, uvPath, StringComparison.Ordinal) || !ArgsEqual(existingArgs, newArgs); + + // If the existing file contains a UTF-8 BOM, force a rewrite to remove it + try + { + if (System.IO.File.Exists(configPath)) + { + using (var fs = new System.IO.FileStream(configPath, System.IO.FileMode.Open, System.IO.FileAccess.Read, System.IO.FileShare.ReadWrite)) + { + if (fs.Length >= 3) + { + int b1 = fs.ReadByte(); + int b2 = fs.ReadByte(); + int b3 = fs.ReadByte(); + bool hasBom = (b1 == 0xEF && b2 == 0xBB && b3 == 0xBF); + if (hasBom) changed = true; + } + } + } + } + catch { } if (!changed) { return "Configured successfully"; // nothing to do @@ -1112,12 +1132,29 @@ private string WriteToConfig(string pythonDir, string configPath, McpClient mcpC } string mergedJson = JsonConvert.SerializeObject(existingConfig, jsonSettings); - string tmp = configPath + ".tmp"; - System.IO.File.WriteAllText(tmp, mergedJson, System.Text.Encoding.UTF8); - if (System.IO.File.Exists(configPath)) - System.IO.File.Replace(tmp, configPath, null); - else - System.IO.File.Move(tmp, configPath); + + // Write without BOM and fsync to avoid transient parse failures + try + { + WriteJsonAtomicallyNoBom(configPath, mergedJson); + } + catch + { + // Fallback simple write if atomic path fails + var encNoBom = new System.Text.UTF8Encoding(encoderShouldEmitUTF8Identifier: false); + System.IO.File.WriteAllText(configPath, mergedJson, encNoBom); + } + + // Validate that resulting file is valid JSON + try + { + var verify = System.IO.File.ReadAllText(configPath); + JsonConvert.DeserializeObject(verify); + } + catch (Exception ex) + { + UnityEngine.Debug.LogWarning($"UnityMCP: Wrote config but JSON re-parse failed: {ex.Message}"); + } try { if (IsValidUv(uvPath)) UnityEditor.EditorPrefs.SetString("UnityMCP.UvPath", uvPath); @@ -1128,6 +1165,23 @@ private string WriteToConfig(string pythonDir, string configPath, McpClient mcpC return "Configured successfully"; } + private static void WriteJsonAtomicallyNoBom(string path, string json) + { + string tmp = path + ".tmp"; + var encNoBom = new System.Text.UTF8Encoding(encoderShouldEmitUTF8Identifier: false); + using (var fs = new System.IO.FileStream(tmp, System.IO.FileMode.Create, System.IO.FileAccess.Write, System.IO.FileShare.None)) + using (var sw = new System.IO.StreamWriter(fs, encNoBom)) + { + sw.Write(json); + sw.Flush(); + fs.Flush(true); + } + if (System.IO.File.Exists(path)) + System.IO.File.Replace(tmp, path, null); + else + System.IO.File.Move(tmp, path); + } + private void ShowManualConfigurationInstructions( string configPath, McpClient mcpClient @@ -1328,10 +1382,13 @@ private string ConfigureMcpClient(McpClient mcpClient) { configPath = mcpClient.windowsConfigPath; } - else if ( - RuntimeInformation.IsOSPlatform(OSPlatform.OSX) - || RuntimeInformation.IsOSPlatform(OSPlatform.Linux) - ) + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + configPath = string.IsNullOrEmpty(mcpClient.macConfigPath) + ? mcpClient.linuxConfigPath + : mcpClient.macConfigPath; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { configPath = mcpClient.linuxConfigPath; } @@ -1354,6 +1411,22 @@ private string ConfigureMcpClient(McpClient mcpClient) string result = WriteToConfig(pythonDir, configPath, mcpClient); + // On macOS for Claude Desktop, also mirror to Linux-style path for backward compatibility + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX) + && mcpClient?.mcpType == McpTypes.ClaudeDesktop) + { + string altPath = mcpClient.linuxConfigPath; + if (!string.IsNullOrEmpty(altPath) && !string.Equals(configPath, altPath, StringComparison.Ordinal)) + { + try + { + Directory.CreateDirectory(Path.GetDirectoryName(altPath)); + WriteToConfig(pythonDir, altPath, mcpClient); + } + catch { } + } + } + // Update the client status after successful configuration if (result == "Configured successfully") { @@ -1482,10 +1555,13 @@ private void CheckMcpConfiguration(McpClient mcpClient) { configPath = mcpClient.windowsConfigPath; } - else if ( - RuntimeInformation.IsOSPlatform(OSPlatform.OSX) - || RuntimeInformation.IsOSPlatform(OSPlatform.Linux) - ) + else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX)) + { + configPath = string.IsNullOrEmpty(mcpClient.macConfigPath) + ? mcpClient.linuxConfigPath + : mcpClient.macConfigPath; + } + else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux)) { configPath = mcpClient.linuxConfigPath; } @@ -1497,8 +1573,26 @@ private void CheckMcpConfiguration(McpClient mcpClient) if (!File.Exists(configPath)) { - mcpClient.SetStatus(McpStatus.NotConfigured); - return; + // On macOS for Claude Desktop, fall back to Linux-style path if present + if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX) + && mcpClient?.mcpType == McpTypes.ClaudeDesktop) + { + string altPath = mcpClient.linuxConfigPath; + if (!string.IsNullOrEmpty(altPath) && File.Exists(altPath)) + { + configPath = altPath; // read from fallback + } + else + { + mcpClient.SetStatus(McpStatus.NotConfigured); + return; + } + } + else + { + mcpClient.SetStatus(McpStatus.NotConfigured); + return; + } } string configJson = File.ReadAllText(configPath); diff --git a/UnityMcpBridge/UnityMcpServer~/src/server.py b/UnityMcpBridge/UnityMcpServer~/src/server.py index 99f41229..fdec41ea 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/server.py +++ b/UnityMcpBridge/UnityMcpServer~/src/server.py @@ -96,9 +96,11 @@ def asset_creation_strategy() -> str: ) # Resources support: list and read Unity scripts/files -@mcp.capabilities(resources={"listChanged": True}) -class _: - pass +# Guard for older MCP versions without 'capabilities' API +if hasattr(mcp, "capabilities"): + @mcp.capabilities(resources={"listChanged": True}) + class _: + pass PROJECT_ROOT = Path(os.environ.get("UNITY_PROJECT_ROOT", Path.cwd())).resolve() ASSETS_ROOT = (PROJECT_ROOT / "Assets").resolve() @@ -120,28 +122,32 @@ def _resolve_safe_path_from_uri(uri: str) -> Path | None: return None return p -@mcp.resource.list() -def list_resources(ctx: Context) -> list[dict]: - assets = [] - try: - for p in ASSETS_ROOT.rglob("*.cs"): - rel = p.relative_to(PROJECT_ROOT).as_posix() - assets.append({"uri": f"unity://path/{rel}", "name": p.name}) - except Exception: - pass - return assets -@mcp.resource.read() -def read_resource(ctx: Context, uri: str) -> dict: - p = _resolve_safe_path_from_uri(uri) - if not p or not p.exists(): - return {"mimeType": "text/plain", "text": f"Resource not found: {uri}"} - try: - text = p.read_text(encoding="utf-8") - sha = hashlib.sha256(text.encode("utf-8")).hexdigest() - return {"mimeType": "text/plain", "text": text, "metadata": {"sha256": sha}} - except Exception as e: - return {"mimeType": "text/plain", "text": f"Error reading resource: {e}"} +if hasattr(mcp, "resource") and hasattr(getattr(mcp, "resource"), "list"): + @mcp.resource.list() + def list_resources(ctx: Context) -> list[dict]: + assets = [] + try: + for p in ASSETS_ROOT.rglob("*.cs"): + rel = p.relative_to(PROJECT_ROOT).as_posix() + assets.append({"uri": f"unity://path/{rel}", "name": p.name}) + except Exception: + pass + return assets + +if hasattr(mcp, "resource") and hasattr(getattr(mcp, "resource"), "read"): + @mcp.resource.read() + def read_resource(ctx: Context, uri: str) -> dict: + p = _resolve_safe_path_from_uri(uri) + if not p or not p.exists(): + return {"mimeType": "text/plain", "text": f"Resource not found: {uri}"} + try: + text = p.read_text(encoding="utf-8") + sha = hashlib.sha256(text.encode("utf-8")).hexdigest() + return {"mimeType": "text/plain", "text": text, "metadata": {"sha256": sha}} + except Exception as e: + return {"mimeType": "text/plain", "text": f"Error reading resource: {e}"} + af56d70 (Claude Desktop: write BOM-free config to macOS path; dual-path fallback; add uv -q for quieter stdio; MCP server: compatibility guards for capabilities/resource decorators and indentation fix; ManageScript: shadow var fix; robust mac config path.) # Run the server if __name__ == "__main__": From 8ccba72b328a2e4d70de31375cf1e8f3cc8244f4 Mon Sep 17 00:00:00 2001 From: David Sarno Date: Sun, 17 Aug 2025 19:47:01 -0700 Subject: [PATCH 23/25] MCP: natural-language edit defaults; header guard + precondition for text edits; anchor aliasing and text-op conversion; immediate compile on NL/structured; add resource_tools (tail_lines, find_in_file); update test cases --- UnityMcpBridge/Editor/Tools/ManageEditor.cs | 24 +- UnityMcpBridge/Editor/Tools/ManageScript.cs | 81 +++++- UnityMcpBridge/UnityMcpServer~/src/server.py | 1 - .../UnityMcpServer~/src/tools/__init__.py | 3 + .../src/tools/manage_script_edits.py | 234 +++++++++++++++++- .../src/tools/resource_tools.py | 227 +++++++++++++++++ 6 files changed, 562 insertions(+), 8 deletions(-) create mode 100644 UnityMcpBridge/UnityMcpServer~/src/tools/resource_tools.py diff --git a/UnityMcpBridge/Editor/Tools/ManageEditor.cs b/UnityMcpBridge/Editor/Tools/ManageEditor.cs index 06d057d6..9151115f 100644 --- a/UnityMcpBridge/Editor/Tools/ManageEditor.cs +++ b/UnityMcpBridge/Editor/Tools/ManageEditor.cs @@ -1,6 +1,7 @@ using System; using System.Collections.Generic; using System.Linq; +using System.IO; using Newtonsoft.Json.Linq; using UnityEditor; using UnityEditorInternal; // Required for tag management @@ -89,6 +90,8 @@ public static object HandleCommand(JObject @params) // Editor State/Info case "get_state": return GetEditorState(); + case "get_project_root": + return GetProjectRoot(); case "get_windows": return GetEditorWindows(); case "get_active_tool": @@ -137,7 +140,7 @@ public static object HandleCommand(JObject @params) default: return Response.Error( - $"Unknown action: '{action}'. Supported actions include play, pause, stop, get_state, get_windows, get_active_tool, get_selection, set_active_tool, add_tag, remove_tag, get_tags, add_layer, remove_layer, get_layers." + $"Unknown action: '{action}'. Supported actions include play, pause, stop, get_state, get_project_root, get_windows, get_active_tool, get_selection, set_active_tool, add_tag, remove_tag, get_tags, add_layer, remove_layer, get_layers." ); } } @@ -165,6 +168,25 @@ private static object GetEditorState() } } + private static object GetProjectRoot() + { + try + { + // Application.dataPath points to /Assets + string assetsPath = Application.dataPath.Replace('\\', '/'); + string projectRoot = Directory.GetParent(assetsPath)?.FullName.Replace('\\', '/'); + if (string.IsNullOrEmpty(projectRoot)) + { + return Response.Error("Could not determine project root from Application.dataPath"); + } + return Response.Success("Project root resolved.", new { projectRoot }); + } + catch (Exception e) + { + return Response.Error($"Error getting project root: {e.Message}"); + } + } + private static object GetEditorWindows() { try diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index 31ce8e78..1fcf1e13 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -483,11 +483,12 @@ private static object ApplyTextEdits( try { original = File.ReadAllText(fullPath); } catch (Exception ex) { return Response.Error($"Failed to read script: {ex.Message}"); } + // Require precondition to avoid drift on large files string currentSha = ComputeSha256(original); - if (!string.IsNullOrEmpty(preconditionSha256) && !preconditionSha256.Equals(currentSha, StringComparison.OrdinalIgnoreCase)) - { + if (string.IsNullOrEmpty(preconditionSha256)) + return Response.Error("precondition_required", new { status = "precondition_required", current_sha256 = currentSha }); + if (!preconditionSha256.Equals(currentSha, StringComparison.OrdinalIgnoreCase)) return Response.Error("stale_file", new { status = "stale_file", expected_sha256 = preconditionSha256, current_sha256 = currentSha }); - } // Convert edits to absolute index ranges var spans = new List<(int start, int end, string text)>(); @@ -520,6 +521,59 @@ private static object ApplyTextEdits( } } + // Header guard: refuse edits that touch before the first 'using ' directive (after optional BOM) to prevent file corruption + int headerBoundary = 0; + if (original.Length > 0 && original[0] == '\uFEFF') headerBoundary = 1; // skip BOM + // Find first top-level using (very simple scan of start of file) + var mUsing = System.Text.RegularExpressions.Regex.Match(original, @"(?m)^(?:\uFEFF)?using\s+\w+", System.Text.RegularExpressions.RegexOptions.None); + if (mUsing.Success) + headerBoundary = Math.Min(Math.Max(headerBoundary, mUsing.Index), original.Length); + foreach (var sp in spans) + { + if (sp.start < headerBoundary) + { + return Response.Error("header_guard", new { status = "header_guard", hint = "Refusing to edit before the first 'using'. Use anchor_insert near a method or a structured edit." }); + } + } + + // Attempt auto-upgrade: if a single edit targets a method header/body, re-route as structured replace_method + if (spans.Count == 1) + { + var sp = spans[0]; + // Heuristic: around the start of the edit, try to match a method header in original + int searchStart = Math.Max(0, sp.start - 200); + int searchEnd = Math.Min(original.Length, sp.start + 200); + string slice = original.Substring(searchStart, searchEnd - searchStart); + var rx = new System.Text.RegularExpressions.Regex(@"(?m)^[\t ]*(?:\[[^\]]+\][\t ]*)*[\t ]*(?:public|private|protected|internal|static|virtual|override|sealed|async|extern|unsafe|new|partial)[\s\S]*?\b([A-Za-z_][A-Za-z0-9_]*)\s*\("); + var mh = rx.Match(slice); + if (mh.Success) + { + string methodName = mh.Groups[1].Value; + // Find class span containing the edit + if (TryComputeClassSpan(original, name, null, out var clsStart, out var clsLen, out _)) + { + if (TryComputeMethodSpan(original, clsStart, clsLen, methodName, null, null, null, out var mStart, out var mLen, out _)) + { + // If the edit overlaps the method span significantly, treat as replace_method + if (sp.start <= mStart + 2 && sp.end >= mStart + 1) + { + var structEdits = new JArray(); + var op = new JObject + { + ["mode"] = "replace_method", + ["className"] = name, + ["methodName"] = methodName, + ["replacement"] = original.Remove(sp.start, sp.end - sp.start).Insert(sp.start, sp.text ?? string.Empty).Substring(mStart, (sp.text ?? string.Empty).Length + (sp.start - mStart) + (mLen - (sp.end - mStart))) + }; + structEdits.Add(op); + // Reuse structured path + return EditScript(fullPath, relativePath, name, structEdits, new JObject{ ["refresh"] = "immediate", ["validate"] = "standard" }); + } + } + } + } + } + if (totalBytes > MaxEditPayloadBytes) { return Response.Error("too_large", new { status = "too_large", limitBytes = MaxEditPayloadBytes, hint = "split into smaller edits" }); @@ -952,6 +1006,9 @@ private static object EditScript( string afterParameters = op.Value("afterParametersSignature"); string afterAttributesContains = op.Value("afterAttributesContains"); string snippet = ExtractReplacement(op); + // Harden: refuse empty replacement for inserts + if (snippet == null || snippet.Trim().Length == 0) + return Response.Error("insert_method requires a non-empty 'replacement' text."); if (string.IsNullOrWhiteSpace(className)) return Response.Error("insert_method requires 'className'."); if (snippet == null) return Response.Error("insert_method requires 'replacement' (inline or base64) containing a full method declaration."); @@ -1239,7 +1296,23 @@ private static bool TryComputeMethodSpan( // 1) Find the method header using a stricter regex (allows optional attributes above) string rtPattern = string.IsNullOrEmpty(returnType) ? @"[^\s]+" : Regex.Escape(returnType).Replace("\\ ", "\\s+"); string namePattern = Regex.Escape(methodName); - string paramsPattern = string.IsNullOrEmpty(parametersSignature) ? @"[\s\S]*?" : Regex.Escape(parametersSignature); + // If a parametersSignature is provided, it may include surrounding parentheses. Strip them so + // we can safely embed the signature inside our own parenthesis group without duplicating. + string paramsPattern; + if (string.IsNullOrEmpty(parametersSignature)) + { + paramsPattern = @"[\s\S]*?"; // permissive when not specified + } + else + { + string ps = parametersSignature.Trim(); + if (ps.StartsWith("(") && ps.EndsWith(")") && ps.Length >= 2) + { + ps = ps.Substring(1, ps.Length - 2); + } + // Escape literal text of the signature + paramsPattern = Regex.Escape(ps); + } string pattern = @"(?m)^[\t ]*(?:\[[^\]]+\][\t ]*)*[\t ]*" + @"(?:(?:public|private|protected|internal|static|virtual|override|sealed|async|extern|unsafe|new|partial|readonly|volatile|event|abstract|ref|in|out)\s+)*" + diff --git a/UnityMcpBridge/UnityMcpServer~/src/server.py b/UnityMcpBridge/UnityMcpServer~/src/server.py index fdec41ea..3e81408c 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/server.py +++ b/UnityMcpBridge/UnityMcpServer~/src/server.py @@ -147,7 +147,6 @@ def read_resource(ctx: Context, uri: str) -> dict: return {"mimeType": "text/plain", "text": text, "metadata": {"sha256": sha}} except Exception as e: return {"mimeType": "text/plain", "text": f"Error reading resource: {e}"} - af56d70 (Claude Desktop: write BOM-free config to macOS path; dual-path fallback; add uv -q for quieter stdio; MCP server: compatibility guards for capabilities/resource decorators and indentation fix; ManageScript: shadow var fix; robust mac config path.) # Run the server if __name__ == "__main__": diff --git a/UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py b/UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py index 710b53dc..aa7bf014 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py +++ b/UnityMcpBridge/UnityMcpServer~/src/tools/__init__.py @@ -8,6 +8,7 @@ from .manage_shader import register_manage_shader_tools from .read_console import register_read_console_tools from .execute_menu_item import register_execute_menu_item_tools +from .resource_tools import register_resource_tools logger = logging.getLogger("unity-mcp-server") @@ -24,4 +25,6 @@ def register_all_tools(mcp): register_manage_shader_tools(mcp) register_read_console_tools(mcp) register_execute_menu_item_tools(mcp) + # Expose resource wrappers as normal tools so IDEs without resources primitive can use them + register_resource_tools(mcp) logger.info("Unity MCP Server tool registration complete.") diff --git a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py index bd7f7137..126c60c0 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py +++ b/UnityMcpBridge/UnityMcpServer~/src/tools/manage_script_edits.py @@ -1,5 +1,5 @@ from mcp.server.fastmcp import FastMCP, Context -from typing import Dict, Any, List +from typing import Dict, Any, List, Tuple import base64 import re from unity_connection import send_command_with_retry @@ -74,6 +74,97 @@ def _apply_edits_locally(original_text: str, edits: List[Dict[str, Any]]) -> str return text +def _infer_class_name(script_name: str) -> str: + # Default to script name as class name (common Unity pattern) + return (script_name or "").strip() + + +def _extract_code_after(keyword: str, request: str) -> str: + idx = request.lower().find(keyword) + if idx >= 0: + return request[idx + len(keyword):].strip() + return "" + + +def _parse_natural_request_to_edits( + request: str, + script_name: str, + file_text: str, +) -> Tuple[List[Dict[str, Any]], str]: + """Parses a natural language request into a list of edits. + + Returns (edits, message). message is a brief description or disambiguation note. + """ + req = (request or "").strip() + if not req: + return [], "" + + edits: List[Dict[str, Any]] = [] + cls = _infer_class_name(script_name) + + # 1) Insert/Add comment above/below/after method + m = re.search(r"(?:insert|add)\s+comment\s+[\"'](.+?)[\"']\s+(above|before|below|after)\s+(?:the\s+)?(?:method\s+)?([A-Za-z_][A-Za-z0-9_]*)", + req, re.IGNORECASE) + if m: + comment = m.group(1) + pos = m.group(2).lower() + method = m.group(3) + position = "before" if pos in ("above", "before") else "after" + anchor = rf"(?m)^\s*(?:\[[^\]]+\]\s*)*(?:public|private|protected|internal|static|virtual|override|sealed|async|extern|unsafe|new|partial).*?\b{re.escape(method)}\s*\(" + edits.append({ + "op": "anchor_insert", + "anchor": anchor, + "position": position, + "text": f" /* {comment} */\n", + }) + return edits, "insert_comment" + + # 2) Insert method ... after + m = re.search(r"insert\s+method\s+```([\s\S]+?)```\s+after\s+([A-Za-z_][A-Za-z0-9_]*)", req, re.IGNORECASE) + if not m: + m = re.search(r"insert\s+method\s+(.+?)\s+after\s+([A-Za-z_][A-Za-z0-9_]*)", req, re.IGNORECASE) + if m: + snippet = m.group(1).strip() + after_name = m.group(2) + edits.append({ + "op": "insert_method", + "className": cls, + "position": "after", + "afterMethodName": after_name, + "replacement": snippet, + }) + return edits, "insert_method" + + # 3) Replace method with + m = re.search(r"replace\s+method\s+([A-Za-z_][A-Za-z0-9_]*)\s+with\s+```([\s\S]+?)```", req, re.IGNORECASE) + if not m: + m = re.search(r"replace\s+method\s+([A-Za-z_][A-Za-z0-9_]*)\s+with\s+([\s\S]+)$", req, re.IGNORECASE) + if m: + name = m.group(1) + repl = m.group(2).strip() + edits.append({ + "op": "replace_method", + "className": cls, + "methodName": name, + "replacement": repl, + }) + return edits, "replace_method" + + # 4) Delete method [all overloads] + m = re.search(r"delete\s+method\s+([A-Za-z_][A-Za-z0-9_]*)", req, re.IGNORECASE) + if m: + name = m.group(1) + edits.append({ + "op": "delete_method", + "className": cls, + "methodName": name, + }) + return edits, "delete_method" + + # 5) Fallback: no parse + return [], "Could not parse natural-language request" + + def register_manage_script_edits_tools(mcp: FastMCP): @mcp.tool(description=( "Apply targeted edits to an existing C# script WITHOUT replacing the whole file. " @@ -88,9 +179,37 @@ def script_apply_edits( options: Dict[str, Any] | None = None, script_type: str = "MonoBehaviour", namespace: str = "", + request: str | None = None, ) -> Dict[str, Any]: # If the edits request structured class/method ops, route directly to Unity's 'edit' action. # These bypass local text validation/encoding since Unity performs the semantic changes. + # If user provided a natural-language request instead of structured edits, parse it + if (not edits) and request: + # Read to help extraction and return contextual diff/verification + read_resp = send_command_with_retry("manage_script", { + "action": "read", + "name": name, + "path": path, + "namespace": namespace, + "scriptType": script_type, + }) + if not isinstance(read_resp, dict) or not read_resp.get("success"): + return read_resp if isinstance(read_resp, dict) else {"success": False, "message": str(read_resp)} + data = read_resp.get("data") or read_resp.get("result", {}).get("data") or {} + contents = data.get("contents") + if contents is None and data.get("contentsEncoded") and data.get("encodedContents"): + contents = base64.b64decode(data["encodedContents"]).decode("utf-8") + parsed_edits, why = _parse_natural_request_to_edits(request, name, contents or "") + if not parsed_edits: + return {"success": False, "message": f"Could not understand request: {why}"} + edits = parsed_edits + # Provide sensible defaults for natural language requests + options = dict(options or {}) + options.setdefault("validate", "standard") + options.setdefault("refresh", "immediate") + if len(edits) > 1: + options.setdefault("applyMode", "sequential") + for e in edits or []: op = (e.get("op") or e.get("operation") or e.get("type") or e.get("mode") or "").strip().lower() if op in ("replace_class", "delete_class", "replace_method", "delete_method", "insert_method"): @@ -125,13 +244,124 @@ def script_apply_edits( if contents is None: return {"success": False, "message": "No contents returned from Unity read."} - # 2) apply edits locally + # Optional preview/dry-run: apply locally and return diff without writing + preview = bool((options or {}).get("preview")) + + # If the edits are text-ops, prefer sending them to Unity's apply_text_edits with precondition + # so header guards and validation run on the C# side. + # Supported conversions: anchor_insert, replace_range, regex_replace (first match only). + text_ops = { (e.get("op") or e.get("operation") or e.get("type") or e.get("mode") or "").strip().lower() for e in (edits or []) } + structured_kinds = {"replace_class","delete_class","replace_method","delete_method","insert_method"} + if not text_ops.issubset(structured_kinds): + # Convert to apply_text_edits payload + try: + current_text = contents + def line_col_from_index(idx: int) -> Tuple[int, int]: + # 1-based line/col + line = current_text.count("\n", 0, idx) + 1 + last_nl = current_text.rfind("\n", 0, idx) + col = (idx - (last_nl + 1)) + 1 if last_nl >= 0 else idx + 1 + return line, col + + at_edits: List[Dict[str, Any]] = [] + import re as _re + for e in edits or []: + op = (e.get("op") or e.get("operation") or e.get("type") or e.get("mode") or "").strip().lower() + # aliasing for text field + text_field = e.get("text") or e.get("insert") or e.get("content") or "" + if op == "anchor_insert": + anchor = e.get("anchor") or "" + position = (e.get("position") or "before").lower() + m = _re.search(anchor, current_text, _re.MULTILINE) + if not m: + return {"success": False, "message": f"anchor not found: {anchor}"} + idx = m.start() if position == "before" else m.end() + sl, sc = line_col_from_index(idx) + at_edits.append({ + "startLine": sl, + "startCol": sc, + "endLine": sl, + "endCol": sc, + "newText": text_field or "" + }) + # Update local snapshot to keep subsequent anchors stable + current_text = current_text[:idx] + (text_field or "") + current_text[idx:] + elif op == "replace_range": + # Directly forward if already in line/col form + if "startLine" in e: + at_edits.append({ + "startLine": int(e.get("startLine", 1)), + "startCol": int(e.get("startCol", 1)), + "endLine": int(e.get("endLine", 1)), + "endCol": int(e.get("endCol", 1)), + "newText": text_field + }) + else: + # If only indices provided, skip (we don't support index-based here) + return {"success": False, "message": "replace_range requires startLine/startCol/endLine/endCol"} + elif op == "regex_replace": + pattern = e.get("pattern") or "" + repl = text_field + m = _re.search(pattern, current_text, _re.MULTILINE) + if not m: + continue + sl, sc = line_col_from_index(m.start()) + el, ec = line_col_from_index(m.end()) + at_edits.append({ + "startLine": sl, + "startCol": sc, + "endLine": el, + "endCol": ec, + "newText": repl + }) + current_text = current_text[:m.start()] + repl + current_text[m.end():] + else: + return {"success": False, "message": f"Unsupported text edit op for server-side apply_text_edits: {op}"} + + # Send to Unity with precondition SHA to enforce guards + import hashlib + sha = hashlib.sha256(contents.encode("utf-8")).hexdigest() + params: Dict[str, Any] = { + "action": "apply_text_edits", + "name": name, + "path": path, + "namespace": namespace, + "scriptType": script_type, + "edits": at_edits, + "precondition_sha256": sha, + "options": { + "refresh": (options or {}).get("refresh", "immediate"), + "validate": (options or {}).get("validate", "standard") + } + } + resp = send_command_with_retry("manage_script", params) + return resp if isinstance(resp, dict) else {"success": False, "message": str(resp)} + except Exception as e: + return {"success": False, "message": f"Edit conversion failed: {e}"} + + # 2) apply edits locally (only if not text-ops) try: new_contents = _apply_edits_locally(contents, edits) except Exception as e: return {"success": False, "message": f"Edit application failed: {e}"} + if preview: + # Produce a compact unified diff limited to small context + import difflib + a = contents.splitlines() + b = new_contents.splitlines() + diff = list(difflib.unified_diff(a, b, fromfile="before", tofile="after", n=3)) + # Limit diff size to keep responses small + if len(diff) > 2000: + diff = diff[:2000] + ["... (diff truncated) ..."] + return {"success": True, "message": "Preview only (no write)", "data": {"diff": "\n".join(diff)}} + # 3) update to Unity + # Default refresh/validate for natural usage on text path as well + options = dict(options or {}) + options.setdefault("validate", "standard") + options.setdefault("refresh", "immediate") + params: Dict[str, Any] = { "action": "update", "name": name, diff --git a/UnityMcpBridge/UnityMcpServer~/src/tools/resource_tools.py b/UnityMcpBridge/UnityMcpServer~/src/tools/resource_tools.py new file mode 100644 index 00000000..572f2b0a --- /dev/null +++ b/UnityMcpBridge/UnityMcpServer~/src/tools/resource_tools.py @@ -0,0 +1,227 @@ +""" +Resource wrapper tools so clients that do not expose MCP resources primitives +can still list and read files via normal tools. These call into the same +safe path logic (re-implemented here to avoid importing server.py). +""" +from __future__ import annotations + +from typing import Dict, Any, List +import re +from pathlib import Path +import fnmatch +import hashlib +import os + +from mcp.server.fastmcp import FastMCP, Context +from unity_connection import send_command_with_retry + + +def _resolve_project_root(override: str | None) -> Path: + # 1) Explicit override + if override: + pr = Path(override).expanduser().resolve() + if (pr / "Assets").exists(): + return pr + # 2) Environment + env = os.environ.get("UNITY_PROJECT_ROOT") + if env: + pr = Path(env).expanduser().resolve() + if (pr / "Assets").exists(): + return pr + # 3) Ask Unity via manage_editor.get_project_root + try: + resp = send_command_with_retry("manage_editor", {"action": "get_project_root"}) + if isinstance(resp, dict) and resp.get("success"): + pr = Path(resp.get("data", {}).get("projectRoot", "")).expanduser().resolve() + if pr and (pr / "Assets").exists(): + return pr + except Exception: + pass + + # 4) Walk up from CWD to find a Unity project (Assets + ProjectSettings) + cur = Path.cwd().resolve() + for _ in range(6): + if (cur / "Assets").exists() and (cur / "ProjectSettings").exists(): + return cur + if cur.parent == cur: + break + cur = cur.parent + # 5) Fallback: CWD + return Path.cwd().resolve() + + +def _resolve_safe_path_from_uri(uri: str, project: Path) -> Path | None: + raw: str | None = None + if uri.startswith("unity://path/"): + raw = uri[len("unity://path/"):] + elif uri.startswith("file://"): + raw = uri[len("file://"):] + elif uri.startswith("Assets/"): + raw = uri + if raw is None: + return None + p = (project / raw).resolve() + try: + p.relative_to(project) + except ValueError: + return None + return p + + +def register_resource_tools(mcp: FastMCP) -> None: + """Registers list_resources and read_resource wrapper tools.""" + + @mcp.tool() + async def list_resources( + ctx: Context, + pattern: str | None = "*.cs", + under: str = "Assets", + limit: int = 200, + project_root: str | None = None, + ) -> Dict[str, Any]: + """ + Lists project URIs (unity://path/...) under a folder (default: Assets). + - pattern: glob like *.cs or *.shader (None to list all files) + - under: relative folder under project root + - limit: max results + """ + try: + project = _resolve_project_root(project_root) + base = (project / under).resolve() + try: + base.relative_to(project) + except ValueError: + return {"success": False, "error": "Base path must be under project root"} + + matches: List[str] = [] + for p in base.rglob("*"): + if not p.is_file(): + continue + if pattern and not fnmatch.fnmatch(p.name, pattern): + continue + rel = p.relative_to(project).as_posix() + matches.append(f"unity://path/{rel}") + if len(matches) >= max(1, limit): + break + + return {"success": True, "data": {"uris": matches, "count": len(matches)}} + except Exception as e: + return {"success": False, "error": str(e)} + + @mcp.tool() + async def read_resource( + ctx: Context, + uri: str, + start_line: int | None = None, + line_count: int | None = None, + head_bytes: int | None = None, + tail_lines: int | None = None, + project_root: str | None = None, + request: str | None = None, + ) -> Dict[str, Any]: + """ + Reads a resource by unity://path/... URI with optional slicing. + One of line window (start_line/line_count) or head_bytes can be used to limit size. + """ + try: + project = _resolve_project_root(project_root) + p = _resolve_safe_path_from_uri(uri, project) + if not p or not p.exists() or not p.is_file(): + return {"success": False, "error": f"Resource not found: {uri}"} + + # Natural-language convenience: request like "last 120 lines", "first 200 lines", + # "show 40 lines around MethodName", etc. + if request: + req = request.strip().lower() + m = re.search(r"last\s+(\d+)\s+lines", req) + if m: + tail_lines = int(m.group(1)) + m = re.search(r"first\s+(\d+)\s+lines", req) + if m: + start_line = 1 + line_count = int(m.group(1)) + m = re.search(r"first\s+(\d+)\s*bytes", req) + if m: + head_bytes = int(m.group(1)) + m = re.search(r"show\s+(\d+)\s+lines\s+around\s+([A-Za-z_][A-Za-z0-9_]*)", req) + if m: + window = int(m.group(1)) + method = m.group(2) + # naive search for method header to get a line number + text_all = p.read_text(encoding="utf-8") + lines_all = text_all.splitlines() + pat = re.compile(rf"^\s*(?:\[[^\]]+\]\s*)*(?:public|private|protected|internal|static|virtual|override|sealed|async|extern|unsafe|new|partial).*?\b{re.escape(method)}\s*\(", re.MULTILINE) + hit_line = None + for i, line in enumerate(lines_all, start=1): + if pat.search(line): + hit_line = i + break + if hit_line: + half = max(1, window // 2) + start_line = max(1, hit_line - half) + line_count = window + + # Mutually exclusive windowing options precedence: + # 1) head_bytes, 2) tail_lines, 3) start_line+line_count, else full text + if head_bytes and head_bytes > 0: + raw = p.read_bytes()[: head_bytes] + text = raw.decode("utf-8", errors="replace") + else: + text = p.read_text(encoding="utf-8") + if tail_lines is not None and tail_lines > 0: + lines = text.splitlines() + n = max(0, tail_lines) + text = "\n".join(lines[-n:]) + elif start_line is not None and line_count is not None and line_count >= 0: + lines = text.splitlines() + s = max(0, start_line - 1) + e = min(len(lines), s + line_count) + text = "\n".join(lines[s:e]) + + sha = hashlib.sha256(text.encode("utf-8")).hexdigest() + return {"success": True, "data": {"text": text, "metadata": {"sha256": sha}}} + except Exception as e: + return {"success": False, "error": str(e)} + + @mcp.tool() + async def find_in_file( + ctx: Context, + uri: str, + pattern: str, + ignore_case: bool | None = True, + project_root: str | None = None, + max_results: int | None = 200, + ) -> Dict[str, Any]: + """ + Searches a file with a regex pattern and returns line numbers and excerpts. + - uri: unity://path/Assets/... or file path form supported by read_resource + - pattern: regular expression (Python re) + - ignore_case: case-insensitive by default + - max_results: cap results to avoid huge payloads + """ + import re + try: + project = _resolve_project_root(project_root) + p = _resolve_safe_path_from_uri(uri, project) + if not p or not p.exists() or not p.is_file(): + return {"success": False, "error": f"Resource not found: {uri}"} + + text = p.read_text(encoding="utf-8") + flags = re.MULTILINE + if ignore_case: + flags |= re.IGNORECASE + rx = re.compile(pattern, flags) + + results = [] + lines = text.splitlines() + for i, line in enumerate(lines, start=1): + if rx.search(line): + results.append({"line": i, "text": line}) + if max_results and len(results) >= max_results: + break + + return {"success": True, "data": {"matches": results, "count": len(results)}} + except Exception as e: + return {"success": False, "error": str(e)} + + From fcd0ec4e8b9ff6285b9ae3eaa0ca9c5da5627323 Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 20:44:50 -0700 Subject: [PATCH 24/25] Handle FRAMING=1 in port discovery probe --- UnityMcpBridge/Editor/Tools/ManageScript.cs | 9 +++++- .../UnityMcpServer~/src/port_discovery.py | 29 +++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index 1fcf1e13..d92e6cb6 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -557,13 +557,20 @@ private static object ApplyTextEdits( // If the edit overlaps the method span significantly, treat as replace_method if (sp.start <= mStart + 2 && sp.end >= mStart + 1) { + var methodOriginal = original.Substring(mStart, mLen); + int relStart = Math.Max(0, Math.Min(sp.start - mStart, methodOriginal.Length)); + int relEnd = Math.Max(relStart, Math.Min(sp.end - mStart, methodOriginal.Length)); + string replacementSnippet = methodOriginal + .Remove(relStart, relEnd - relStart) + .Insert(relStart, sp.text ?? string.Empty); + var structEdits = new JArray(); var op = new JObject { ["mode"] = "replace_method", ["className"] = name, ["methodName"] = methodName, - ["replacement"] = original.Remove(sp.start, sp.end - sp.start).Insert(sp.start, sp.text ?? string.Empty).Substring(mStart, (sp.text ?? string.Empty).Length + (sp.start - mStart) + (mLen - (sp.end - mStart))) + ["replacement"] = replacementSnippet }; structEdits.Add(op); // Reuse structured path diff --git a/UnityMcpBridge/UnityMcpServer~/src/port_discovery.py b/UnityMcpBridge/UnityMcpServer~/src/port_discovery.py index 98855333..070bde4d 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/port_discovery.py +++ b/UnityMcpBridge/UnityMcpServer~/src/port_discovery.py @@ -18,6 +18,7 @@ from typing import Optional, List import glob import socket +import struct logger = logging.getLogger("unity-mcp-server") @@ -56,14 +57,36 @@ def list_candidate_files() -> List[Path]: @staticmethod def _try_probe_unity_mcp(port: int) -> bool: """Quickly check if a Unity MCP listener is on this port. - Tries a short TCP connect, sends 'ping', expects a JSON 'pong'. + Performs a short TCP connect and ping/pong exchange. If the + server advertises ``FRAMING=1`` in its greeting, the ping and + pong are sent/received with an 8-byte big-endian length prefix. """ + + def _read_exact(sock: socket.socket, count: int) -> bytes: + buf = bytearray() + while len(buf) < count: + chunk = sock.recv(count - len(buf)) + if not chunk: + raise ConnectionError("Connection closed before reading expected bytes") + buf.extend(chunk) + return bytes(buf) + try: with socket.create_connection(("127.0.0.1", port), PortDiscovery.CONNECT_TIMEOUT) as s: s.settimeout(PortDiscovery.CONNECT_TIMEOUT) try: - s.sendall(b"ping") - data = s.recv(512) + greeting = s.recv(256) + text = greeting.decode('ascii', errors='ignore') if greeting else '' + payload = b"ping" + if 'FRAMING=1' in text: + header = struct.pack('>Q', len(payload)) + s.sendall(header + payload) + resp_header = _read_exact(s, 8) + resp_len = struct.unpack('>Q', resp_header)[0] + data = _read_exact(s, resp_len) + else: + s.sendall(payload) + data = s.recv(512) # Minimal validation: look for a success pong response if data and b'"message":"pong"' in data: return True From f48c4a064e03a06055888b82b152931d54491e6e Mon Sep 17 00:00:00 2001 From: dsarno Date: Sun, 17 Aug 2025 21:50:34 -0700 Subject: [PATCH 25/25] refactor port probing and script tooling --- UnityMcpBridge/Editor/Tools/ManageScript.cs | 85 +++++++++++++------ .../UnityMcpServer~/src/port_discovery.py | 47 +++++++--- 2 files changed, 96 insertions(+), 36 deletions(-) diff --git a/UnityMcpBridge/Editor/Tools/ManageScript.cs b/UnityMcpBridge/Editor/Tools/ManageScript.cs index d92e6cb6..f22d4af9 100644 --- a/UnityMcpBridge/Editor/Tools/ManageScript.cs +++ b/UnityMcpBridge/Editor/Tools/ManageScript.cs @@ -90,6 +90,14 @@ private static bool TryResolveUnderAssets(string relDir, out string fullPathDir, relPathSafe = null; return false; } +#if NET6_0_OR_GREATER + if (!string.IsNullOrEmpty(di.LinkTarget)) + { + fullPathDir = null; + relPathSafe = null; + return false; + } +#endif } } catch { /* best effort; proceed */ } @@ -525,7 +533,11 @@ private static object ApplyTextEdits( int headerBoundary = 0; if (original.Length > 0 && original[0] == '\uFEFF') headerBoundary = 1; // skip BOM // Find first top-level using (very simple scan of start of file) - var mUsing = System.Text.RegularExpressions.Regex.Match(original, @"(?m)^(?:\uFEFF)?using\s+\w+", System.Text.RegularExpressions.RegexOptions.None); + var mUsing = System.Text.RegularExpressions.Regex.Match( + original, + @"(?m)^(?:\uFEFF)?(?:global\s+)?using(?:\s+static)?\b", + System.Text.RegularExpressions.RegexOptions.None + ); if (mUsing.Success) headerBoundary = Math.Min(Math.Max(headerBoundary, mUsing.Index), original.Length); foreach (var sp in spans) @@ -550,32 +562,34 @@ private static object ApplyTextEdits( { string methodName = mh.Groups[1].Value; // Find class span containing the edit - if (TryComputeClassSpan(original, name, null, out var clsStart, out var clsLen, out _)) + if (!TryComputeClassSpan(original, name, null, out var clsStart, out var clsLen, out _)) + { + FindEnclosingClassSpan(original, sp.start, out clsStart, out clsLen); + } + if (clsLen > 0 && + TryComputeMethodSpan(original, clsStart, clsLen, methodName, null, null, null, out var mStart, out var mLen, out _)) { - if (TryComputeMethodSpan(original, clsStart, clsLen, methodName, null, null, null, out var mStart, out var mLen, out _)) + // If the edit overlaps the method span significantly, treat as replace_method + if (sp.start <= mStart + 2 && sp.end >= mStart + 1) { - // If the edit overlaps the method span significantly, treat as replace_method - if (sp.start <= mStart + 2 && sp.end >= mStart + 1) + var methodOriginal = original.Substring(mStart, mLen); + int relStart = Math.Max(0, Math.Min(sp.start - mStart, methodOriginal.Length)); + int relEnd = Math.Max(relStart, Math.Min(sp.end - mStart, methodOriginal.Length)); + string replacementSnippet = methodOriginal + .Remove(relStart, relEnd - relStart) + .Insert(relStart, sp.text ?? string.Empty); + + var structEdits = new JArray(); + var op = new JObject { - var methodOriginal = original.Substring(mStart, mLen); - int relStart = Math.Max(0, Math.Min(sp.start - mStart, methodOriginal.Length)); - int relEnd = Math.Max(relStart, Math.Min(sp.end - mStart, methodOriginal.Length)); - string replacementSnippet = methodOriginal - .Remove(relStart, relEnd - relStart) - .Insert(relStart, sp.text ?? string.Empty); - - var structEdits = new JArray(); - var op = new JObject - { - ["mode"] = "replace_method", - ["className"] = name, - ["methodName"] = methodName, - ["replacement"] = replacementSnippet - }; - structEdits.Add(op); - // Reuse structured path - return EditScript(fullPath, relativePath, name, structEdits, new JObject{ ["refresh"] = "immediate", ["validate"] = "standard" }); - } + ["mode"] = "replace_method", + ["className"] = name, + ["methodName"] = methodName, + ["replacement"] = replacementSnippet + }; + structEdits.Add(op); + // Reuse structured path + return EditScript(fullPath, relativePath, name, structEdits, new JObject{ ["refresh"] = "immediate", ["validate"] = "standard" }); } } } @@ -737,7 +751,16 @@ private static bool CheckBalancedDelimiters(string text, out int line, out char char c = text[i]; char next = i + 1 < text.Length ? text[i + 1] : '\0'; - if (c == '\n') { line++; if (inSingle) inSingle = false; } + if (c == '\r') + { + // Treat CRLF as a single newline; skip the LF if present + if (next == '\n') { i++; } + line++; if (inSingle) inSingle = false; + } + else if (c == '\n') + { + line++; if (inSingle) inSingle = false; + } if (escape) { escape = false; continue; } @@ -1205,6 +1228,18 @@ private static bool ValidateClassSnippet(string snippet, string expectedName, ou #endif } + private static bool FindEnclosingClassSpan(string source, int index, out int start, out int length) + { + start = length = 0; + if (index < 0 || index > source.Length) return false; + var prefix = source.Substring(0, Math.Min(index, source.Length)); + var matches = Regex.Matches(prefix, @"(?m)\bclass\s+([A-Za-z_][A-Za-z0-9_]*)"); + if (matches.Count == 0) return false; + var m = matches[matches.Count - 1]; + var className = m.Groups[1].Value; + return TryComputeClassSpanBalanced(source, className, null, out start, out length, out _); + } + private static bool TryComputeClassSpan(string source, string className, string ns, out int start, out int length, out string why) { #if USE_ROSLYN diff --git a/UnityMcpBridge/UnityMcpServer~/src/port_discovery.py b/UnityMcpBridge/UnityMcpServer~/src/port_discovery.py index 070bde4d..828dd956 100644 --- a/UnityMcpBridge/UnityMcpServer~/src/port_discovery.py +++ b/UnityMcpBridge/UnityMcpServer~/src/port_discovery.py @@ -22,6 +22,21 @@ logger = logging.getLogger("unity-mcp-server") +FRAME_HEADER_SIZE = 8 +# Keep small; we're only looking for a tiny pong. 1 MiB is generous for probes. +MAX_FRAME_SIZE = 1 << 20 + + +# Module-level helper to avoid duplication and per-call redefinition +def _read_exact(sock: socket.socket, count: int) -> bytes: + buf = bytearray() + while len(buf) < count: + chunk = sock.recv(count - len(buf)) + if not chunk: + raise ConnectionError("Connection closed before reading expected bytes") + buf.extend(chunk) + return bytes(buf) + class PortDiscovery: """Handles port discovery from Unity Bridge registry""" REGISTRY_FILE = "unity-mcp-port.json" # legacy single-project file @@ -62,15 +77,6 @@ def _try_probe_unity_mcp(port: int) -> bool: pong are sent/received with an 8-byte big-endian length prefix. """ - def _read_exact(sock: socket.socket, count: int) -> bytes: - buf = bytearray() - while len(buf) < count: - chunk = sock.recv(count - len(buf)) - if not chunk: - raise ConnectionError("Connection closed before reading expected bytes") - buf.extend(chunk) - return bytes(buf) - try: with socket.create_connection(("127.0.0.1", port), PortDiscovery.CONNECT_TIMEOUT) as s: s.settimeout(PortDiscovery.CONNECT_TIMEOUT) @@ -81,12 +87,31 @@ def _read_exact(sock: socket.socket, count: int) -> bytes: if 'FRAMING=1' in text: header = struct.pack('>Q', len(payload)) s.sendall(header + payload) - resp_header = _read_exact(s, 8) + resp_header = _read_exact(s, FRAME_HEADER_SIZE) resp_len = struct.unpack('>Q', resp_header)[0] + # Defensive cap against unreasonable frame sizes + if resp_len > MAX_FRAME_SIZE: + return False data = _read_exact(s, resp_len) else: s.sendall(payload) - data = s.recv(512) + # Read a small bounded amount looking for pong + chunks = [] + total = 0 + data = b"" + while total < 1024: + try: + part = s.recv(512) + except socket.timeout: + break + if not part: + break + chunks.append(part) + total += len(part) + if b'"message":"pong"' in part: + break + if chunks: + data = b"".join(chunks) # Minimal validation: look for a success pong response if data and b'"message":"pong"' in data: return True