Skip to content

Commit

Permalink
[Text Translation] Adding support for AAD authentication (Azure#29145)
Browse files Browse the repository at this point in the history
### Packages impacted by this PR
Text Translation SDK

### Describe the problem that is addressed by this PR
Adding support for AAD authentication using Text Translation endpoints.
Those endpoints require use of special header `Ocp-Apim-ResourceId` vs
using custom endpoint.

### Checklists
- [x] Added impacted package name to the issue description
- [ ] Does this PR needs any fixes in the SDK Generator?** _(If so,
create an Issue in the
[Autorest/typescript](https://github.com/Azure/autorest.typescript)
repository and link it here)_
- [x] Added a changelog (if necessary)

---------

Co-authored-by: Michal Materna <mimat@microsoft.com>
  • Loading branch information
MikeyMCZ and Michal Materna committed Apr 2, 2024
1 parent ef6ae49 commit 40a6141
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 8 deletions.
1 change: 1 addition & 0 deletions sdk/translation/ai-translation-text-rest/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
## 1.0.0-beta.2 (Unreleased)

### Features Added
- Added support for AAD authentication.

### Breaking Changes

Expand Down
2 changes: 1 addition & 1 deletion sdk/translation/ai-translation-text-rest/assets.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@
"AssetsRepo": "Azure/azure-sdk-assets",
"AssetsRepoPrefixPath": "js",
"TagPrefix": "js/translation/ai-translation-text-rest",
"Tag": "js/translation/ai-translation-text-rest_8b443c94ab"
"Tag": "js/translation/ai-translation-text-rest_6dacbcc4a1"
}
2 changes: 2 additions & 0 deletions sdk/translation/ai-translation-text-rest/karma.conf.js
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ module.exports = function (config) {
"TEXT_TRANSLATION_API_KEY",
"TEXT_TRANSLATION_REGION",
"RECORDINGS_RELATIVE_PATH",
"TEXT_TRANSLATION_AAD_REGION",
"TEXT_TRANSLATION_RESOURCE_ID",
],

// test results reporter to use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ export interface CommonScriptModelOutput {
}

// @public
function createClient(endpoint: undefined | string, credential?: undefined | TranslatorCredential | KeyCredential | TokenCredential, options?: ClientOptions): TextTranslationClient;
function createClient(endpoint: undefined | string, credential?: undefined | TranslatorCredential | TranslatorTokenCredential | KeyCredential | TokenCredential, options?: ClientOptions): TextTranslationClient;
export default createClient;

// @public
Expand Down Expand Up @@ -546,6 +546,16 @@ export interface TranslatorCredential {
region: string;
}

// @public (undocumented)
export interface TranslatorTokenCredential {
// (undocumented)
azureResourceId: string;
// (undocumented)
region: string;
// (undocumented)
tokenCredential: TokenCredential;
}

