Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Sum and Average aggregations #1873

Merged
merged 14 commits into from
Oct 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
166 changes: 166 additions & 0 deletions dev/src/aggregate.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
/**
* @license
* Copyright 2023 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

import * as firestore from '@google-cloud/firestore';

import {FieldPath} from './path';
import {google} from '../protos/firestore_v1_proto_api';

import IAggregation = google.firestore.v1.StructuredAggregationQuery.IAggregation;
import * as assert from 'assert';

/**
* Concrete implementation of the Aggregate type.
*/
export class Aggregate {
constructor(
readonly alias: string,
readonly aggregateType: AggregateType,
readonly fieldPath?: string | FieldPath
) {}

/**
* Converts this object to the proto representation of an Aggregate.
* @internal
*/
toProto(): IAggregation {
const proto: IAggregation = {};
if (this.aggregateType === 'count') {
proto.count = {};
} else if (this.aggregateType === 'sum') {
assert(
this.fieldPath !== undefined,
'Missing field path for sum aggregation.'
);
proto.sum = {
field: {
fieldPath: FieldPath.fromArgument(this.fieldPath!).formattedName,
},
};
} else if (this.aggregateType === 'avg') {
assert(
this.fieldPath !== undefined,
'Missing field path for average aggregation.'
);
proto.avg = {
field: {
fieldPath: FieldPath.fromArgument(this.fieldPath!).formattedName,
},
};
} else {
throw new Error(`Aggregate type ${this.aggregateType} unimplemented.`);
}
proto.alias = this.alias;
return proto;
}
}

/**
* Represents an aggregation that can be performed by Firestore.
*/
export class AggregateField<T> implements firestore.AggregateField<T> {
/** A type string to uniquely identify instances of this class. */
readonly type = 'AggregateField';

/**
* The field on which the aggregation is performed.
* @internal
**/
public readonly _field?: string | FieldPath;

/**
* Create a new AggregateField<T>
* @param aggregateType Specifies the type of aggregation operation to perform.
* @param field Optionally specifies the field that is aggregated.
* @internal
*/
private constructor(
public readonly aggregateType: AggregateType,
field?: string | FieldPath
) {
this._field = field;
}

/**
* Compares this object with the given object for equality.
*
* This object is considered "equal" to the other object if and only if
* `other` performs the same kind of aggregation on the same field (if any).
*
* @param other The object to compare to this object for equality.
* @return `true` if this object is "equal" to the given object, as
* defined above, or `false` otherwise.
*/
isEqual(other: AggregateField<T>): boolean {
return (
other instanceof AggregateField &&
this.aggregateType === other.aggregateType &&
((this._field === undefined && other._field === undefined) ||
(this._field !== undefined &&
other._field !== undefined &&
FieldPath.fromArgument(this._field).isEqual(
FieldPath.fromArgument(other._field)
)))
);
}

/**
* Create an AggregateField object that can be used to compute the count of
* documents in the result set of a query.
*/
static count(): AggregateField<number> {
return new AggregateField<number>('count');
}

/**
* Create an AggregateField object that can be used to compute the average of
* a specified field over a range of documents in the result set of a query.
* @param field Specifies the field to average across the result set.
*/
static average(field: string | FieldPath): AggregateField<number | null> {
return new AggregateField<number | null>('avg', field);
}

/**
* Create an AggregateField object that can be used to compute the sum of
* a specified field over a range of documents in the result set of a query.
* @param field Specifies the field to sum across the result set.
*/
static sum(field: string | FieldPath): AggregateField<number> {
return new AggregateField<number>('sum', field);
}
}

/**
* A type whose property values are all `AggregateField` objects.
*/
export interface AggregateSpec {
[field: string]: AggregateFieldType;
}

/**
* The union of all `AggregateField` types that are supported by Firestore.
*/
export type AggregateFieldType =
| ReturnType<typeof AggregateField.count>
| ReturnType<typeof AggregateField.sum>
| ReturnType<typeof AggregateField.average>;

/**
* Union type representing the aggregate type to be performed.
*/
export type AggregateType = 'count' | 'avg' | 'sum';
1 change: 1 addition & 0 deletions dev/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ export {GeoPoint} from './geo-point';
export {CollectionGroup};
export {QueryPartition} from './query-partition';
export {setLogFunction} from './logger';
export {AggregateField, Aggregate} from './aggregate';

const libVersion = require('../../package.json').version;
setLibVersion(libVersion);
Expand Down
93 changes: 80 additions & 13 deletions dev/src/reference.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

import * as firestore from '@google-cloud/firestore';
import * as assert from 'assert';
import {Duplex, Readable, Transform} from 'stream';
import * as deepEqual from 'fast-deep-equal';
import {GoogleError} from 'google-gax';
Expand Down Expand Up @@ -44,6 +45,7 @@ import {
autoId,
Deferred,
isPermanentRpcError,
mapToArray,
requestTag,
wrapError,
} from './util';
Expand All @@ -58,6 +60,7 @@ import {DocumentWatch, QueryWatch} from './watch';
import {validateDocumentData, WriteBatch, WriteResult} from './write-batch';
import api = protos.google.firestore.v1;
import {CompositeFilter, Filter, UnaryFilter} from './filter';
import {AggregateField, Aggregate, AggregateSpec} from './aggregate';

