diff --git a/src/__tests__/client.ts b/src/__tests__/client.ts index 906cd89..a77845e 100644 --- a/src/__tests__/client.ts +++ b/src/__tests__/client.ts @@ -19,13 +19,13 @@ checkEndpoints.push({ encoding: "json", }); -function withClient(f: (c: hrana.Client) => Promise): () => Promise { +function withClient(f: (c: hrana.Client) => Promise, config?: hrana.ClientConfig): () => Promise { return async () => { let client: hrana.Client; if (isWs) { - client = hrana.openWs(url, jwt, 3); + client = hrana.openWs(url, jwt, 3, config); } else if (isHttp) { - client = hrana.openHttp(url, jwt, undefined, 3); + client = hrana.openHttp(url, jwt, undefined, 3, config); } else { throw new Error("expected either ws or http URL"); } @@ -82,7 +82,7 @@ test("Stream.queryValue() without value", withClient(async (c) => { test("Stream.queryRow() with row", withClient(async (c) => { const s = c.openStream(); - + const res = await s.queryRow( "SELECT 1 AS one, 'elephant' AS two, 42.5 AS three, NULL as four"); expect(res.columnNames).toStrictEqual(["one", "two", "three", "four"]); @@ -102,7 +102,7 @@ test("Stream.queryRow() with row", withClient(async (c) => { test("Stream.queryRow() without row", withClient(async (c) => { const s = c.openStream(); - + const res = await s.queryValue("SELECT 1 AS one WHERE 0 = 1"); expect(res.value).toStrictEqual(undefined); expect(res.columnNames).toStrictEqual(["one"]); @@ -318,6 +318,63 @@ describe("returned integers", () => { }); }); +describe("returned booleans", () => { + const columnName = 'isActive'; + describe("booleans are JS integers", () => { + test('without config', withClient(async (c) => { + const s = c.openStream(); + await s.run("BEGIN"); + await s.run("DROP TABLE IF EXISTS t"); + await s.run(`CREATE TABLE t (id INTEGER PRIMARY KEY, ${columnName} BOOLEAN)`); + await s.run("INSERT INTO t VALUES (1, true)"); + await s.run("INSERT INTO t VALUES (2, false)"); + await s.run("COMMIT"); + + const resTrue = await s.queryRow(`SELECT ${columnName} FROM t WHERE id = 1`); + const valTrue = resTrue.row?.[columnName]; + expect(typeof valTrue).toStrictEqual("number"); + expect(valTrue).toStrictEqual(1); + + const resFalse = await s.queryRow(`SELECT ${columnName} FROM t WHERE id = 2`); + const valFalse = resFalse.row?.[columnName]; + expect(typeof valFalse).toStrictEqual("number"); + expect(valFalse).toStrictEqual(0); + })); + + test('with config', withClient(async (c) => { + const s = c.openStream(); + const resTrue = await s.queryRow(`SELECT ${columnName} FROM t WHERE id = 1`); + const valTrue = resTrue.row?.[columnName]; + expect(typeof valTrue).toStrictEqual("number"); + expect(valTrue).toStrictEqual(1); + + const resFalse = await s.queryRow(`SELECT ${columnName} FROM t WHERE id = 2`); + const valFalse = resFalse.row?.[columnName]; + expect(typeof valFalse).toStrictEqual("number"); + expect(valFalse).toStrictEqual(0); + }, { castBooleans: false })); + }); + + describe("booleans are JS booleans", () => { + test('with config', withClient(async (c) => { + const s = c.openStream(); + const resTrue = await s.queryRow(`SELECT ${columnName} FROM t WHERE id = 1`); + const valTrue = resTrue.row?.[columnName]; + expect(typeof valTrue).toStrictEqual("boolean"); + expect(valTrue).toStrictEqual(true); + + const resFalse = await s.queryRow(`SELECT ${columnName} FROM t WHERE id = 2`); + const valFalse = resFalse.row?.[columnName]; + expect(typeof valFalse).toStrictEqual("boolean"); + expect(valFalse).toStrictEqual(false); + + await s.run("BEGIN"); + await s.run("DROP TABLE t"); + await s.run("COMMIT"); + }, { castBooleans: true })); + }); +}); + test("response error", withClient(async (c) => { const s = c.openStream(); await expect(s.queryValue("SELECT")).rejects.toBeInstanceOf(hrana.ResponseError); @@ -539,7 +596,7 @@ for (const useCursor of [false, true]) { test("failing statement", withClient(async (c) => { if (useCursor) { await c.getVersion(); } const s = c.openStream(); - + const batch = s.batch(useCursor); const prom1 = batch.step().queryValue("SELECT 1"); const prom2 = batch.step().queryValue("SELECT foobar"); @@ -904,7 +961,7 @@ for (const useCursor of [false, true]) { for (const useCursor of [false, true]) { (version >= 3 || !useCursor ? test : test.skip)( useCursor ? "batch w/ cursor" : "batch w/o cursor", - withSqlOwner(async (s, owner) => + withSqlOwner(async (s, owner) => { const sql1 = owner.storeSql("SELECT 11"); const sql2 = owner.storeSql("SELECT 'one', 'two'"); diff --git a/src/batch.ts b/src/batch.ts index 8000aff..8f6eda5 100644 --- a/src/batch.ts +++ b/src/batch.ts @@ -1,5 +1,5 @@ +import { ClientConfig } from "./client.js"; import { ProtoError, MisuseError } from "./errors.js"; -import { IdAlloc } from "./id_alloc.js"; import type { RowsResult, RowResult, ValueResult, StmtResult } from "./result.js"; import { stmtResultFromProto, rowsResultFromProto, @@ -11,8 +11,7 @@ import type { InStmt } from "./stmt.js"; import { stmtToProto } from "./stmt.js"; import { Stream } from "./stream.js"; import { impossible } from "./util.js"; -import type { Value, InValue, IntMode } from "./value.js"; -import { valueToProto, valueFromProto } from "./value.js"; +import type { IntMode } from "./value.js"; /** A builder for creating a batch and executing it on the server. */ export class Batch { @@ -205,7 +204,7 @@ export class BatchStep { #add( inStmt: InStmt, wantRows: boolean, - fromProto: (result: proto.StmtResult, intMode: IntMode) => T, + fromProto: (result: proto.StmtResult, intMode: IntMode, config: ClientConfig) => T, ): Promise { if (this._index !== undefined) { throw new MisuseError("This BatchStep has already been added to the batch"); @@ -234,7 +233,7 @@ export class BatchStep { } else if (stepError !== undefined) { errorCallback(errorFromProto(stepError)); } else if (stepResult !== undefined) { - outputCallback(fromProto(stepResult, this._batch._stream.intMode)); + outputCallback(fromProto(stepResult, this._batch._stream.intMode, this._batch._stream.config)); } else { outputCallback(undefined); } @@ -282,7 +281,7 @@ export class BatchCond { return new BatchCond(cond._batch, {type: "not", cond: cond._proto}); } - /** Create a condition that is a logical AND of other conditions. + /** Create a condition that is a logical AND of other conditions. */ static and(batch: Batch, conds: Array): BatchCond { for (const cond of conds) { @@ -291,7 +290,7 @@ export class BatchCond { return new BatchCond(batch, {type: "and", conds: conds.map(e => e._proto)}); } - /** Create a condition that is a logical OR of other conditions. + /** Create a condition that is a logical OR of other conditions. */ static or(batch: Batch, conds: Array): BatchCond { for (const cond of conds) { diff --git a/src/client.ts b/src/client.ts index 731bb99..a180503 100644 --- a/src/client.ts +++ b/src/client.ts @@ -3,11 +3,15 @@ import type { IntMode } from "./value.js"; export type ProtocolVersion = 1 | 2 | 3; export type ProtocolEncoding = "json" | "protobuf"; +export type ClientConfig = { + castBooleans?: boolean; +}; /** A client for the Hrana protocol (a "database connection pool"). */ export abstract class Client { /** @private */ - constructor() { + constructor(config: ClientConfig) { + this.config = config; this.intMode = "number"; } @@ -36,4 +40,7 @@ export abstract class Client { * override the integer mode for every stream by setting {@link Stream.intMode} on the stream. */ intMode: IntMode; + + /** Stores the client configuration. See {@link ClientConfig}. */ + config: ClientConfig; } diff --git a/src/http/client.ts b/src/http/client.ts index 6b1d56d..acf74da 100644 --- a/src/http/client.ts +++ b/src/http/client.ts @@ -1,6 +1,6 @@ import { fetch, Request } from "@libsql/isomorphic-fetch"; -import type { ProtocolVersion, ProtocolEncoding } from "../client.js"; +import type { ProtocolVersion, ProtocolEncoding, ClientConfig } from "../client.js"; import { Client } from "../client.js"; import { ClientError, ClosedError, ProtocolVersionError } from "../errors.js"; @@ -56,8 +56,8 @@ export class HttpClient extends Client { _endpoint: Endpoint | undefined; /** @private */ - constructor(url: URL, jwt: string | undefined, customFetch: unknown | undefined, protocolVersion: ProtocolVersion = 2) { - super(); + constructor(url: URL, jwt: string | undefined, customFetch: unknown | undefined, protocolVersion: ProtocolVersion = 2, config: ClientConfig) { + super(config); this.#url = url; this.#jwt = jwt; this.#fetch = (customFetch as typeof fetch) ?? fetch; diff --git a/src/http/stream.ts b/src/http/stream.ts index 154ae0c..818a713 100644 --- a/src/http/stream.ts +++ b/src/http/stream.ts @@ -66,7 +66,7 @@ export class HttpStream extends Stream implements SqlOwner { /** @private */ constructor(client: HttpClient, baseUrl: URL, jwt: string | undefined, customFetch: typeof fetch) { - super(client.intMode); + super(client.intMode, client.config); this.#client = client; this.#baseUrl = baseUrl.toString(); this.#jwt = jwt; diff --git a/src/index.ts b/src/index.ts index c46ac4c..6b79f92 100644 --- a/src/index.ts +++ b/src/index.ts @@ -5,13 +5,13 @@ import { WebSocketUnsupportedError } from "./errors.js"; import { HttpClient } from "./http/client.js"; import { WsClient } from "./ws/client.js"; -import { ProtocolVersion } from "./client.js"; +import { ClientConfig, ProtocolVersion } from "./client.js"; export { WebSocket } from "@libsql/isomorphic-ws"; export type { RequestInit, Response } from "@libsql/isomorphic-fetch"; export { fetch, Request, Headers } from "@libsql/isomorphic-fetch"; -export type { ProtocolVersion, ProtocolEncoding } from "./client.js"; +export type { ClientConfig, ProtocolVersion, ProtocolEncoding } from "./client.js"; export { Client } from "./client.js"; export type { DescribeResult, DescribeColumn } from "./describe.js"; export * from "./errors.js"; @@ -32,7 +32,7 @@ export { WsClient } from "./ws/client.js"; export { WsStream } from "./ws/stream.js"; /** Open a Hrana client over WebSocket connected to the given `url`. */ -export function openWs(url: string | URL, jwt?: string, protocolVersion: ProtocolVersion = 2): WsClient { +export function openWs(url: string | URL, jwt?: string, protocolVersion: ProtocolVersion = 2, config: ClientConfig = {}): WsClient { if (typeof WebSocket === "undefined") { throw new WebSocketUnsupportedError("WebSockets are not supported in this environment"); } @@ -43,7 +43,7 @@ export function openWs(url: string | URL, jwt?: string, protocolVersion: Protoco subprotocols = Array.from(subprotocolsV2.keys()); } const socket = new WebSocket(url, subprotocols); - return new WsClient(socket, jwt); + return new WsClient(socket, jwt, config); } /** Open a Hrana client over HTTP connected to the given `url`. @@ -52,6 +52,6 @@ export function openWs(url: string | URL, jwt?: string, protocolVersion: Protoco * from `@libsql/isomorphic-fetch`. This function is always called with a `Request` object from * `@libsql/isomorphic-fetch`. */ -export function openHttp(url: string | URL, jwt?: string, customFetch?: unknown | undefined, protocolVersion: ProtocolVersion = 2): HttpClient { - return new HttpClient(url instanceof URL ? url : new URL(url), jwt, customFetch, protocolVersion); +export function openHttp(url: string | URL, jwt?: string, customFetch?: unknown | undefined, protocolVersion: ProtocolVersion = 2, config: ClientConfig = {}): HttpClient { + return new HttpClient(url instanceof URL ? url : new URL(url), jwt, customFetch, protocolVersion, config); } diff --git a/src/result.ts b/src/result.ts index d4b333c..b1da854 100644 --- a/src/result.ts +++ b/src/result.ts @@ -1,3 +1,4 @@ +import { ClientConfig } from "./client.js"; import { ClientError, ProtoError, ResponseError } from "./errors.js"; import type * as proto from "./shared/proto.js"; import type { Value, IntMode } from "./value.js"; @@ -52,17 +53,17 @@ export function stmtResultFromProto(result: proto.StmtResult): StmtResult { }; } -export function rowsResultFromProto(result: proto.StmtResult, intMode: IntMode): RowsResult { +export function rowsResultFromProto(result: proto.StmtResult, intMode: IntMode, config: ClientConfig): RowsResult { const stmtResult = stmtResultFromProto(result); - const rows = result.rows.map(row => rowFromProto(stmtResult.columnNames, row, intMode)); + const rows = result.rows.map(row => rowFromProto(stmtResult.columnNames, row, intMode, stmtResult.columnDecltypes, config)); return {...stmtResult, rows}; } -export function rowResultFromProto(result: proto.StmtResult, intMode: IntMode): RowResult { +export function rowResultFromProto(result: proto.StmtResult, intMode: IntMode, config: ClientConfig): RowResult { const stmtResult = stmtResultFromProto(result); let row: Row | undefined; if (result.rows.length > 0) { - row = rowFromProto(stmtResult.columnNames, result.rows[0], intMode); + row = rowFromProto(stmtResult.columnNames, result.rows[0], intMode, stmtResult.columnDecltypes, config); } return {...stmtResult, row}; } @@ -71,6 +72,7 @@ export function valueResultFromProto(result: proto.StmtResult, intMode: IntMode) const stmtResult = stmtResultFromProto(result); let value: Value | undefined; if (result.rows.length > 0 && stmtResult.columnNames.length > 0) { + // TODO: How do we solve this? AFAICS we don't have column data when fetching a single value, so we don't know when to cast ints to booleans value = valueFromProto(result.rows[0][0], intMode); } return {...stmtResult, value}; @@ -80,12 +82,14 @@ function rowFromProto( colNames: Array, values: Array, intMode: IntMode, + colDecltypes: Array, + config: ClientConfig ): Row { const row = {}; // make sure that the "length" property is not enumerable Object.defineProperty(row, "length", { value: values.length }); for (let i = 0; i < values.length; ++i) { - const value = valueFromProto(values[i], intMode); + const value = valueFromProto(values[i], intMode, colDecltypes[i], config.castBooleans); Object.defineProperty(row, i, { value }); const colName = colNames[i]; diff --git a/src/stream.ts b/src/stream.ts index 950aeb5..dd0f05d 100644 --- a/src/stream.ts +++ b/src/stream.ts @@ -1,5 +1,5 @@ import { Batch } from "./batch.js"; -import type { Client } from "./client.js"; +import type { Client, ClientConfig } from "./client.js"; import type { Cursor } from "./cursor.js"; import type { DescribeResult } from "./describe.js"; import { describeResultFromProto } from "./describe.js"; @@ -18,8 +18,9 @@ import type { IntMode } from "./value.js"; /** A stream for executing SQL statements (a "database connection"). */ export abstract class Stream { /** @private */ - constructor(intMode: IntMode) { + constructor(intMode: IntMode, config: ClientConfig) { this.intMode = intMode; + this.config = config; } /** Get the client object that this stream belongs to. */ @@ -61,10 +62,10 @@ export abstract class Stream { #execute( inStmt: InStmt, wantRows: boolean, - fromProto: (result: proto.StmtResult, intMode: IntMode) => T, + fromProto: (result: proto.StmtResult, intMode: IntMode, config: ClientConfig) => T, ): Promise { const stmt = stmtToProto(this._sqlOwner(), inStmt, wantRows); - return this._execute(stmt).then((r) => fromProto(r, this.intMode)); + return this._execute(stmt).then((r) => fromProto(r, this.intMode, this.config)); } /** Return a builder for creating and executing a batch. @@ -120,4 +121,7 @@ export abstract class Stream { * This value affects the results of all operations on this stream. */ intMode: IntMode; + + /** Stores the client configuration. */ + config: ClientConfig; } diff --git a/src/value.ts b/src/value.ts index 2fc5a70..b1e6ec2 100644 --- a/src/value.ts +++ b/src/value.ts @@ -1,4 +1,4 @@ -import { ClientError, ProtoError, MisuseError } from "./errors.js"; +import { ProtoError, MisuseError } from "./errors.js"; import type * as proto from "./shared/proto.js"; import { impossible } from "./util.js"; @@ -7,6 +7,7 @@ export type Value = | null | string | number + | boolean | bigint | ArrayBuffer @@ -65,7 +66,7 @@ export function valueToProto(value: InValue): proto.Value { const minInteger = -9223372036854775808n; const maxInteger = 9223372036854775807n; -export function valueFromProto(value: proto.Value, intMode: IntMode): Value { +export function valueFromProto(value: proto.Value, intMode: IntMode, colDecltype?: string, castBooleans?: boolean): Value { if (value === null) { return null; } else if (typeof value === "number") { @@ -73,6 +74,9 @@ export function valueFromProto(value: proto.Value, intMode: IntMode): Value { } else if (typeof value === "string") { return value; } else if (typeof value === "bigint") { + if (castBooleans && colDecltype?.toLowerCase() === 'boolean') { + return Boolean(value); + } if (intMode === "number") { const num = Number(value); if (!Number.isSafeInteger(num)) { diff --git a/src/ws/client.ts b/src/ws/client.ts index 784ee92..5e35c44 100644 --- a/src/ws/client.ts +++ b/src/ws/client.ts @@ -1,6 +1,6 @@ import { WebSocket } from "@libsql/isomorphic-ws"; -import type { ProtocolVersion, ProtocolEncoding } from "../client.js"; +import type { ProtocolVersion, ProtocolEncoding, ClientConfig } from "../client.js"; import { Client } from "../client.js"; import { readJsonObject, writeJsonObject, readProtobufMessage, writeProtobufMessage, @@ -73,8 +73,8 @@ export class WsClient extends Client implements SqlOwner { #sqlIdAlloc: IdAlloc; /** @private */ - constructor(socket: WebSocket, jwt: string | undefined) { - super(); + constructor(socket: WebSocket, jwt: string | undefined, config: ClientConfig) { + super(config); this.#socket = socket; this.#openCallbacks = []; @@ -268,7 +268,7 @@ export class WsClient extends Client implements SqlOwner { } else { throw impossible(encoding, "Impossible encoding"); } - + this.#handleMsg(msg); } catch (e) { this.#socket.close(3007, "Could not handle message"); diff --git a/src/ws/stream.ts b/src/ws/stream.ts index a5d816f..8b8c578 100644 --- a/src/ws/stream.ts +++ b/src/ws/stream.ts @@ -47,7 +47,7 @@ export class WsStream extends Stream { /** @private */ constructor(client: WsClient, streamId: number) { - super(client.intMode); + super(client.intMode, client.config); this.#client = client; this.#streamId = streamId;