Skip to content

Commit

Permalink
feat: validate incoming host header (#163)
Browse files Browse the repository at this point in the history
* feat(request): add RequestError class and its utility function

* fix: return 400 if host header is missing or contains invalid characters

* feat: validate incoming host header

* test: add tests for invalid host header

* feat: use options.hostname as default hostname for request
  • Loading branch information
usualoma authored Apr 19, 2024
1 parent d61c8ec commit 306d98f
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 32 deletions.
1 change: 1 addition & 0 deletions src/index.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
export { serve, createAdaptorServer } from './server'
export { getRequestListener } from './listener'
export { RequestError } from './request'
export type { HttpBindings, Http2Bindings } from './types'
22 changes: 18 additions & 4 deletions src/listener.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import type { IncomingMessage, ServerResponse, OutgoingHttpHeaders } from 'node:http'
import type { Http2ServerRequest, Http2ServerResponse } from 'node:http2'
import { getAbortController, newRequest, Request as LightweightRequest } from './request'
import {
getAbortController,
newRequest,
Request as LightweightRequest,
toRequestError,
} from './request'
import { cacheKey, getInternalBody, Response as LightweightResponse } from './response'
import type { CustomErrorHandler, FetchCallback, HttpBindings } from './types'
import { writeFromReadableStream, buildOutgoingHttpHeaders } from './utils'
Expand All @@ -10,6 +15,11 @@ import './globals'
const regBuffer = /^no$/i
const regContentType = /^(application\/json\b|text\/(?!event-stream\b))/i

const handleRequestError = (): Response =>
new Response(null, {
status: 400,
})

const handleFetchError = (e: unknown): Response =>
new Response(null, {
status:
Expand Down Expand Up @@ -140,6 +150,7 @@ const responseViaResponseObject = async (
export const getRequestListener = (
fetchCallback: FetchCallback,
options: {
hostname?: string
errorHandler?: CustomErrorHandler
overrideGlobalObjects?: boolean
} = {}
Expand All @@ -157,12 +168,13 @@ export const getRequestListener = (
incoming: IncomingMessage | Http2ServerRequest,
outgoing: ServerResponse | Http2ServerResponse
) => {
let res
// eslint-disable-next-line @typescript-eslint/no-explicit-any
let res, req: any

try {
// `fetchCallback()` requests a Request object, but global.Request is expensive to generate,
// so generate a pseudo Request object with only the minimum required information.
const req = newRequest(incoming)
req = newRequest(incoming, options.hostname)

// Detect if request was aborted.
outgoing.on('close', () => {
Expand All @@ -181,10 +193,12 @@ export const getRequestListener = (
} catch (e: unknown) {
if (!res) {
if (options.errorHandler) {
res = await options.errorHandler(e)
res = await options.errorHandler(req ? e : toRequestError(e))
if (!res) {
return
}
} else if (!req) {
res = handleRequestError()
} else {
res = handleFetchError(e)
}
Expand Down
48 changes: 42 additions & 6 deletions src/request.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,25 @@ import { Http2ServerRequest } from 'node:http2'
import { Readable } from 'node:stream'
import type { TLSSocket } from 'node:tls'

export class RequestError extends Error {
static name = 'RequestError'
constructor(
message: string,
options?: {
cause?: unknown
}
) {
super(message, options)
}
}

export const toRequestError = (e: unknown): RequestError => {
if (e instanceof RequestError) {
return e
}
return new RequestError((e as Error).message, { cause: e })
}

export const GlobalRequest = global.Request
export class Request extends GlobalRequest {
constructor(input: string | Request, options?: RequestInit) {
Expand Down Expand Up @@ -111,18 +130,35 @@ const requestPrototype: Record<string | symbol, any> = {
})
Object.setPrototypeOf(requestPrototype, Request.prototype)

export const newRequest = (incoming: IncomingMessage | Http2ServerRequest) => {
export const newRequest = (
incoming: IncomingMessage | Http2ServerRequest,
defaultHostname?: string
) => {
const req = Object.create(requestPrototype)
req[incomingKey] = incoming
req[urlKey] = new URL(

const host =
(incoming instanceof Http2ServerRequest ? incoming.authority : incoming.headers.host) ||
defaultHostname
if (!host) {
throw new RequestError('Missing host header')
}
const url = new URL(
`${
incoming instanceof Http2ServerRequest ||
(incoming.socket && (incoming.socket as TLSSocket).encrypted)
? 'https'
: 'http'
}://${incoming instanceof Http2ServerRequest ? incoming.authority : incoming.headers.host}${
incoming.url
}`
).href
}://${host}${incoming.url}`
)

// check by length for performance.
// if suspicious, check by host. host header sometimes contains port.
if (url.hostname.length !== host.length && url.hostname !== host.replace(/:\d+$/, '')) {
throw new RequestError('Invalid host header')
}

req[urlKey] = url.href

return req
}
1 change: 1 addition & 0 deletions src/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import type { Options, ServerType } from './types'
export const createAdaptorServer = (options: Options): ServerType => {
const fetchCallback = options.fetch
const requestListener = getRequestListener(fetchCallback, {
hostname: options.hostname,
overrideGlobalObjects: options.overrideGlobalObjects,
})
// ts will complain about createServerHTTP and createServerHTTP2 not being callable, which works just fine
Expand Down
68 changes: 54 additions & 14 deletions test/listener.test.ts
Original file line number Diff line number Diff line change
@@ -1,28 +1,68 @@
import { createServer } from 'node:http'
import request from 'supertest'
import { getRequestListener } from '../src/listener'
import { GlobalRequest, Request as LightweightRequest } from '../src/request'
import { GlobalRequest, Request as LightweightRequest, RequestError } from '../src/request'
import { GlobalResponse, Response as LightweightResponse } from '../src/response'

describe('Invalid request', () => {
const requestListener = getRequestListener(jest.fn())
const server = createServer(async (req, res) => {
await requestListener(req, res)
describe('default error handler', () => {
const requestListener = getRequestListener(jest.fn())
const server = createServer(requestListener)

if (!res.writableEnded) {
res.writeHead(500, { 'Content-Type': 'text/plain' })
res.end('error handler did not return a response')
}
it('Should return server error for a request w/o host header', async () => {
const res = await request(server).get('/').set('Host', '').send()
expect(res.status).toBe(400)
})

it('Should return server error for a request invalid host header', async () => {
const res = await request(server).get('/').set('Host', 'a b').send()
expect(res.status).toBe(400)
})
})

it('Should return server error for a request w/o host header', async () => {
const res = await request(server).get('/').set('Host', '').send()
expect(res.status).toBe(500)
describe('custom error handler', () => {
const requestListener = getRequestListener(jest.fn(), {
errorHandler: (e) => {
if (e instanceof RequestError) {
return new Response(e.message, { status: 400 })
} else {
return new Response('unknown error', { status: 500 })
}
},
})
const server = createServer(requestListener)

it('Should return server error for a request w/o host header', async () => {
const res = await request(server).get('/').set('Host', '').send()
expect(res.status).toBe(400)
})

it('Should return server error for a request invalid host header', async () => {
const res = await request(server).get('/').set('Host', 'a b').send()
expect(res.status).toBe(400)
})

it('Should return server error for host header with path', async () => {
const res = await request(server).get('/').set('Host', 'a/b').send()
expect(res.status).toBe(400)
})
})

it('Should return server error for a request invalid host header', async () => {
const res = await request(server).get('/').set('Host', 'a b').send()
expect(res.status).toBe(500)
describe('default hostname', () => {
const requestListener = getRequestListener(() => new Response('ok'), {
hostname: 'example.com',
})
const server = createServer(requestListener)

it('Should return 200 for a request w/o host header', async () => {
const res = await request(server).get('/').set('Host', '').send()
expect(res.status).toBe(200)
})

it('Should return server error for a request invalid host header', async () => {
const res = await request(server).get('/').set('Host', 'a b').send()
expect(res.status).toBe(400)
})
})
})

Expand Down
50 changes: 42 additions & 8 deletions test/request.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import {
Request as LightweightRequest,
GlobalRequest,
getAbortController,
RequestError,
} from '../src/request'

Object.defineProperty(global, 'Request', {
Expand Down Expand Up @@ -40,30 +41,30 @@ describe('Request', () => {
expect(req.url).toBe('http://localhost/foo.txt')
})

it('Should resolve double dots in host header', async () => {
it('Should accept hostname and port in host header', async () => {
const req = newRequest({
headers: {
host: 'localhost/..',
host: 'localhost:8080',
},
url: '/foo.txt',
url: '/static/../foo.txt',
} as IncomingMessage)
expect(req).toBeInstanceOf(global.Request)
expect(req.url).toBe('http://localhost/foo.txt')
expect(req.url).toBe('http://localhost:8080/foo.txt')
})

it('should generate only one `AbortController` per `Request` object created', async () => {
const req = newRequest({
headers: {
host: 'localhost/..',
host: 'localhost',
},
rawHeaders: ['host', 'localhost/..'],
rawHeaders: ['host', 'localhost'],
url: '/foo.txt',
} as IncomingMessage)
const req2 = newRequest({
headers: {
host: 'localhost/..',
host: 'localhost',
},
rawHeaders: ['host', 'localhost/..'],
rawHeaders: ['host', 'localhost'],
url: '/foo.txt',
} as IncomingMessage)

Expand All @@ -78,6 +79,39 @@ describe('Request', () => {
expect(z).not.toBe(x)
expect(z).not.toBe(y)
})

it('Should throw error if host header contains path', async () => {
expect(() => {
newRequest({
headers: {
host: 'localhost/..',
},
url: '/foo.txt',
} as IncomingMessage)
}).toThrow(RequestError)
})

it('Should throw error if host header is empty', async () => {
expect(() => {
newRequest({
headers: {
host: '',
},
url: '/foo.txt',
} as IncomingMessage)
}).toThrow(RequestError)
})

it('Should throw error if host header contains query parameter', async () => {
expect(() => {
newRequest({
headers: {
host: 'localhost?foo=bar',
},
url: '/foo.txt',
} as IncomingMessage)
}).toThrow(RequestError)
})
})

describe('GlobalRequest', () => {
Expand Down

0 comments on commit 306d98f

Please sign in to comment.