// @public
export interface TransliterableScriptOutput extends CommonScriptModelOutput {
toScripts: Array<CommonScriptModelOutput>;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
// Copyright (c) Microsoft Corporation.
// Licensed under the MIT license.

import { AzureKeyCredential } from "@azure/core-auth";
import { AzureKeyCredential, TokenCredential } from "@azure/core-auth";
import {
PipelinePolicy,
PipelineRequest,
Expand All @@ -11,12 +11,20 @@ import {

const APIM_KEY_HEADER_NAME = "Ocp-Apim-Subscription-Key";
const APIM_REGION_HEADER_NAME = "Ocp-Apim-Subscription-Region";
const APIM_RESOURCE_ID = "Ocp-Apim-ResourceId";
export const DEFAULT_SCOPE = "https://cognitiveservices.azure.com/.default";

export interface TranslatorCredential {
key: string;
region: string;
}

export interface TranslatorTokenCredential {
tokenCredential: TokenCredential;
region: string;
azureResourceId: string;
}

export class TranslatorAuthenticationPolicy implements PipelinePolicy {
name: string = "TranslatorAuthenticationPolicy";
credential: TranslatorCredential;
Expand Down Expand Up @@ -47,3 +55,19 @@ export class TranslatorAzureKeyAuthenticationPolicy implements PipelinePolicy {
return next(request);
}
}

export class TranslatorTokenCredentialAuthenticationPolicy implements PipelinePolicy {
name: string = "TranslatorTokenCredentialAuthenticationPolicy";
credential: TranslatorTokenCredential;

constructor(credential: TranslatorTokenCredential) {
this.credential = credential;
}

sendRequest(request: PipelineRequest, next: SendRequest): Promise<PipelineResponse> {
request.headers.set(APIM_REGION_HEADER_NAME, this.credential.region);
request.headers.set(APIM_RESOURCE_ID, this.credential.azureResourceId);

return next(request);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ import { logger } from "../logger";
import * as coreRestPipeline from "@azure/core-rest-pipeline";
import { TextTranslationClient } from "../clientDefinitions";
import {
DEFAULT_SCOPE,
TranslatorCredential,
TranslatorTokenCredential,
TranslatorAuthenticationPolicy,
TranslatorAzureKeyAuthenticationPolicy,
TranslatorTokenCredentialAuthenticationPolicy,
} from "./authentication";
import { AzureKeyCredential, KeyCredential, TokenCredential } from "@azure/core-auth";

const DEFAULT_SCOPE = "https://cognitiveservices.azure.com/.default";
const DEFAULT_ENPOINT = "https://api.cognitive.microsofttranslator.com";
const PLATFORM_HOST = "cognitiveservices";
const PLATFORM_PATH = "/translator/text/v3.0";
Expand All @@ -25,6 +27,17 @@ function isTranslatorKeyCredential(credential: any): credential is TranslatorCre
return (credential as TranslatorCredential)?.key !== undefined;
}

function isTokenCredential(credential: any): credential is TokenCredential {
return (credential as TokenCredential)?.getToken !== undefined;
}

function isTranslatorTokenCredential(credential: any): credential is TranslatorTokenCredential {
return (
(credential as TranslatorTokenCredential)?.tokenCredential !== undefined &&
(credential as TranslatorTokenCredential)?.azureResourceId !== undefined
);
}

/**
* Initialize a new instance of `TextTranslationClient`
* @param endpoint type: string, Supported Text Translation endpoints (protocol and hostname, for example:
Expand All @@ -33,7 +46,12 @@ function isTranslatorKeyCredential(credential: any): credential is TranslatorCre
*/
export default function createClient(
endpoint: undefined | string,
credential: undefined | TranslatorCredential | KeyCredential | TokenCredential = undefined,
credential:
| undefined
| TranslatorCredential
| TranslatorTokenCredential
| KeyCredential
| TokenCredential = undefined,
options: ClientOptions = {},
): TextTranslationClient {
let serviceEndpoint: string;
Expand Down Expand Up @@ -77,13 +95,23 @@ export default function createClient(
credential as AzureKeyCredential,
);
client.pipeline.addPolicy(mtKeyAuthenticationPolicy);
} else if (credential) {
} else if (isTokenCredential(credential)) {
client.pipeline.addPolicy(
coreRestPipeline.bearerTokenAuthenticationPolicy({
credential: credential as TokenCredential,
scopes: DEFAULT_SCOPE,
}),
);
} else if (isTranslatorTokenCredential(credential)) {
client.pipeline.addPolicy(
coreRestPipeline.bearerTokenAuthenticationPolicy({
credential: (credential as TranslatorTokenCredential).tokenCredential,
scopes: DEFAULT_SCOPE,
}),
);
client.pipeline.addPolicy(
new TranslatorTokenCredentialAuthenticationPolicy(credential as TranslatorTokenCredential),
);
}

return client;
Expand Down
2 changes: 1 addition & 1 deletion sdk/translation/ai-translation-text-rest/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,6 @@ export * from "./isUnexpected";
export * from "./models";
export * from "./outputModels";
export * from "./serializeHelper";
export { TranslatorCredential } from "./custom/authentication";
export { TranslatorCredential, TranslatorTokenCredential } from "./custom/authentication";

export default TextTranslationClient;
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ describe("BreakSentence tests", () => {

const breakSentences = response.body as BreakSentenceItemOutput[];
assert.isTrue(breakSentences[0].detectedLanguage?.language === "en");
assert.isTrue(breakSentences[0].detectedLanguage?.score === 1.0);
assert.isTrue(breakSentences[0].detectedLanguage?.score === 0.98);
assert.isTrue(breakSentences[0].sentLen[0] === 11);
});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
createCustomTranslationClient,
createTranslationClient,
createTokenTranslationClient,
createAADAuthenticationTranslationClient,
startRecorder,
} from "./utils/recordedClient";
import { Context } from "mocha";
Expand Down Expand Up @@ -373,4 +374,17 @@ describe("Translate tests", () => {
});
assert.equal(response.status, "200");
});

it("with AAD authentication", async () => {
const tokenClient = await createAADAuthenticationTranslationClient({ recorder });
const inputText: InputTextItem[] = [{ text: "This is a test." }];
const parameters: TranslateQueryParamProperties & Record<string, unknown> = {
to: "cs",
};
const response = await tokenClient.path("/translate").post({
body: inputText,
queryParameters: parameters,
});
assert.equal(response.status, "200");
});
});
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,21 @@ import {
import { StaticAccessTokenCredential } from "./StaticAccessTokenCredential";
import createTextTranslationClient, {
TranslatorCredential,
TranslatorTokenCredential,
TextTranslationClient,
} from "../../../src";
import { ClientOptions } from "@azure-rest/core-client";
import { createDefaultHttpClient, createPipelineRequest } from "@azure/core-rest-pipeline";
import { TokenCredential } from "@azure/core-auth";
import { ClientSecretCredential } from "@azure/identity";

const envSetupForPlayback: Record<string, string> = {
TEXT_TRANSLATION_API_KEY: "fakeapikey",
TEXT_TRANSLATION_ENDPOINT: "https://fakeEndpoint.cognitive.microsofttranslator.com",
TEXT_TRANSLATION_CUSTOM_ENDPOINT: "https://fakeCustomEndpoint.cognitiveservices.azure.com",
TEXT_TRANSLATION_REGION: "fakeregion",
TEXT_TRANSLATION_AAD_REGION: "fakeregion",
TEXT_TRANSLATION_RESOURCE_ID: "fakeresourceid",
};

const recorderEnvSetup: RecorderStartOptions = {
Expand Down Expand Up @@ -113,6 +117,36 @@ export async function createTokenTranslationClient(options: {
return client;
}

export async function createAADAuthenticationTranslationClient(options: {
recorder?: Recorder;
clientOptions?: ClientOptions;
}): Promise<TextTranslationClient> {
const { recorder, clientOptions = {} } = options;
const updatedOptions = recorder ? recorder.configureClientOptions(clientOptions) : clientOptions;
const endpoint = assertEnvironmentVariable("TEXT_TRANSLATION_ENDPOINT");
const region = assertEnvironmentVariable("TEXT_TRANSLATION_AAD_REGION");
const azureResourceId = assertEnvironmentVariable("TEXT_TRANSLATION_RESOURCE_ID");

let tokenCredential: TokenCredential;
if (isPlaybackMode()) {
tokenCredential = createMockToken();
} else {
const clientId = assertEnvironmentVariable("TEXT_TRANSLATION_CLIENT_ID");
const tenantId = assertEnvironmentVariable("TEXT_TRANSLATION_TENANT_ID");
const secret = assertEnvironmentVariable("TEXT_TRANSLATION_CLIENT_SECRET");

tokenCredential = new ClientSecretCredential(tenantId, clientId, secret);
}

const translatorTokenCredentials: TranslatorTokenCredential = {
tokenCredential,
azureResourceId,
region,
};
const client = createTextTranslationClient(endpoint, translatorTokenCredentials, updatedOptions);
return client;
}

export function createMockToken(): {
getToken: (_scopes: string) => Promise<{ token: string; expiresOnTimestamp: number }>;
} {
Expand Down

0 comments on commit 40a6141

Please sign in to comment.