Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions lib/DBSQLClient.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ import { EventEmitter } from 'events';
import TCLIService from '../thrift/TCLIService';
import { TProtocolVersion } from '../thrift/TCLIService_types';
import IDBSQLClient, { ClientOptions, ConnectionOptions, OpenSessionRequest } from './contracts/IDBSQLClient';
import IDriver from './contracts/IDriver';
import IClientContext from './contracts/IClientContext';
import HiveDriver from './hive/HiveDriver';
import { Int64 } from './hive/Types';
import DBSQLSession from './DBSQLSession';
Expand Down Expand Up @@ -41,13 +43,17 @@ function getInitialNamespaceOptions(catalogName?: string, schemaName?: string) {
};
}

export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
export default class DBSQLClient extends EventEmitter implements IDBSQLClient, IClientContext {
private connectionProvider?: IConnectionProvider;

private authProvider?: IAuthentication;

private client?: TCLIService.Client;

private readonly driver = new HiveDriver({
context: this,
});

private readonly logger: IDBSQLLogger;

private readonly thrift = thrift;
Expand All @@ -73,7 +79,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
};
}

private getAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
private initAuthProvider(options: ConnectionOptions, authProvider?: IAuthentication): IAuthentication {
if (authProvider) {
return authProvider;
}
Expand All @@ -84,15 +90,16 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
return new PlainHttpAuthentication({
username: 'token',
password: options.token,
context: this,
});
case 'databricks-oauth':
return new DatabricksOAuth({
host: options.host,
logger: this.logger,
persistence: options.persistence,
azureTenantId: options.azureTenantId,
clientId: options.oauthClientId,
clientSecret: options.oauthClientSecret,
context: this,
});
case 'custom':
return options.provider;
Expand All @@ -110,7 +117,7 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
* const session = client.connect({host, path, token});
*/
public async connect(options: ConnectionOptions, authProvider?: IAuthentication): Promise<IDBSQLClient> {
this.authProvider = this.getAuthProvider(options, authProvider);
this.authProvider = this.initAuthProvider(options, authProvider);

this.connectionProvider = new HttpConnection(this.getConnectionOptions(options));

Expand Down Expand Up @@ -156,44 +163,57 @@ export default class DBSQLClient extends EventEmitter implements IDBSQLClient {
* const session = await client.openSession();
*/
public async openSession(request: OpenSessionRequest = {}): Promise<IDBSQLSession> {
const driver = new HiveDriver(() => this.getClient());

const response = await driver.openSession({
const response = await this.driver.openSession({
client_protocol_i64: new Int64(TProtocolVersion.SPARK_CLI_SERVICE_PROTOCOL_V8),
...getInitialNamespaceOptions(request.initialCatalog, request.initialSchema),
});

Status.assert(response.status);
const session = new DBSQLSession(driver, definedOrError(response.sessionHandle), {
logger: this.logger,
const session = new DBSQLSession({
handle: definedOrError(response.sessionHandle),
context: this,
});
this.sessions.add(session);
return session;
}

private async getClient() {
public async close(): Promise<void> {
await this.sessions.closeAll();

this.client = undefined;
this.connectionProvider = undefined;
this.authProvider = undefined;
}

public getLogger(): IDBSQLLogger {
return this.logger;
}

public async getConnectionProvider(): Promise<IConnectionProvider> {
if (!this.connectionProvider) {
throw new HiveDriverError('DBSQLClient: not connected');
}

return this.connectionProvider;
}

public async getClient(): Promise<TCLIService.Client> {
const connectionProvider = await this.getConnectionProvider();

if (!this.client) {
this.logger.log(LogLevel.info, 'DBSQLClient: initializing thrift client');
this.client = this.thrift.createClient(TCLIService, await this.connectionProvider.getThriftConnection());
this.client = this.thrift.createClient(TCLIService, await connectionProvider.getThriftConnection());
}

if (this.authProvider) {
const authHeaders = await this.authProvider.authenticate();
this.connectionProvider.setHeaders(authHeaders);
connectionProvider.setHeaders(authHeaders);
}

return this.client;
}

public async close(): Promise<void> {
await this.sessions.closeAll();

this.client = undefined;
this.connectionProvider = undefined;
this.authProvider = undefined;
public async getDriver(): Promise<IDriver> {
return this.driver;
}
}
11 changes: 6 additions & 5 deletions lib/DBSQLOperation/FetchResultsHelper.ts
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import {
TRowSet,
} from '../../thrift/TCLIService_types';
import { ColumnCode, FetchType, Int64 } from '../hive/Types';
import HiveDriver from '../hive/HiveDriver';
import Status from '../dto/Status';
import IClientContext from '../contracts/IClientContext';

function checkIfOperationHasMoreRows(response: TFetchResultsResp): boolean {
if (response.hasMoreRows) {
Expand Down Expand Up @@ -36,7 +36,7 @@ function checkIfOperationHasMoreRows(response: TFetchResultsResp): boolean {
}

export default class FetchResultsHelper {
private readonly driver: HiveDriver;
private readonly context: IClientContext;

private readonly operationHandle: TOperationHandle;

Expand All @@ -49,12 +49,12 @@ export default class FetchResultsHelper {
public hasMoreRows: boolean = false;

constructor(
driver: HiveDriver,
context: IClientContext,
operationHandle: TOperationHandle,
prefetchedResults: Array<TFetchResultsResp | undefined>,
returnOnlyPrefetchedResults: boolean,
) {
this.driver = driver;
this.context = context;
this.operationHandle = operationHandle;
prefetchedResults.forEach((item) => {
if (item) {
Expand Down Expand Up @@ -85,7 +85,8 @@ export default class FetchResultsHelper {
return this.processFetchResponse(prefetchedResponse);
}

const response = await this.driver.fetchResults({
const driver = await this.context.getDriver();
const response = await driver.fetchResults({
operationHandle: this.operationHandle,
orientation: this.fetchOrientation,
maxRows: new Int64(maxRows),
Expand Down
68 changes: 34 additions & 34 deletions lib/DBSQLOperation/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import IOperation, {
GetSchemaOptions,
WaitUntilReadyOptions,
} from '../contracts/IOperation';
import HiveDriver from '../hive/HiveDriver';
import {
TGetOperationStatusResp,
TOperationHandle,
Expand All @@ -18,19 +17,22 @@ import {
} from '../../thrift/TCLIService_types';
import Status from '../dto/Status';
import FetchResultsHelper from './FetchResultsHelper';
import IDBSQLLogger, { LogLevel } from '../contracts/IDBSQLLogger';
import { LogLevel } from '../contracts/IDBSQLLogger';
import OperationStateError, { OperationStateErrorCode } from '../errors/OperationStateError';
import IOperationResult from '../result/IOperationResult';
import JsonResult from '../result/JsonResult';
import ArrowResult from '../result/ArrowResult';
import CloudFetchResult from '../result/CloudFetchResult';
import { definedOrError } from '../utils';
import HiveDriverError from '../errors/HiveDriverError';
import IClientContext from '../contracts/IClientContext';

const defaultMaxRows = 100000;

interface DBSQLOperationConstructorOptions {
logger: IDBSQLLogger;
handle: TOperationHandle;
directResults?: TSparkDirectResults;
context: IClientContext;
}

async function delay(ms?: number): Promise<void> {
Expand All @@ -42,12 +44,10 @@ async function delay(ms?: number): Promise<void> {
}

export default class DBSQLOperation implements IOperation {
private readonly driver: HiveDriver;
private readonly context: IClientContext;

private readonly operationHandle: TOperationHandle;

private readonly logger: IDBSQLLogger;

public onClose?: () => void;

private readonly _data: FetchResultsHelper;
Expand All @@ -70,32 +70,26 @@ export default class DBSQLOperation implements IOperation {

private resultHandler?: IOperationResult;

constructor(
driver: HiveDriver,
operationHandle: TOperationHandle,
{ logger }: DBSQLOperationConstructorOptions,
directResults?: TSparkDirectResults,
) {
this.driver = driver;
this.operationHandle = operationHandle;
this.logger = logger;
constructor({ handle, directResults, context }: DBSQLOperationConstructorOptions) {
this.operationHandle = handle;
this.context = context;

const useOnlyPrefetchedResults = Boolean(directResults?.closeOperation);

this.hasResultSet = operationHandle.hasResultSet;
this.hasResultSet = this.operationHandle.hasResultSet;
if (directResults?.operationStatus) {
this.processOperationStatusResponse(directResults.operationStatus);
}

this.metadata = directResults?.resultSetMetadata;
this._data = new FetchResultsHelper(
this.driver,
this.context,
this.operationHandle,
[directResults?.resultSet],
useOnlyPrefetchedResults,
);
this.closeOperation = directResults?.closeOperation;
this.logger.log(LogLevel.debug, `Operation created with id: ${this.getId()}`);
this.context.getLogger().log(LogLevel.debug, `Operation created with id: ${this.getId()}`);
}

public getId() {
Expand All @@ -118,7 +112,7 @@ export default class DBSQLOperation implements IOperation {
const chunk = await this.fetchChunk(options);
data.push(chunk);
} while (await this.hasMoreRows()); // eslint-disable-line no-await-in-loop
this.logger?.log(LogLevel.debug, `Fetched all data from operation with id: ${this.getId()}`);
this.context.getLogger().log(LogLevel.debug, `Fetched all data from operation with id: ${this.getId()}`);

return data.flat();
}
Expand Down Expand Up @@ -149,10 +143,12 @@ export default class DBSQLOperation implements IOperation {
await this.failIfClosed();

const result = await resultHandler.getValue(data ? [data] : []);
this.logger?.log(
LogLevel.debug,
`Fetched chunk of size: ${options?.maxRows || defaultMaxRows} from operation with id: ${this.getId()}`,
);
this.context
.getLogger()
.log(
LogLevel.debug,
`Fetched chunk of size: ${options?.maxRows || defaultMaxRows} from operation with id: ${this.getId()}`,
);
return result;
}

Expand All @@ -163,13 +159,14 @@ export default class DBSQLOperation implements IOperation {
*/
public async status(progress: boolean = false): Promise<TGetOperationStatusResp> {
await this.failIfClosed();
this.logger?.log(LogLevel.debug, `Fetching status for operation with id: ${this.getId()}`);
this.context.getLogger().log(LogLevel.debug, `Fetching status for operation with id: ${this.getId()}`);

if (this.operationStatus) {
return this.operationStatus;
}

const response = await this.driver.getOperationStatus({
const driver = await this.context.getDriver();
const response = await driver.getOperationStatus({
operationHandle: this.operationHandle,
getProgressUpdate: progress,
});
Expand All @@ -186,9 +183,10 @@ export default class DBSQLOperation implements IOperation {
return Status.success();
}

this.logger?.log(LogLevel.debug, `Cancelling operation with id: ${this.getId()}`);
this.context.getLogger().log(LogLevel.debug, `Cancelling operation with id: ${this.getId()}`);

const response = await this.driver.cancelOperation({
const driver = await this.context.getDriver();
const response = await driver.cancelOperation({
operationHandle: this.operationHandle,
});
Status.assert(response.status);
Expand All @@ -209,11 +207,12 @@ export default class DBSQLOperation implements IOperation {
return Status.success();
}

this.logger?.log(LogLevel.debug, `Closing operation with id: ${this.getId()}`);
this.context.getLogger().log(LogLevel.debug, `Closing operation with id: ${this.getId()}`);

const driver = await this.context.getDriver();
const response =
this.closeOperation ??
(await this.driver.closeOperation({
(await driver.closeOperation({
operationHandle: this.operationHandle,
}));
Status.assert(response.status);
Expand Down Expand Up @@ -254,7 +253,7 @@ export default class DBSQLOperation implements IOperation {

await this.waitUntilReady(options);

this.logger?.log(LogLevel.debug, `Fetching schema for operation with id: ${this.getId()}`);
this.context.getLogger().log(LogLevel.debug, `Fetching schema for operation with id: ${this.getId()}`);
const metadata = await this.fetchMetadata();
return metadata.schema ?? null;
}
Expand Down Expand Up @@ -332,7 +331,8 @@ export default class DBSQLOperation implements IOperation {

private async fetchMetadata() {
if (!this.metadata) {
const metadata = await this.driver.getResultSetMetadata({
const driver = await this.context.getDriver();
const metadata = await driver.getResultSetMetadata({
operationHandle: this.operationHandle,
});
Status.assert(metadata.status);
Expand All @@ -349,13 +349,13 @@ export default class DBSQLOperation implements IOperation {
if (!this.resultHandler) {
switch (resultFormat) {
case TSparkRowSetType.COLUMN_BASED_SET:
this.resultHandler = new JsonResult(metadata.schema);
this.resultHandler = new JsonResult(this.context, metadata.schema);
break;
case TSparkRowSetType.ARROW_BASED_SET:
this.resultHandler = new ArrowResult(metadata.schema, metadata.arrowSchema);
this.resultHandler = new ArrowResult(this.context, metadata.schema, metadata.arrowSchema);
break;
case TSparkRowSetType.URL_BASED_SET:
this.resultHandler = new CloudFetchResult(metadata.schema);
this.resultHandler = new CloudFetchResult(this.context, metadata.schema);
break;
default:
this.resultHandler = undefined;
Expand Down
Loading