diff --git a/invokeai/frontend/web/public/locales/en.json b/invokeai/frontend/web/public/locales/en.json index be5ecd9a9cf..a50b0cb9efd 100644 --- a/invokeai/frontend/web/public/locales/en.json +++ b/invokeai/frontend/web/public/locales/en.json @@ -501,6 +501,14 @@ "title": "Next Prompt in History", "desc": "When the prompt is focused, move to the next (newer) prompt in your history." }, + "promptWeightUp": { + "title": "Increase Weight of Prompt Selection", + "desc": "When the prompt is focused and text is selected, increase the weight of the selected prompt." + }, + "promptWeightDown": { + "title": "Decrease Weight of Prompt Selection", + "desc": "When the prompt is focused and text is selected, decrease the weight of the selected prompt." + }, "toggleLeftPanel": { "title": "Toggle Left Panel", "desc": "Show or hide the left panel." diff --git a/invokeai/frontend/web/src/common/util/promptAST.test.ts b/invokeai/frontend/web/src/common/util/promptAST.test.ts new file mode 100644 index 00000000000..25786d417af --- /dev/null +++ b/invokeai/frontend/web/src/common/util/promptAST.test.ts @@ -0,0 +1,267 @@ +import { describe, expect, it } from 'vitest'; + +import { parseTokens, serialize, tokenize } from './promptAST'; + +describe('promptAST', () => { + describe('tokenize', () => { + it('should tokenize basic text', () => { + const tokens = tokenize('a cat'); + expect(tokens).toEqual([ + { type: 'word', value: 'a' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', value: 'cat' }, + ]); + }); + + it('should tokenize groups with parentheses', () => { + const tokens = tokenize('(a cat)'); + expect(tokens).toEqual([ + { type: 'lparen' }, + { type: 'word', value: 'a' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', value: 'cat' }, + { type: 'rparen' }, + ]); + }); + + it('should tokenize escaped parentheses', () => { + const tokens = tokenize('\\(medium\\)'); + expect(tokens).toEqual([ + { type: 'escaped_paren', value: '(' }, + { type: 'word', value: 'medium' }, + { type: 'escaped_paren', value: ')' }, + ]); + }); + + it('should tokenize mixed escaped and unescaped parentheses', () => { + const tokens = tokenize('colored pencil \\(medium\\) (enhanced)'); + expect(tokens).toEqual([ + { type: 'word', value: 'colored' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', value: 'pencil' }, + { type: 'whitespace', value: ' ' }, + { type: 'escaped_paren', value: '(' }, + { type: 'word', value: 'medium' }, + { type: 'escaped_paren', value: ')' }, + { type: 'whitespace', value: ' ' }, + { type: 'lparen' }, + { type: 'word', value: 'enhanced' }, + { type: 'rparen' }, + ]); + }); + + it('should tokenize groups with weights', () => { + const tokens = tokenize('(a cat)1.2'); + expect(tokens).toEqual([ + { type: 'lparen' }, + { type: 'word', value: 'a' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', value: 'cat' }, + { type: 'rparen' }, + { type: 'weight', value: 1.2 }, + ]); + }); + + it('should tokenize words with weights', () => { + const tokens = tokenize('cat+'); + expect(tokens).toEqual([ + { type: 'word', value: 'cat' }, + { type: 'weight', value: '+' }, + ]); + }); + + it('should tokenize embeddings', () => { + const tokens = tokenize(''); + expect(tokens).toEqual([{ type: 'lembed' }, { type: 'word', value: 'embedding_name' }, { type: 'rembed' }]); + }); + }); + + describe('parseTokens', () => { + it('should parse basic text', () => { + const tokens = tokenize('a cat'); + const ast = parseTokens(tokens); + expect(ast).toEqual([ + { type: 'word', text: 'a' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', text: 'cat' }, + ]); + }); + + it('should parse groups', () => { + const tokens = tokenize('(a cat)'); + const ast = parseTokens(tokens); + expect(ast).toEqual([ + { + type: 'group', + children: [ + { type: 'word', text: 'a' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', text: 'cat' }, + ], + }, + ]); + }); + + it('should parse escaped parentheses', () => { + const tokens = tokenize('\\(medium\\)'); + const ast = parseTokens(tokens); + expect(ast).toEqual([ + { type: 'escaped_paren', value: '(' }, + { type: 'word', text: 'medium' }, + { type: 'escaped_paren', value: ')' }, + ]); + }); + + it('should parse mixed escaped and unescaped parentheses', () => { + const tokens = tokenize('colored pencil \\(medium\\) (enhanced)'); + const ast = parseTokens(tokens); + expect(ast).toEqual([ + { type: 'word', text: 'colored' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', text: 'pencil' }, + { type: 'whitespace', value: ' ' }, + { type: 'escaped_paren', value: '(' }, + { type: 'word', text: 'medium' }, + { type: 'escaped_paren', value: ')' }, + { type: 'whitespace', value: ' ' }, + { + type: 'group', + children: [{ type: 'word', text: 'enhanced' }], + }, + ]); + }); + + it('should parse groups with attention', () => { + const tokens = tokenize('(a cat)1.2'); + const ast = parseTokens(tokens); + expect(ast).toEqual([ + { + type: 'group', + attention: 1.2, + children: [ + { type: 'word', text: 'a' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', text: 'cat' }, + ], + }, + ]); + }); + + it('should parse words with attention', () => { + const tokens = tokenize('cat+'); + const ast = parseTokens(tokens); + expect(ast).toEqual([{ type: 'word', text: 'cat', attention: '+' }]); + }); + + it('should parse embeddings', () => { + const tokens = tokenize(''); + const ast = parseTokens(tokens); + expect(ast).toEqual([{ type: 'embedding', value: 'embedding_name' }]); + }); + }); + + describe('serialize', () => { + it('should serialize basic text', () => { + const tokens = tokenize('a cat'); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe('a cat'); + }); + + it('should serialize groups', () => { + const tokens = tokenize('(a cat)'); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe('(a cat)'); + }); + + it('should serialize escaped parentheses', () => { + const tokens = tokenize('\\(medium\\)'); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe('\\(medium\\)'); + }); + + it('should serialize mixed escaped and unescaped parentheses', () => { + const tokens = tokenize('colored pencil \\(medium\\) (enhanced)'); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe('colored pencil \\(medium\\) (enhanced)'); + }); + + it('should serialize groups with attention', () => { + const tokens = tokenize('(a cat)1.2'); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe('(a cat)1.2'); + }); + + it('should serialize words with attention', () => { + const tokens = tokenize('cat+'); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe('cat+'); + }); + + it('should serialize embeddings', () => { + const tokens = tokenize(''); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe(''); + }); + }); + + describe('compel compatibility examples', () => { + it('should handle escaped parentheses for literal text', () => { + const prompt = 'A bear \\(with razor-sharp teeth\\) in a forest.'; + const tokens = tokenize(prompt); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe(prompt); + }); + + it('should handle unescaped parentheses as grouping syntax', () => { + const prompt = 'A bear (with razor-sharp teeth) in a forest.'; + const tokens = tokenize(prompt); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe(prompt); + }); + + it('should handle colored pencil medium example', () => { + const prompt = 'colored pencil \\(medium\\)'; + const tokens = tokenize(prompt); + const ast = parseTokens(tokens); + const result = serialize(ast); + expect(result).toBe(prompt); + }); + + it('should distinguish between escaped and unescaped in same prompt', () => { + const prompt = 'portrait \\(realistic\\) (high quality)1.2'; + const tokens = tokenize(prompt); + const ast = parseTokens(tokens); + + // Should have escaped parens as nodes and a group with attention + expect(ast).toEqual([ + { type: 'word', text: 'portrait' }, + { type: 'whitespace', value: ' ' }, + { type: 'escaped_paren', value: '(' }, + { type: 'word', text: 'realistic' }, + { type: 'escaped_paren', value: ')' }, + { type: 'whitespace', value: ' ' }, + { + type: 'group', + attention: 1.2, + children: [ + { type: 'word', text: 'high' }, + { type: 'whitespace', value: ' ' }, + { type: 'word', text: 'quality' }, + ], + }, + ]); + + const result = serialize(ast); + expect(result).toBe(prompt); + }); + }); +}); diff --git a/invokeai/frontend/web/src/common/util/promptAST.ts b/invokeai/frontend/web/src/common/util/promptAST.ts new file mode 100644 index 00000000000..ab9df32e064 --- /dev/null +++ b/invokeai/frontend/web/src/common/util/promptAST.ts @@ -0,0 +1,356 @@ +/** + * Expected as either '+', '-', '++', '--', etc. or a numeric string like '1.2', '0.8', etc. + */ +export type Attention = string | number; + +type Word = string; + +type Punct = string; + +type Whitespace = string; + +type Embedding = string; + +export type Token = + | { type: 'word'; value: Word } + | { type: 'whitespace'; value: Whitespace } + | { type: 'punct'; value: Punct } + | { type: 'lparen' } + | { type: 'rparen' } + | { type: 'weight'; value: Attention } + | { type: 'lembed' } + | { type: 'rembed' } + | { type: 'escaped_paren'; value: '(' | ')' }; + +export type ASTNode = + | { type: 'word'; text: Word; attention?: Attention } + | { type: 'group'; children: ASTNode[]; attention?: Attention } + | { type: 'embedding'; value: Embedding } + | { type: 'whitespace'; value: Whitespace } + | { type: 'punct'; value: Punct } + | { type: 'escaped_paren'; value: '(' | ')' }; + +const WEIGHT_PATTERN = /^[+-]?(\d+(\.\d+)?|[+-]+)/; +const WHITESPACE_PATTERN = /^\s+/; +const PUNCTUATION_PATTERN = /^[.,]/; +const OTHER_PATTERN = /\s/; + +/** + * Convert a prompt string into an AST. + * @param prompt string + * @returns ASTNode[] + */ +export function tokenize(prompt: string): Token[] { + if (!prompt) { + return []; + } + + const len = prompt.length; + let tokens: Token[] = []; + let i = 0; + + while (i < len) { + const char = prompt[i]; + if (!char) { + break; + } + + const result = + tokenizeWhitespace(char, i) || + tokenizeEscapedParen(prompt, i) || + tokenizeLeftParen(char, i) || + tokenizeRightParen(prompt, i) || + tokenizeEmbedding(char, i) || + tokenizeWord(prompt, i) || + tokenizePunctuation(char, i) || + tokenizeOther(char, i); + + if (result) { + if (result.token) { + tokens.push(result.token); + } + if (result.extraToken) { + tokens.push(result.extraToken); + } + i = result.nextIndex; + } else { + i++; + } + } + + return tokens; +} + +type TokenizeResult = { + token?: Token; + extraToken?: Token; + nextIndex: number; +} | null; + +function tokenizeWhitespace(char: string, i: number): TokenizeResult { + if (WHITESPACE_PATTERN.test(char)) { + return { + token: { type: 'whitespace', value: char }, + nextIndex: i + 1, + }; + } + return null; +} + +function tokenizeEscapedParen(prompt: string, i: number): TokenizeResult { + const char = prompt[i]; + if (char === '\\' && i + 1 < prompt.length) { + const nextChar = prompt[i + 1]; + if (nextChar === '(' || nextChar === ')') { + return { + token: { type: 'escaped_paren', value: nextChar }, + nextIndex: i + 2, + }; + } + } + return null; +} + +function tokenizeLeftParen(char: string, i: number): TokenizeResult { + if (char === '(') { + return { + token: { type: 'lparen' }, + nextIndex: i + 1, + }; + } + return null; +} + +function tokenizeRightParen(prompt: string, i: number): TokenizeResult { + const char = prompt[i]; + if (char === ')') { + // Look ahead for weight like ')1.1' or ')-0.9' or ')+' or ')-' + const weightMatch = prompt.slice(i + 1).match(WEIGHT_PATTERN); + if (weightMatch && weightMatch[0]) { + let weight: Attention = weightMatch[0]; + if (!isNaN(Number(weight))) { + weight = Number(weight); + } + return { + token: { type: 'rparen' }, + extraToken: { type: 'weight', value: weight }, + nextIndex: i + 1 + weightMatch[0].length, + }; + } + return { + token: { type: 'rparen' }, + nextIndex: i + 1, + }; + } + return null; +} + +function tokenizePunctuation(char: string, i: number): TokenizeResult { + if (PUNCTUATION_PATTERN.test(char)) { + return { + token: { type: 'punct', value: char }, + nextIndex: i + 1, + }; + } + return null; +} + +function tokenizeWord(prompt: string, i: number): TokenizeResult { + const char = prompt[i]; + if (!char) { + return null; + } + + if (/[a-zA-Z0-9_]/.test(char)) { + let j = i; + while (j < prompt.length && /[a-zA-Z0-9_]/.test(prompt[j]!)) { + j++; + } + const word = prompt.slice(i, j); + + // Check for weight immediately after word (e.g., "Lorem+", "consectetur-") + const weightMatch = prompt.slice(j).match(/^[+-]?(\d+(\.\d+)?|[+-]+)/); + if (weightMatch && weightMatch[0]) { + return { + token: { type: 'word', value: word }, + extraToken: { type: 'weight', value: weightMatch[0] }, + nextIndex: j + weightMatch[0].length, + }; + } + + return { + token: { type: 'word', value: word }, + nextIndex: j, + }; + } + return null; +} + +function tokenizeEmbedding(char: string, i: number): TokenizeResult { + if (char === '<') { + return { + token: { type: 'lembed' }, + nextIndex: i + 1, + }; + } + if (char === '>') { + return { + token: { type: 'rembed' }, + nextIndex: i + 1, + }; + } + return null; +} + +function tokenizeOther(char: string, i: number): TokenizeResult { + // Any other single character punctuation + if (OTHER_PATTERN.test(char)) { + return { + token: { type: 'punct', value: char }, + nextIndex: i + 1, + }; + } + return null; +} + +/** + * Convert tokens into an AST. + * @param tokens Token[] + * @returns ASTNode[] + */ +export function parseTokens(tokens: Token[]): ASTNode[] { + let pos = 0; + + function peek(): Token | undefined { + return tokens[pos]; + } + + function consume(): Token | undefined { + return tokens[pos++]; + } + + function parseGroup(): ASTNode[] { + const nodes: ASTNode[] = []; + + while (pos < tokens.length) { + const token = peek(); + if (!token || token.type === 'rparen') { + break; + } + // console.log('Parsing token:', token); + + switch (token.type) { + case 'whitespace': { + const wsToken = consume() as Token & { type: 'whitespace' }; + nodes.push({ type: 'whitespace', value: wsToken.value }); + break; + } + case 'lparen': { + consume(); + const groupChildren = parseGroup(); + + let attention: Attention | undefined; + if (peek()?.type === 'rparen') { + consume(); // consume ')' + if (peek()?.type === 'weight') { + attention = (consume() as Token & { type: 'weight' }).value; + } + } + + nodes.push({ type: 'group', children: groupChildren, attention }); + break; + } + case 'lembed': { + consume(); // consume '<' + let embedValue = ''; + while (peek() && peek()!.type !== 'rembed') { + const embedToken = consume()!; + embedValue += + embedToken.type === 'word' || embedToken.type === 'punct' || embedToken.type === 'whitespace' + ? embedToken.value + : ''; + } + if (peek()?.type === 'rembed') { + consume(); // consume '>' + } + nodes.push({ type: 'embedding', value: embedValue.trim() }); + break; + } + case 'word': { + const wordToken = consume() as Token & { type: 'word' }; + let attention: Attention | undefined; + + // Check for immediate weight after word + if (peek()?.type === 'weight') { + attention = (consume() as Token & { type: 'weight' }).value; + } + + nodes.push({ type: 'word', text: wordToken.value, attention }); + break; + } + case 'punct': { + const punctToken = consume() as Token & { type: 'punct' }; + nodes.push({ type: 'punct', value: punctToken.value }); + break; + } + case 'escaped_paren': { + const escapedToken = consume() as Token & { type: 'escaped_paren' }; + nodes.push({ type: 'escaped_paren', value: escapedToken.value }); + break; + } + default: { + consume(); + } + } + } + + return nodes; + } + + return parseGroup(); +} + +/** + * Convert an AST back into a prompt string. + * @param ast ASTNode[] + * @returns string + */ +export function serialize(ast: ASTNode[]): string { + let prompt = ''; + + for (const node of ast) { + switch (node.type) { + case 'punct': + case 'whitespace': { + prompt += node.value; + break; + } + case 'escaped_paren': { + prompt += `\\${node.value}`; + break; + } + case 'word': { + prompt += node.text; + if (node.attention) { + prompt += String(node.attention); + } + break; + } + case 'group': { + prompt += '('; + prompt += serialize(node.children); + prompt += ')'; + if (node.attention) { + prompt += String(node.attention); + } + break; + } + case 'embedding': { + prompt += `<${node.value}>`; + break; + } + } + } + + return prompt; +} diff --git a/invokeai/frontend/web/src/common/util/promptAttention.ts b/invokeai/frontend/web/src/common/util/promptAttention.ts new file mode 100644 index 00000000000..2d5808d28e2 --- /dev/null +++ b/invokeai/frontend/web/src/common/util/promptAttention.ts @@ -0,0 +1,565 @@ +import { logger } from 'app/logging/logger'; +import { serializeError } from 'serialize-error'; + +import { type ASTNode, type Attention, parseTokens, serialize, tokenize } from './promptAST'; + +const log = logger('events'); + +/** + * Behavior Rules: + * + * ATTENTION SYNTAX: + * - Words support +/- attention: `word+`, `word-`, `word++`, `word--`, etc. + * - Groups support both +/- and numeric: `(words)+`, `(words)1.2`, `(words)0.8` + * - `word++` is roughly equivalent to `(word)1.2` in effect + * - Mixed attention like `word+-` or numeric on words `word1.2` is invalid + * + * ADJUSTMENT RULES: + * - `word++` down → `word+` + * - `word+` down → `word` (attention removed) + * - `word` down → `word-` + * - `word-` down → `word--` + * - `(content)1.2` down → `(content)1.1` + * - `(content)1.1` down → `content` (group unwrapped) + * - `content content` down → `(content content)0.9` (new group created) + * + * SELECTION BEHAVIOR: + * - Cursor/selection within a word → expand to full word, adjust word attention + * - Cursor touching group boundary (parens or weight) → adjust group attention + * - Selection entirely within a group → adjust that group's attention + * - Selection spans multiple content nodes → create new group with initial attention + * - Whitespace and punctuation are ignored for content selection + */ + +export type AttentionDirection = 'up' | 'down'; + +type SelectionBounds = { + start: number; + end: number; +}; + +export type AdjustmentResult = { + prompt: string; + selectionStart: number; + selectionEnd: number; +}; + +/** + * A node with its position in the serialized prompt string. + */ +type PositionedNode = { + node: ASTNode; + start: number; + end: number; + /** Position of content start (after opening paren for groups) */ + contentStart: number; + /** Position of content end (before closing paren for groups) */ + contentEnd: number; + parent: PositionedNode | null; +}; + +// ============================================================================ +// ATTENTION COMPUTATION +// ============================================================================ + +/** + * Checks if attention is symbol-based (+, -, ++, etc.) + */ +function isSymbolAttention(attention: Attention | undefined): attention is string { + return typeof attention === 'string' && /^[+-]+$/.test(attention); +} + +/** + * Checks if attention is numeric (1.2, 0.8, etc.) + */ +function isNumericAttention(attention: Attention | undefined): attention is number { + return typeof attention === 'number'; +} + +/** + * Computes adjusted attention for symbol-based attention (+/-). + * Used for both words and groups with symbol attention. + */ +function adjustSymbolAttention(direction: AttentionDirection, attention: string | undefined): string | undefined { + if (!attention) { + return direction === 'up' ? '+' : '-'; + } + + if (direction === 'up') { + // Going up: remove '-' if present, otherwise add '+' + if (attention.endsWith('-')) { + const result = attention.slice(0, -1); + return result || undefined; + } + return `${attention}+`; + } else { + // Going down: remove '+' if present, otherwise add '-' + if (attention.endsWith('+')) { + const result = attention.slice(0, -1); + return result || undefined; + } + return `${attention}-`; + } +} + +/** + * Computes adjusted attention for numeric attention. + * Only used for groups. + */ +function adjustNumericAttention(direction: AttentionDirection, attention: number): number | undefined { + const step = direction === 'up' ? 0.1 : -0.1; + const result = parseFloat((attention + step).toFixed(1)); + + // 1.0 is default - return undefined to signal unwrapping + if (result === 1.0) { + return undefined; + } + + return result; +} + +/** + * Computes the new attention value based on direction and current attention. + * Returns undefined if attention should be removed (normalize to default). + */ +function computeAttention( + direction: AttentionDirection, + attention: Attention | undefined, + isGroup: boolean +): Attention | undefined { + // No current attention + if (attention === undefined) { + if (isGroup) { + // Groups going down from neutral get 0.9 + return direction === 'up' ? '+' : 0.9; + } + return direction === 'up' ? '+' : '-'; + } + + // Symbol attention (+, -, ++, etc.) + if (isSymbolAttention(attention)) { + return adjustSymbolAttention(direction, attention); + } + + // Numeric attention (only valid for groups) + if (isNumericAttention(attention)) { + return adjustNumericAttention(direction, attention); + } + + // Parse string numbers + const numValue = parseFloat(String(attention)); + if (!isNaN(numValue)) { + return adjustNumericAttention(direction, numValue); + } + + // Fallback: treat as no attention + return direction === 'up' ? '+' : '-'; +} + +// ============================================================================ +// POSITION MAPPING +// ============================================================================ + +/** + * Builds a flat map of all nodes with their positions in the prompt string. + * Groups include both their full bounds and content bounds. + */ +function buildPositionMap( + ast: ASTNode[], + startPos = 0, + parent: PositionedNode | null = null +): { positions: PositionedNode[]; endPos: number } { + const positions: PositionedNode[] = []; + let currentPos = startPos; + + for (const node of ast) { + const nodeStart = currentPos; + let contentStart = currentPos; + let contentEnd = currentPos; + let nodeEnd = currentPos; + + switch (node.type) { + case 'word': { + nodeEnd = currentPos + node.text.length; + if (node.attention !== undefined) { + nodeEnd += String(node.attention).length; + } + contentStart = currentPos; + contentEnd = currentPos + node.text.length; + currentPos = nodeEnd; + break; + } + + case 'whitespace': + case 'punct': { + nodeEnd = currentPos + node.value.length; + contentStart = nodeStart; + contentEnd = nodeEnd; + currentPos = nodeEnd; + break; + } + + case 'escaped_paren': { + nodeEnd = currentPos + 2; // \( or \) + contentStart = nodeStart; + contentEnd = nodeEnd; + currentPos = nodeEnd; + break; + } + + case 'embedding': { + nodeEnd = currentPos + node.value.length + 2; // + contentStart = currentPos + 1; + contentEnd = currentPos + 1 + node.value.length; + currentPos = nodeEnd; + break; + } + + case 'group': { + // Opening paren + currentPos += 1; + contentStart = currentPos; + + // Create placeholder for parent reference + const groupNode: PositionedNode = { + node, + start: nodeStart, + end: nodeStart, // Will be updated + contentStart, + contentEnd: contentStart, // Will be updated + parent, + }; + + // Process children with this group as parent + const childResult = buildPositionMap(node.children, currentPos, groupNode); + positions.push(...childResult.positions); + currentPos = childResult.endPos; + + contentEnd = currentPos; + + // Closing paren + currentPos += 1; + + // Attention + if (node.attention !== undefined) { + currentPos += String(node.attention).length; + } + + nodeEnd = currentPos; + + // Update the group node with final positions + groupNode.end = nodeEnd; + groupNode.contentEnd = contentEnd; + + positions.push(groupNode); + continue; // Skip the push at the end + } + } + + positions.push({ + node, + start: nodeStart, + end: nodeEnd, + contentStart, + contentEnd, + parent, + }); + } + + return { positions, endPos: currentPos }; +} + +// ============================================================================ +// NODE FINDING +// ============================================================================ + +/** + * Finds the deepest group that fully contains the selection. + * Returns null if selection is not fully within any group. + */ +function findEnclosingGroup(positions: PositionedNode[], selection: SelectionBounds): PositionedNode | null { + const groups = positions + .filter((p) => p.node.type === 'group') + .filter((p) => selection.start >= p.start && selection.end <= p.end) + // Sort by size (smallest = deepest nesting) + .sort((a, b) => a.end - a.start - (b.end - b.start)); + + return groups[0] ?? null; +} + +/** + * Checks if the cursor/selection is at the boundary of a group + * (touching parentheses or weight). + */ +function isTouchingGroupBoundary(group: PositionedNode, selection: SelectionBounds): boolean { + const { start, end, contentStart, contentEnd } = group; + + // Touching or at opening paren + if (selection.start <= contentStart && selection.end <= contentStart) { + return true; + } + + // Touching or at closing paren/weight + if (selection.start >= contentEnd && selection.end >= contentEnd) { + return true; + } + + // Selection spans the entire group content + if (selection.start <= contentStart && selection.end >= contentEnd) { + return true; + } + + // Cursor is exactly at group start or end + if (selection.start === selection.end) { + if (selection.start === start || selection.start === end) { + return true; + } + } + + return false; +} + +/** + * Finds content nodes (words, groups, embeddings) that intersect with selection. + */ +function findContentNodes(positions: PositionedNode[], selection: SelectionBounds): PositionedNode[] { + return positions.filter((p) => { + // Only content nodes + if (p.node.type !== 'word' && p.node.type !== 'group' && p.node.type !== 'embedding') { + return false; + } + + // Check intersection + return !(selection.end <= p.start || selection.start >= p.end); + }); +} + +/** + * Finds the single word the cursor is within (not just touching). + */ +function findWordAtCursor(positions: PositionedNode[], selection: SelectionBounds): PositionedNode | null { + const words = positions.filter((p) => p.node.type === 'word' && selection.start >= p.start && selection.end <= p.end); + + return words[0] ?? null; +} + +// ============================================================================ +// AST MANIPULATION +// ============================================================================ + +/** + * Replaces a node in the AST with replacement node(s). + * Uses reference equality to find the target. + */ +function replaceNodeInAST(ast: ASTNode[], target: ASTNode, replacement: ASTNode | ASTNode[]): ASTNode[] { + const replacements = Array.isArray(replacement) ? replacement : [replacement]; + + return ast.flatMap((node) => { + if (node === target) { + return replacements; + } + + if (node.type === 'group') { + const newChildren = replaceNodeInAST(node.children, target, replacement); + // Only create new object if children changed + if (newChildren !== node.children) { + return [{ ...node, children: newChildren }]; + } + } + + return [node]; + }); +} + +// ============================================================================ +// MAIN ADJUSTMENT FUNCTION +// ============================================================================ + +/** + * Determines the adjustment strategy based on selection and AST structure. + */ +type AdjustmentStrategy = + | { type: 'adjust-word'; node: PositionedNode } + | { type: 'adjust-group'; node: PositionedNode } + | { type: 'create-group'; nodes: PositionedNode[] } + | { type: 'no-op' }; + +function determineStrategy(positions: PositionedNode[], selection: SelectionBounds): AdjustmentStrategy { + const contentNodes = findContentNodes(positions, selection); + + if (contentNodes.length === 0) { + return { type: 'no-op' }; + } + + // Check if we're in a group context first + const enclosingGroup = findEnclosingGroup(positions, selection); + + if (enclosingGroup) { + // If touching group boundary, adjust the group + if (isTouchingGroupBoundary(enclosingGroup, selection)) { + return { type: 'adjust-group', node: enclosingGroup }; + } + + // Check for single word within the group + const wordAtCursor = findWordAtCursor(positions, selection); + if (wordAtCursor) { + return { type: 'adjust-word', node: wordAtCursor }; + } + + // Selection spans content within group - adjust the group + return { type: 'adjust-group', node: enclosingGroup }; + } + + // No enclosing group - check for single word + const wordAtCursor = findWordAtCursor(positions, selection); + if (wordAtCursor) { + return { type: 'adjust-word', node: wordAtCursor }; + } + + // Single content node (could be word, embedding, or group) + if (contentNodes.length === 1) { + const node = contentNodes[0]!; + if (node.node.type === 'group') { + return { type: 'adjust-group', node }; + } + if (node.node.type === 'word') { + return { type: 'adjust-word', node }; + } + // Embeddings don't support attention adjustment - wrap in group + return { type: 'create-group', nodes: contentNodes }; + } + + // Multiple content nodes - create a new group + return { type: 'create-group', nodes: contentNodes }; +} + +/** + * Adjusts the attention of the prompt at the current cursor/selection position. + */ +export function adjustPromptAttention( + prompt: string, + selectionStart: number, + selectionEnd: number, + direction: AttentionDirection +): AdjustmentResult { + try { + // Handle empty prompt + if (!prompt.trim()) { + return { prompt, selectionStart, selectionEnd }; + } + + // Normalize selection + const selection: SelectionBounds = { + start: Math.min(selectionStart, selectionEnd), + end: Math.max(selectionStart, selectionEnd), + }; + + // Parse and build position map + const tokens = tokenize(prompt); + const ast = parseTokens(tokens); + const { positions } = buildPositionMap(ast); + + // Determine what to do + const strategy = determineStrategy(positions, selection); + + switch (strategy.type) { + case 'no-op': + return { prompt, selectionStart, selectionEnd }; + + case 'adjust-word': { + const wordPos = strategy.node; + const word = wordPos.node as ASTNode & { type: 'word' }; + const newAttention = computeAttention(direction, word.attention, false); + + const updatedWord: ASTNode = { + type: 'word', + text: word.text, + attention: newAttention, + }; + + const newAST = replaceNodeInAST(ast, word, updatedWord); + const newPrompt = serialize(newAST); + const newWordText = serialize([updatedWord]); + + return { + prompt: newPrompt, + selectionStart: wordPos.start, + selectionEnd: wordPos.start + newWordText.length, + }; + } + + case 'adjust-group': { + const groupPos = strategy.node; + const group = groupPos.node as ASTNode & { type: 'group' }; + const newAttention = computeAttention(direction, group.attention, true); + + // If attention becomes undefined (1.0), unwrap the group + if (newAttention === undefined) { + const newAST = replaceNodeInAST(ast, group, group.children); + const newPrompt = serialize(newAST); + const childrenText = serialize(group.children); + + return { + prompt: newPrompt, + selectionStart: groupPos.start, + selectionEnd: groupPos.start + childrenText.length, + }; + } + + const updatedGroup: ASTNode = { + type: 'group', + children: group.children, + attention: newAttention, + }; + + const newAST = replaceNodeInAST(ast, group, updatedGroup); + const newPrompt = serialize(newAST); + const newGroupText = serialize([updatedGroup]); + + return { + prompt: newPrompt, + selectionStart: groupPos.start, + selectionEnd: groupPos.start + newGroupText.length, + }; + } + + case 'create-group': { + const nodes = strategy.nodes; + const firstNode = nodes[0]!; + const lastNode = nodes[nodes.length - 1]!; + + // Get the text range to wrap + const wrapStart = firstNode.start; + const wrapEnd = lastNode.end; + + // Parse just the selected portion + const selectedText = prompt.substring(wrapStart, wrapEnd); + const selectedTokens = tokenize(selectedText); + const selectedAST = parseTokens(selectedTokens); + + // Create new group with appropriate attention + const newAttention = computeAttention(direction, undefined, true); + const newGroup: ASTNode = { + type: 'group', + children: selectedAST, + attention: newAttention, + }; + + // Reconstruct prompt + const before = prompt.substring(0, wrapStart); + const after = prompt.substring(wrapEnd); + const newGroupText = serialize([newGroup]); + const newPrompt = before + newGroupText + after; + + return { + prompt: newPrompt, + selectionStart: wrapStart, + selectionEnd: wrapStart + newGroupText.length, + }; + } + } + } catch (error) { + log.error({ error: serializeError(error as Error) }, 'Error adjusting prompt attention'); + return { prompt, selectionStart, selectionEnd }; + } +} diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx index 1ba98fa774f..f845c0958b9 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamNegativePrompt.tsx @@ -63,7 +63,6 @@ export const ParamNegativePrompt = memo(() => { value={prompt} onChange={onChange} onKeyDown={onKeyDown} - fontSize="sm" variant="darkFilled" minH={28} borderTopWidth={24} // This prevents the prompt from being hidden behind the header @@ -71,6 +70,8 @@ export const ParamNegativePrompt = memo(() => { paddingInlineStart={3} paddingTop={0} paddingBottom={3} + fontFamily="mono" + fontSize="sm" /> diff --git a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx index 73d22f0eac7..849945209e2 100644 --- a/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx +++ b/invokeai/frontend/web/src/features/parameters/components/Core/ParamPositivePrompt.tsx @@ -1,6 +1,7 @@ import { Box, Flex, Textarea } from '@invoke-ai/ui-library'; import { useAppDispatch, useAppSelector, useAppStore } from 'app/store/storeHooks'; import { usePersistedTextAreaSize } from 'common/hooks/usePersistedTextareaSize'; +import { adjustPromptAttention } from 'common/util/promptAttention'; import { positivePromptChanged, selectModelSupportsNegativePrompt, @@ -192,6 +193,58 @@ export const ParamPositivePrompt = memo(() => { dependencies: [promptHistoryApi.next, isPromptFocused], }); + // Adjust prompt attention up + useRegisteredHotkeys({ + id: 'promptWeightUp', + category: 'app', + callback: (e) => { + if (isPromptFocused() && textareaRef.current) { + e.preventDefault(); + const textarea = textareaRef.current; + const result = adjustPromptAttention(textarea.value, textarea.selectionStart, textarea.selectionEnd, 'up'); + + // Update the prompt + dispatch(positivePromptChanged(result.prompt)); + + // Update selection after React re-renders + setTimeout(() => { + if (textareaRef.current) { + textareaRef.current.setSelectionRange(result.selectionStart, result.selectionEnd); + textareaRef.current.focus(); + } + }, 0); + } + }, + options: { preventDefault: true, enableOnFormTags: ['TEXTAREA'] }, + dependencies: [dispatch, isPromptFocused], + }); + + // Adjust prompt attention down + useRegisteredHotkeys({ + id: 'promptWeightDown', + category: 'app', + callback: (e) => { + if (isPromptFocused() && textareaRef.current) { + e.preventDefault(); + const textarea = textareaRef.current; + const result = adjustPromptAttention(textarea.value, textarea.selectionStart, textarea.selectionEnd, 'down'); + + // Update the prompt + dispatch(positivePromptChanged(result.prompt)); + + // Update selection after React re-renders + setTimeout(() => { + if (textareaRef.current) { + textareaRef.current.setSelectionRange(result.selectionStart, result.selectionEnd); + textareaRef.current.focus(); + } + }, 0); + } + }, + options: { preventDefault: true, enableOnFormTags: ['TEXTAREA'] }, + dependencies: [dispatch, isPromptFocused], + }); + return ( @@ -211,6 +264,8 @@ export const ParamPositivePrompt = memo(() => { paddingBottom={3} resize="vertical" minH={32} + fontFamily="mono" + fontSize="sm" /> diff --git a/invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts b/invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts index cdd4fa099e2..c819e0ab2a9 100644 --- a/invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts +++ b/invokeai/frontend/web/src/features/system/components/HotkeysModal/useHotkeyData.ts @@ -92,6 +92,8 @@ export const useHotkeyData = (): HotkeysData => { // Prompt/history navigation (when prompt textarea is focused) addHotkey('app', 'promptHistoryPrev', ['alt+up']); addHotkey('app', 'promptHistoryNext', ['alt+down']); + addHotkey('app', 'promptWeightUp', ['ctrl+up']); + addHotkey('app', 'promptWeightDown', ['ctrl+down']); addHotkey('app', 'focusPrompt', ['alt+a']); addHotkey('app', 'toggleLeftPanel', ['t', 'o']);