/**
* The direction of a `Query.orderBy()` clause is specified as 'desc' or 'asc'
Expand Down Expand Up @@ -1848,7 +1851,47 @@ export class Query<
AppModelType,
DbModelType
> {
return new AggregateQuery(this, {count: {}});
return this.aggregate({
count: AggregateField.count(),
});
}

/**
* Returns a query that can perform the given aggregations.
*
* The returned query, when executed, calculates the specified aggregations
* over the documents in the result set of this query, without actually
* downloading the documents.
*
* Using the returned query to perform aggregations is efficient because only
* the final aggregation values, not the documents' data, is downloaded. The
* returned query can even perform aggregations of the documents if the result set
* would be prohibitively large to download entirely (e.g. thousands of documents).
*
* @param aggregateSpec An `AggregateSpec` object that specifies the aggregates
* to perform over the result set. The AggregateSpec specifies aliases for each
* aggregate, which can be used to retrieve the aggregate result.
* @example
* ```typescript
* const aggregateQuery = col.aggregate(query, {
* countOfDocs: count(),
* totalHours: sum('hours'),
* averageScore: average('score')
* });
*
* const aggregateSnapshot = await aggregateQuery.get();
* const countOfDocs: number = aggregateSnapshot.data().countOfDocs;
* const totalHours: number = aggregateSnapshot.data().totalHours;
* const averageScore: number | null = aggregateSnapshot.data().averageScore;
* ```
*/
aggregate<T extends firestore.AggregateSpec>(
aggregateSpec: T
): AggregateQuery<T, AppModelType, DbModelType> {
return new AggregateQuery<T, AppModelType, DbModelType>(
this,
aggregateSpec
);
}

/**
Expand Down Expand Up @@ -3163,12 +3206,15 @@ export class CollectionReference<
* A query that calculates aggregations over an underlying query.
*/
export class AggregateQuery<
AggregateSpecType extends firestore.AggregateSpec,
AggregateSpecType extends AggregateSpec,
AppModelType = firestore.DocumentData,
DbModelType extends firestore.DocumentData = firestore.DocumentData,
> implements
firestore.AggregateQuery<AggregateSpecType, AppModelType, DbModelType>
{
private readonly clientAliasToServerAliasMap: Record<string, string> = {};
private readonly serverAliasToClientAliasMap: Record<string, string> = {};

/**
* @private
* @internal
Expand All @@ -3181,7 +3227,19 @@ export class AggregateQuery<
// eslint-disable-next-line @typescript-eslint/no-explicit-any
private readonly _query: Query<AppModelType, DbModelType>,
private readonly _aggregates: AggregateSpecType
) {}
) {
// Client-side aliases may be too long and exceed the 1500-byte string size limit.
// Such long strings do not need to be transferred over the wire either.
// The client maps the user's alias to a short form alias and send that to the server.
let aggregationNum = 0;
for (const clientAlias in this._aggregates) {
if (Object.prototype.hasOwnProperty.call(this._aggregates, clientAlias)) {
const serverAlias = `aggregate_${aggregationNum++}`;
this.clientAliasToServerAliasMap[clientAlias] = serverAlias;
this.serverAliasToClientAliasMap[serverAlias] = clientAlias;
}
}
}

/** The query whose aggregations will be calculated by this object. */
get query(): Query<AppModelType, DbModelType> {
Expand Down Expand Up @@ -3323,12 +3381,17 @@ export class AggregateQuery<
if (fields) {
const serializer = this._query.firestore._serializer!;
for (const prop of Object.keys(fields)) {
if (this._aggregates[prop] === undefined) {
const alias = this.serverAliasToClientAliasMap[prop];
assert(
alias !== null && alias !== undefined,
`'${prop}' not present in server-client alias mapping.`
);
if (this._aggregates[alias] === undefined) {
throw new Error(
`Unexpected alias [${prop}] in result aggregate result`
);
}
data[prop] = serializer.decodeValue(fields[prop]);
data[alias] = serializer.decodeValue(fields[prop]);
}
}
return data;
Expand All @@ -3344,18 +3407,22 @@ export class AggregateQuery<
*/
toProto(transactionId?: Uint8Array): api.IRunAggregationQueryRequest {
const queryProto = this._query.toProto();
//TODO(tomandersen) inspect _query to build request - this is just hard
// coded count right now.
const runQueryRequest: api.IRunAggregationQueryRequest = {
parent: queryProto.parent,
structuredAggregationQuery: {
structuredQuery: queryProto.structuredQuery,
aggregations: [
{
alias: 'count',
count: {},
},
],
aggregations: mapToArray(this._aggregates, (aggregate, clientAlias) => {
const serverAlias = this.clientAliasToServerAliasMap[clientAlias];
assert(
serverAlias !== null && serverAlias !== undefined,
`'${clientAlias}' not present in client-server alias mapping.`
);
return new Aggregate(
serverAlias,
aggregate.aggregateType,
aggregate._field
).toProto();
}),
},
};

Expand Down
21 changes: 21 additions & 0 deletions dev/src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import {randomBytes} from 'crypto';
import type {CallSettings, ClientConfig, GoogleError} from 'google-gax';
import type {BackoffSettings} from 'google-gax/build/src/gax';
import * as gapicConfig from './v1/firestore_client_config.json';
import Dict = NodeJS.Dict;

/**
* A Promise implementation that supports deferred resolution.
Expand Down Expand Up @@ -246,3 +247,23 @@ export function tryGetPreferRestEnvironmentVariable(): boolean | undefined {
return undefined;
}
}

/**
* Returns an array of values that are calculated by performing the given `fn`
* on all keys in the given `obj` dictionary.
*
* @private
* @internal
*/
export function mapToArray<V, R>(
obj: Dict<V>,
fn: (element: V, key: string, obj: Dict<V>) => R
): R[] {
const result: R[] = [];
for (const key in obj) {
if (Object.prototype.hasOwnProperty.call(obj, key)) {
result.push(fn(obj[key]!, key, obj));
}
}
return result;
}