Skip to content

Commit

Permalink
Fixing the Machine Learning client to use the provided PredictEndpoin…
Browse files Browse the repository at this point in the history
…t as the host for the Predict operation.
  • Loading branch information
jeremeamia committed Jun 1, 2015
1 parent 0e44bb2 commit a51912a
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 1 deletion.
33 changes: 32 additions & 1 deletion src/MachineLearning/MachineLearningClient.php
Expand Up @@ -2,8 +2,39 @@
namespace Aws\MachineLearning;

use Aws\AwsClient;
use Aws\CommandInterface;
use GuzzleHttp\Psr7\Uri;
use Psr\Http\Message\RequestInterface;

/**
* Amazon Machine Learning client.
*/
class MachineLearningClient extends AwsClient {}
class MachineLearningClient extends AwsClient
{
public function __construct(array $config)
{
parent::__construct($config);
$list = $this->getHandlerList();
$list->appendBuild($this->predictEndpoint(), 'ml.predict_endpoint');
}

/**
* Changes the endpoint of the Predict operation to the provided endpoint.
*
* @return callable
*/
private function predictEndpoint()
{
return static function (callable $handler) {
return function (
CommandInterface $command,
RequestInterface $request = null
) use ($handler) {
if ($command->getName() === 'Predict') {
$request = $request->withUri(new Uri($command['PredictEndpoint']));
}
return $handler($command, $request);
};
};
}
}
37 changes: 37 additions & 0 deletions tests/MachineLearning/MachineLearningClientTest.php
@@ -0,0 +1,37 @@
<?php
namespace Aws\Test\MachineLearning;

use Aws\Middleware;
use Aws\MachineLearning\MachineLearningClient;
use Aws\Test\UsesServiceTrait;
use GuzzleHttp\Psr7;
use GuzzleHttp\Psr7\Uri;

/**
* @covers Aws\MachineLearning\MachineLearningClient
*/
class MachineLearningClientTest extends \PHPUnit_Framework_TestCase
{
use UsesServiceTrait;

public function testUpdatesPredictEndpoint()
{
// Setup state of command/request
$predictEndpoint = new Uri('https://realtime.machinelearning.us-east-1.amazonaws.com/foo');
$client = new MachineLearningClient([
'region' => 'us-east-1',
'version' => 'latest'
]);
$this->addMockResults($client, [[]]);
$client->getHandlerList()->appendSign(Middleware::tap(function ($c, $r) use (&$command, &$request) {
$command = $c; $request = $r;
}));
$client->predict([
'MLModelId' => 'foo',
'Record' => ['foo' => 'bar'],
'PredictEndpoint' => (string) $predictEndpoint
]);

$this->assertEquals($predictEndpoint->getHost(), $request->getUri()->getHost());
}
}

0 comments on commit a51912a

Please sign in to comment.