Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions src/Workflow/Exporter/MermaidExporter.php
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ public function export(Workflow $graph): string

private function getShortClassName(string $class): string
{
$reflection = new ReflectionClass($class);
return $reflection->getShortName();
// Check if it's a class name (contains namespace separator) and class exists
if (strpos($class, '\\') !== false && class_exists($class)) {
$reflection = new ReflectionClass($class);
return $reflection->getShortName();
}

// Otherwise, it's a custom string key, use it directly
return $class;
}
}
38 changes: 31 additions & 7 deletions src/Workflow/Workflow.php
Original file line number Diff line number Diff line change
Expand Up @@ -216,19 +216,31 @@ protected function edges(): array
return [];
}

public function addNode(NodeInterface $node): self
public function addNode(NodeInterface $node, ?string $key = null): self
{
$this->nodes[$node::class] = $node;
$nodeKey = $key ?? $node::class;
$this->nodes[$nodeKey] = $node;
return $this;
}

/**
* @param NodeInterface[] $nodes
* @param NodeInterface[]|array<string, NodeInterface> $nodes
*/
public function addNodes(array $nodes): Workflow
{
foreach ($nodes as $node) {
$this->addNode($node);
// Check if it's an associative array
$isAssociative = count(array_filter(array_keys($nodes), 'is_string')) > 0;

if ($isAssociative) {
// If associative, use the keys
foreach ($nodes as $key => $node) {
$this->addNode($node, $key);
}
} else {
// If indexed, use class names as keys
foreach ($nodes as $node) {
$this->addNode($node);
}
}
return $this;
}
Expand All @@ -239,8 +251,20 @@ public function addNodes(array $nodes): Workflow
public function getNodes(): array
{
if ($this->nodes === []) {
foreach ($this->nodes() as $node) {
$this->addNode($node);
$nodeDefinitions = $this->nodes();

// Check if it's an associative array (has string keys)
// An associative array has string keys or non-sequential numeric keys
$isAssociative = count(array_filter(array_keys($nodeDefinitions), 'is_string')) > 0;

if ($isAssociative) {
// New behavior: use provided keys directly
$this->nodes = $nodeDefinitions;
} else {
// Old behavior: use class names as keys
foreach ($nodeDefinitions as $node) {
$this->addNode($node);
}
}
}

Expand Down
235 changes: 235 additions & 0 deletions tests/Workflow/WorkflowStringKeysTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,235 @@
<?php

declare(strict_types=1);

namespace NeuronAI\Tests\Workflow;

use NeuronAI\Workflow\Edge;
use NeuronAI\Workflow\Node;
use NeuronAI\Workflow\Workflow;
use NeuronAI\Workflow\WorkflowState;
use PHPUnit\Framework\TestCase;

// Calculator nodes that can be reused with different values
class AddNode extends Node
{
public function __construct(private int $value)
{
}

public function run(WorkflowState $state): WorkflowState
{
$current = $state->get('value', 0);
$state->set('value', $current + $this->value);
$history = $state->get('history', []);
$history[] = "Added {$this->value}";
$state->set('history', $history);
return $state;
}
}

class MultiplyNode extends Node
{
public function __construct(private int $value)
{
}

public function run(WorkflowState $state): WorkflowState
{
$current = $state->get('value', 0);
$state->set('value', $current * $this->value);
$history = $state->get('history', []);
$history[] = "Multiplied by {$this->value}";
$state->set('history', $history);
return $state;
}
}

class SubtractNode extends Node
{
public function __construct(private int $value)
{
}

public function run(WorkflowState $state): WorkflowState
{
$current = $state->get('value', 0);
$state->set('value', $current - $this->value);
$history = $state->get('history', []);
$history[] = "Subtracted {$this->value}";
$state->set('history', $history);
return $state;
}
}

class FinishEvenNode extends Node
{
public function run(WorkflowState $state): WorkflowState
{
$state->set('result_type', 'even');
return $state;
}
}

class FinishOddNode extends Node
{
public function run(WorkflowState $state): WorkflowState
{
$state->set('result_type', 'odd');
return $state;
}
}

// Test workflow that uses string keys
class CalculatorWorkflow extends Workflow
{
public function nodes(): array
{
return [
'add1' => new AddNode(1),
'multiply3_first' => new MultiplyNode(3),
'multiply3_second' => new MultiplyNode(3),
'sub1' => new SubtractNode(1),
'finish_even' => new FinishEvenNode(),
'finish_odd' => new FinishOddNode()
];
}

public function edges(): array
{
return [
// ((startingValue + 1) * 3) * 3) - 1
new Edge('add1', 'multiply3_first'),
new Edge('multiply3_first', 'multiply3_second'),
new Edge('multiply3_second', 'sub1'),

// Branch based on even/odd
new Edge('sub1', 'finish_even', fn($state) => $state->get('value') % 2 === 0),
new Edge('sub1', 'finish_odd', fn($state) => $state->get('value') % 2 !== 0)
];
}

protected function start(): string
{
return 'add1';
}

protected function end(): array
{
return ['finish_even', 'finish_odd'];
}
}

class WorkflowStringKeysTest extends TestCase
{
public function test_workflow_with_string_keys(): void
{
$workflow = new CalculatorWorkflow();

// Test with initial value 2: ((2 + 1) * 3) * 3) - 1 = 26 (even)
$initialState = new WorkflowState(['value' => 2]);
$result = $workflow->run($initialState);

$this->assertEquals(26, $result->get('value'));
$this->assertEquals('even', $result->get('result_type'));
$this->assertContains('Added 1', $result->get('history'));
$this->assertContains('Multiplied by 3', $result->get('history'));
$this->assertContains('Subtracted 1', $result->get('history'));
}

public function test_workflow_with_string_keys_odd_result(): void
{
$workflow = new CalculatorWorkflow();

// Test with initial value 1: ((1 + 1) * 3) * 3) - 1 = 17 (odd)
$initialState = new WorkflowState(['value' => 1]);
$result = $workflow->run($initialState);

$this->assertEquals(17, $result->get('value'));
$this->assertEquals('odd', $result->get('result_type'));
}

public function test_programmatic_workflow_with_string_keys(): void
{
$workflow = new Workflow();
$workflow->addNodes([
'add1' => new AddNode(1),
'multiply2' => new MultiplyNode(2),
'finish_even' => new FinishEvenNode(),
'finish_odd' => new FinishOddNode()
])
->addEdges([
new Edge('add1', 'multiply2'),
new Edge('multiply2', 'finish_even', fn($state) => $state->get('value') % 2 === 0),
new Edge('multiply2', 'finish_odd', fn($state) => $state->get('value') % 2 !== 0)
])
->setStart('add1')
->setEnd('finish_even')
->setEnd('finish_odd');

// Test with initial value 3: (3 + 1) * 2 = 8 (even)
$initialState = new WorkflowState(['value' => 3]);
$result = $workflow->run($initialState);

$this->assertEquals(8, $result->get('value'));
$this->assertEquals('even', $result->get('result_type'));
}

public function test_mermaid_export_with_string_keys(): void
{
$workflow = new Workflow();
$workflow->addNodes([
'start' => new AddNode(1),
'middle' => new MultiplyNode(2),
'finish' => new FinishEvenNode()
])
->addEdges([
new Edge('start', 'middle'),
new Edge('middle', 'finish')
])
->setStart('start')
->setEnd('finish');

$export = $workflow->export();

$this->assertStringContainsString('start --> middle', $export);
$this->assertStringContainsString('middle --> finish', $export);
}

public function test_backward_compatibility_with_class_names(): void
{
// This test ensures the old behavior still works
$workflow = new Workflow();
$workflow->addNode(new StartNode())
->addNode(new FinishNode())
->addEdge(new Edge(StartNode::class, FinishNode::class))
->setStart(StartNode::class)
->setEnd(FinishNode::class);

$result = $workflow->run();

$this->assertEquals('end', $result->get('step'));
}

public function test_mixed_mode_nodes_and_edges(): void
{
// Test mixing both approaches - indexed array with class name edges
$workflow = new Workflow();
$workflow->addNodes([
new StartNode(),
new MiddleNode(),
new FinishNode()
])
->addEdges([
new Edge(StartNode::class, MiddleNode::class),
new Edge(MiddleNode::class, FinishNode::class)
])
->setStart(StartNode::class)
->setEnd(FinishNode::class);

$result = $workflow->run();

$this->assertEquals('end', $result->get('step'));
$this->assertEquals(1, $result->get('counter'));
}
}