Skip to content

Commit

Permalink
Allow cors for static files
Browse files Browse the repository at this point in the history
  • Loading branch information
balloob committed Jul 24, 2019
1 parent 10b120f commit e5d7bee
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 4 deletions.
8 changes: 4 additions & 4 deletions homeassistant/components/http/cors.py
@@ -1,5 +1,5 @@
"""Provide CORS support for the HTTP component."""
from aiohttp.web_urldispatcher import Resource, ResourceRoute
from aiohttp.web_urldispatcher import Resource, ResourceRoute, StaticResource
from aiohttp.hdrs import ACCEPT, CONTENT_TYPE, ORIGIN, AUTHORIZATION

from homeassistant.const import (
Expand All @@ -9,7 +9,7 @@
ALLOWED_CORS_HEADERS = [
ORIGIN, ACCEPT, HTTP_HEADER_X_REQUESTED_WITH, CONTENT_TYPE,
HTTP_HEADER_HA_AUTH, AUTHORIZATION]
VALID_CORS_TYPES = (Resource, ResourceRoute)
VALID_CORS_TYPES = (Resource, ResourceRoute, StaticResource)


@callback
Expand Down Expand Up @@ -56,7 +56,7 @@ def _allow_cors(route, config=None):

async def cors_startup(app):
"""Initialize CORS when app starts up."""
for route in list(app.router.routes()):
_allow_cors(route)
for resource in list(app.router.resources()):
_allow_cors(resource)

app.on_startup.append(cors_startup)
20 changes: 20 additions & 0 deletions tests/components/http/test_cors.py
@@ -1,4 +1,5 @@
"""Test cors for the HTTP component."""
from pathlib import Path
from unittest.mock import patch

from aiohttp import web
Expand Down Expand Up @@ -152,3 +153,22 @@ async def test_cors_works_with_frontend(hass, hass_client):
client = await hass_client()
resp = await client.get('/')
assert resp.status == 200


async def test_cors_on_static_files(hass, hass_client):
"""Test that we enable CORS for static files."""
assert await async_setup_component(hass, 'frontend', {
'http': {
'cors_allowed_origins': ['http://www.example.com']
}
})
hass.http.register_static_path('/something', Path(__file__).parent)

client = await hass_client()
resp = await client.options('/something/__init__.py', headers={
'origin': 'http://www.example.com',
ACCESS_CONTROL_REQUEST_METHOD: 'GET',
})
assert resp.status == 200
assert resp.headers[ACCESS_CONTROL_ALLOW_ORIGIN] == \
'http://www.example.com'

0 comments on commit e5d7bee

Please sign in to comment.