-
Notifications
You must be signed in to change notification settings - Fork 4
/
inference.ts
106 lines (99 loc) · 3.46 KB
/
inference.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import { StringDict } from "../common";
import { ExtraField } from "./extras/extras";
import { Page } from "./page";
import type { Prediction } from "./prediction";
import { Product } from "./product";
/**
*
* @typeParam DocT an extension of a `Prediction`. Is generic by default to
* allow for easier optional `PageT` generic typing.
* @typeParam PageT an extension of a `DocT` (`Prediction`). Should only be set
* if a document's pages have specific implementation.
*/
export abstract class Inference<
DocT extends Prediction = Prediction,
PageT extends DocT = DocT
> {
/** A boolean denoting whether a given inference result was rotated. */
isRotationApplied?: boolean;
/** Name and version of a given product. */
product: Product;
/** Wrapper for a document's pages prediction. */
pages: Page<PageT>[] = [];
/** A document's top-level `Prediction`. */
prediction!: DocT;
/** Extraneous fields relating to specific tools for some APIs. */
extras?: ExtraField[] = [];
/** Name of a document's endpoint. Has a default value for OTS APIs. */
endpointName?: string;
/** A document's version. Has a default value for OTS APIs. */
endpointVersion?: string;
constructor(rawPrediction: StringDict) {
this.isRotationApplied = rawPrediction?.is_rotation_applied ?? undefined;
this.product = rawPrediction?.product;
}
/**
* Default string representation.
*/
toString() {
let pages = "";
if (this.pages.toString().length > 0) {
pages = `
Page Predictions
================
${this.pages.map((e: Page<PageT>) => e.toString() || "").join("\n")}`;
}
return `Inference
#########
:Product: ${this.product.name} v${this.product.version}
:Rotation applied: ${this.isRotationApplied ? "Yes" : "No"}
Prediction
==========
${this.prediction.toString().length === 0 ? "" : this.prediction.toString() + "\n"}${pages}`;
}
/**
* Takes in an input string and replaces line breaks with `\n`.
* @param outStr string to cleanup
* @returns cleaned out string
*/
static cleanOutString(outStr: string): string {
const lines = / \n/gm;
return outStr.replace(lines, "\n");
}
}
/**
* Factory to allow for static-like property access syntax in TypeScript.
* Used to retrieve endpoint data for standard products.
*/
export class InferenceFactory {
/**
* Builds a blank product of the given type & sends back the endpointName & endpointVersion parameters of OTS classes.
* Note: this is needed to avoid passing anything other than the class of the object to the parse()/enqueue() call.
* @param inferenceClass Class of the product we are using
* @returns {Inference} An empty instance of a given product.
*/
public static getEndpoint<T extends Inference>(
inferenceClass: new (httpResponse: StringDict) => T
): [string, string] {
if (inferenceClass.name === "CustomV1") {
throw new Error(
"Cannot process custom endpoint as OTS API endpoints. Please provide an endpoint name & version manually."
);
}
const emptyProduct = new inferenceClass({
prediction: {},
pages: [],
}) as T;
if (
!emptyProduct.endpointName ||
!emptyProduct.endpointVersion ||
emptyProduct.endpointName.length === 0 ||
emptyProduct.endpointVersion.length === 0
) {
throw new Error(
`Error during endpoint verification, no endpoint found for product ${inferenceClass.name}.`
);
}
return [emptyProduct.endpointName, emptyProduct.endpointVersion];
}
}