Skip to content

Commit

Permalink
feat: improve caching by only decoding jwks when necessary (#486)
Browse files Browse the repository at this point in the history
  • Loading branch information
bshaffer committed Feb 8, 2023
1 parent 08c7ba6 commit 78d3ed1
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 13 deletions.
45 changes: 37 additions & 8 deletions src/CachedKeySet.php
Expand Up @@ -3,13 +3,15 @@
namespace Firebase\JWT;

use ArrayAccess;
use InvalidArgumentException;
use LogicException;
use OutOfBoundsException;
use Psr\Cache\CacheItemInterface;
use Psr\Cache\CacheItemPoolInterface;
use Psr\Http\Client\ClientInterface;
use Psr\Http\Message\RequestFactoryInterface;
use RuntimeException;
use UnexpectedValueException;

/**
* @implements ArrayAccess<string, Key>
Expand Down Expand Up @@ -41,7 +43,7 @@ class CachedKeySet implements ArrayAccess
*/
private $cacheItem;
/**
* @var array<string, Key>
* @var array<string, array<mixed>>
*/
private $keySet;
/**
Expand Down Expand Up @@ -101,7 +103,7 @@ public function offsetGet($keyId): Key
if (!$this->keyIdExists($keyId)) {
throw new OutOfBoundsException('Key ID not found');
}
return $this->keySet[$keyId];
return JWK::parseKey($this->keySet[$keyId], $this->defaultAlg);
}

/**
Expand Down Expand Up @@ -130,15 +132,43 @@ public function offsetUnset($offset): void
throw new LogicException('Method not implemented');
}

/**
* @return array<mixed>
*/
private function formatJwksForCache(string $jwks): array
{
$jwks = json_decode($jwks, true);

if (!isset($jwks['keys'])) {
throw new UnexpectedValueException('"keys" member must exist in the JWK Set');
}

if (empty($jwks['keys'])) {
throw new InvalidArgumentException('JWK Set did not contain any keys');
}

$keys = [];
foreach ($jwks['keys'] as $k => $v) {
$kid = isset($v['kid']) ? $v['kid'] : $k;
$keys[(string) $kid] = $v;
}

return $keys;
}

private function keyIdExists(string $keyId): bool
{
if (null === $this->keySet) {
$item = $this->getCacheItem();
// Try to load keys from cache
if ($item->isHit()) {
// item found! Return it
$jwks = $item->get();
$this->keySet = JWK::parseKeySet(json_decode($jwks, true), $this->defaultAlg);
// item found! retrieve it
$this->keySet = $item->get();
// If the cached item is a string, the JWKS response was cached (previous behavior).
// Parse this into expected format array<kid, jwk> instead.
if (\is_string($this->keySet)) {
$this->keySet = $this->formatJwksForCache($this->keySet);
}
}
}

Expand All @@ -148,15 +178,14 @@ private function keyIdExists(string $keyId): bool
}
$request = $this->httpFactory->createRequest('GET', $this->jwksUri);
$jwksResponse = $this->httpClient->sendRequest($request);
$jwks = (string) $jwksResponse->getBody();
$this->keySet = JWK::parseKeySet(json_decode($jwks, true), $this->defaultAlg);
$this->keySet = $this->formatJwksForCache((string) $jwksResponse->getBody());

if (!isset($this->keySet[$keyId])) {
return false;
}

$item = $this->getCacheItem();
$item->set($jwks);
$item->set($this->keySet);
if ($this->expiresAfter) {
$item->expiresAfter($this->expiresAfter);
}
Expand Down
77 changes: 72 additions & 5 deletions tests/CachedKeySetTest.php
Expand Up @@ -17,11 +17,12 @@ class CachedKeySetTest extends TestCase
private $testJwksUri = 'https://jwk.uri';
private $testJwksUriKey = 'jwkshttpsjwk.uri';
private $testJwks1 = '{"keys": [{"kid":"foo","kty":"RSA","alg":"foo","n":"","e":""}]}';
private $testCachedJwks1 = ['foo' => ['kid' => 'foo', 'kty' => 'RSA', 'alg' => 'foo', 'n' => '', 'e' => '']];
private $testJwks2 = '{"keys": [{"kid":"bar","kty":"RSA","alg":"bar","n":"","e":""}]}';
private $testJwks3 = '{"keys": [{"kid":"baz","kty":"RSA","n":"","e":""}]}';

private $googleRsaUri = 'https://www.googleapis.com/oauth2/v3/certs';
// private $googleEcUri = 'https://www.gstatic.com/iap/verify/public_key-jwk';
private $googleEcUri = 'https://www.gstatic.com/iap/verify/public_key-jwk';

public function testEmptyUriThrowsException()
{
Expand Down Expand Up @@ -117,7 +118,7 @@ public function testKeyIdIsCached()
$cacheItem->isHit()
->willReturn(true);
$cacheItem->get()
->willReturn($this->testJwks1);
->willReturn($this->testCachedJwks1);

$cache = $this->prophesize(CacheItemPoolInterface::class);
$cache->getItem($this->testJwksUriKey)
Expand All @@ -136,6 +137,66 @@ public function testKeyIdIsCached()
}

public function testCachedKeyIdRefresh()
{
$cacheItem = $this->prophesize(CacheItemInterface::class);
$cacheItem->isHit()
->shouldBeCalledOnce()
->willReturn(true);
$cacheItem->get()
->shouldBeCalledOnce()
->willReturn($this->testCachedJwks1);
$cacheItem->set(Argument::any())
->shouldBeCalledOnce()
->will(function () {
return $this;
});

$cache = $this->prophesize(CacheItemPoolInterface::class);
$cache->getItem($this->testJwksUriKey)
->shouldBeCalledOnce()
->willReturn($cacheItem->reveal());
$cache->save(Argument::any())
->shouldBeCalledOnce()
->willReturn(true);

$cachedKeySet = new CachedKeySet(
$this->testJwksUri,
$this->getMockHttpClient($this->testJwks2), // updated JWK
$this->getMockHttpFactory(),
$cache->reveal()
);
$this->assertInstanceOf(Key::class, $cachedKeySet['foo']);
$this->assertSame('foo', $cachedKeySet['foo']->getAlgorithm());

$this->assertInstanceOf(Key::class, $cachedKeySet['bar']);
$this->assertSame('bar', $cachedKeySet['bar']->getAlgorithm());
}

public function testKeyIdIsCachedFromPreviousFormat()
{
$cacheItem = $this->prophesize(CacheItemInterface::class);
$cacheItem->isHit()
->willReturn(true);
$cacheItem->get()
->willReturn($this->testJwks1);

$cache = $this->prophesize(CacheItemPoolInterface::class);
$cache->getItem($this->testJwksUriKey)
->willReturn($cacheItem->reveal());
$cache->save(Argument::any())
->willReturn(true);

$cachedKeySet = new CachedKeySet(
$this->testJwksUri,
$this->prophesize(ClientInterface::class)->reveal(),
$this->prophesize(RequestFactoryInterface::class)->reveal(),
$cache->reveal()
);
$this->assertInstanceOf(Key::class, $cachedKeySet['foo']);
$this->assertSame('foo', $cachedKeySet['foo']->getAlgorithm());
}

public function testCachedKeyIdRefreshFromPreviousFormat()
{
$cacheItem = $this->prophesize(CacheItemInterface::class);
$cacheItem->isHit()
Expand Down Expand Up @@ -213,12 +274,18 @@ public function testJwtVerify()
$payload = ['sub' => 'foo', 'exp' => strtotime('+10 seconds')];
$msg = JWT::encode($payload, $privKey1, 'RS256', 'jwk1');

// format the cached value to match the expected format
$cachedJwks = [];
$rsaKeySet = file_get_contents(__DIR__ . '/data/rsa-jwkset.json');
foreach (json_decode($rsaKeySet, true)['keys'] as $k => $v) {
$cachedJwks[$v['kid']] = $v;
}

$cacheItem = $this->prophesize(CacheItemInterface::class);
$cacheItem->isHit()
->willReturn(true);
$cacheItem->get()
->willReturn(file_get_contents(__DIR__ . '/data/rsa-jwkset.json')
);
->willReturn($cachedJwks);

$cache = $this->prophesize(CacheItemPoolInterface::class);
$cache->getItem($this->testJwksUriKey)
Expand Down Expand Up @@ -297,7 +364,7 @@ public function provideFullIntegration()
{
return [
[$this->googleRsaUri],
// [$this->googleEcUri, 'LYyP2g']
[$this->googleEcUri, 'LYyP2g']
];
}

Expand Down

0 comments on commit 78d3ed1

Please sign in to comment.