Skip to content

Commit

Permalink
refactor: modularize file strategies and employ with use of DALL-E
Browse files Browse the repository at this point in the history
  • Loading branch information
danny-avila committed Jan 9, 2024
1 parent 75ef6ad commit 7f9d65c
Show file tree
Hide file tree
Showing 13 changed files with 301 additions and 250 deletions.
2 changes: 1 addition & 1 deletion api/app/clients/PluginsClient.js
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ class PluginsClient extends OpenAIClient {
signal: this.abortController.signal,
openAIApiKey: this.openAIApiKey,
conversationId: this.conversationId,
debug: this.options?.debug,
fileStrategy: this.options.req.app.locals.fileStrategy,
message,
},
});
Expand Down
66 changes: 16 additions & 50 deletions api/app/clients/tools/DALL-E.js
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
// From https://platform.openai.com/docs/api-reference/images/create
// To use this tool, you must pass in a configured OpenAIApi object.
const fs = require('fs');
const path = require('path');
const OpenAI = require('openai');
// const { genAzureEndpoint } = require('~/utils/genAzureEndpoints');
const { v4: uuidv4 } = require('uuid');
const { Tool } = require('langchain/tools');
const { HttpsProxyAgent } = require('https-proxy-agent');
const {
saveImageToFirebaseStorage,
getFirebaseStorageImageUrl,
getFirebaseStorage,
} = require('~/server/services/Files/Firebase');
const { getImageBasename } = require('~/server/services/Files/images');
const { processFileURL } = require('~/server/services/Files/process');
const extractBaseURL = require('~/utils/extractBaseURL');
const saveImageFromUrl = require('./saveImageFromUrl');
const { logger } = require('~/config');

const { DALLE_REVERSE_PROXY, PROXY } = process.env;
Expand All @@ -23,6 +16,7 @@ class OpenAICreateImage extends Tool {
super();

this.userId = fields.userId;
this.fileStrategy = fields.fileStrategy;
let apiKey = fields.DALLE_API_KEY || this.getApiKey();

const config = { apiKey };
Expand Down Expand Up @@ -82,11 +76,7 @@ Guidelines:
.trim();
}

getMarkdownImageUrl(imageName) {
const imageUrl = path
.join(this.relativeImageUrl, imageName)
.replace(/\\/g, '/')
.replace('public/', '');
wrapInMarkdown(imageUrl) {
return `![generated image](/${imageUrl})`;
}

Expand Down Expand Up @@ -118,45 +108,21 @@ Guidelines:
});
}

this.outputPath = path.resolve(
__dirname,
'..',
'..',
'..',
'..',
'client',
'public',
'images',
this.userId,
);

const appRoot = path.resolve(__dirname, '..', '..', '..', '..', 'client');
this.relativeImageUrl = path.relative(appRoot, this.outputPath);

// Check if directory exists, if not create it
if (!fs.existsSync(this.outputPath)) {
fs.mkdirSync(this.outputPath, { recursive: true });
}
try {
const result = await processFileURL({
fileStrategy: this.fileStrategy,
userId: this.userId,
URL: theImageUrl,
fileName: imageName,
basePath: 'images',
});

const storage = getFirebaseStorage();
if (storage) {
try {
await saveImageToFirebaseStorage(this.userId, theImageUrl, imageName);
this.result = await getFirebaseStorageImageUrl(`${this.userId}/${imageName}`);
logger.debug('[DALL-E] result: ' + this.result);
} catch (error) {
logger.error('Error while saving the image to Firebase Storage:', error);
this.result = `Failed to save the image to Firebase Storage. ${error.message}`;
}
} else {
try {
await saveImageFromUrl(theImageUrl, this.outputPath, imageName);
this.result = this.getMarkdownImageUrl(imageName);
} catch (error) {
logger.error('Error while saving the image locally:', error);
this.result = `Failed to save the image locally. ${error.message}`;
}
this.result = this.wrapInMarkdown(result);
} catch (error) {
logger.error('Error while saving the image:', error);
this.result = `Failed to save the image locally. ${error.message}`;
}

return this.result;
}
}
Expand Down
46 changes: 0 additions & 46 deletions api/app/clients/tools/saveImageFromUrl.js

This file was deleted.

66 changes: 16 additions & 50 deletions api/app/clients/tools/structured/DALLE3.js
Original file line number Diff line number Diff line change
@@ -1,20 +1,13 @@
// From https://platform.openai.com/docs/guides/images/usage?context=node
// To use this tool, you must pass in a configured OpenAIApi object.
const fs = require('fs');
const path = require('path');
const { z } = require('zod');
const OpenAI = require('openai');
const { v4: uuidv4 } = require('uuid');
const { Tool } = require('langchain/tools');
const { HttpsProxyAgent } = require('https-proxy-agent');
const {
saveImageToFirebaseStorage,
getFirebaseStorageImageUrl,
getFirebaseStorage,
} = require('~/server/services/Files/Firebase');
const { getImageBasename } = require('~/server/services/Files/images');
const { processFileURL } = require('~/server/services/Files/process');
const extractBaseURL = require('~/utils/extractBaseURL');
const saveImageFromUrl = require('../saveImageFromUrl');
const { logger } = require('~/config');

const { DALLE3_SYSTEM_PROMPT, DALLE_REVERSE_PROXY, PROXY } = process.env;
Expand All @@ -23,6 +16,7 @@ class DALLE3 extends Tool {
super();

this.userId = fields.userId;
this.fileStrategy = fields.fileStrategy;
let apiKey = fields.DALLE_API_KEY || this.getApiKey();
const config = { apiKey };
if (DALLE_REVERSE_PROXY) {
Expand Down Expand Up @@ -91,11 +85,7 @@ class DALLE3 extends Tool {
.trim();
}

getMarkdownImageUrl(imageName) {
const imageUrl = path
.join(this.relativeImageUrl, imageName)
.replace(/\\/g, '/')
.replace('public/', '');
wrapInMarkdown(imageUrl) {
return `![generated image](/${imageUrl})`;
}

Expand Down Expand Up @@ -143,43 +133,19 @@ Error Message: ${error.message}`;
});
}

this.outputPath = path.resolve(
__dirname,
'..',
'..',
'..',
'..',
'..',
'client',
'public',
'images',
this.userId,
);
const appRoot = path.resolve(__dirname, '..', '..', '..', '..', '..', 'client');
this.relativeImageUrl = path.relative(appRoot, this.outputPath);

// Check if directory exists, if not create it
if (!fs.existsSync(this.outputPath)) {
fs.mkdirSync(this.outputPath, { recursive: true });
}
const storage = getFirebaseStorage();
if (storage) {
try {
await saveImageToFirebaseStorage(this.userId, theImageUrl, imageName);
this.result = await getFirebaseStorageImageUrl(`${this.userId}/${imageName}`);
logger.debug('[DALL-E-3] result: ' + this.result);
} catch (error) {
logger.error('Error while saving the image to Firebase Storage:', error);
this.result = `Failed to save the image to Firebase Storage. ${error.message}`;
}
} else {
try {
await saveImageFromUrl(theImageUrl, this.outputPath, imageName);
this.result = this.getMarkdownImageUrl(imageName);
} catch (error) {
logger.error('Error while saving the image locally:', error);
this.result = `Failed to save the image locally. ${error.message}`;
}
try {
const result = await processFileURL({
fileStrategy: this.fileStrategy,
userId: this.userId,
URL: theImageUrl,
fileName: imageName,
basePath: 'images',
});

this.result = this.wrapInMarkdown(result);
} catch (error) {
logger.error('Error while saving the image:', error);
this.result = `Failed to save the image locally. ${error.message}`;
}

return this.result;
Expand Down
42 changes: 15 additions & 27 deletions api/app/clients/tools/structured/specs/DALLE3.spec.js
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,15 @@ const fs = require('fs');
const path = require('path');
const OpenAI = require('openai');
const DALLE3 = require('../DALLE3');
const {
getFirebaseStorage,
saveImageToFirebaseStorage,
} = require('~/server/services/Files/Firebase');
const saveImageFromUrl = require('../../saveImageFromUrl');
const { processFileURL } = require('~/server/services/Files/process');
const { saveFileFromURL } = require('~/server/services/Files/Local');

const { logger } = require('~/config');

jest.mock('openai');

jest.mock('~/server/services/Files/Firebase', () => ({
getFirebaseStorage: jest.fn(),
saveImageToFirebaseStorage: jest.fn(),
getFirebaseStorageImageUrl: jest.fn(),
jest.mock('~/server/services/Files/process', () => ({
processFileURL: jest.fn(),
}));

jest.mock('~/server/services/Files/images', () => ({
Expand Down Expand Up @@ -122,7 +118,7 @@ describe('DALLE3', () => {
};

generate.mockResolvedValue(mockResponse);
saveImageFromUrl.mockResolvedValue(true);
saveFileFromURL.mockResolvedValue(true);
fs.existsSync.mockReturnValue(true);
path.resolve.mockReturnValue('/fakepath/images');
path.join.mockReturnValue('/fakepath/images/img-test.png');
Expand Down Expand Up @@ -214,7 +210,7 @@ describe('DALLE3', () => {
};
const error = new Error('Error while saving the image');
generate.mockResolvedValue(mockResponse);
saveImageFromUrl.mockRejectedValue(error);
saveFileFromURL.mockRejectedValue(error);
const result = await dalle._call(mockData);
expect(logger.error).toHaveBeenCalledWith('Error while saving the image locally:', error);
expect(result).toBe('Failed to save the image locally. Error while saving the image');
Expand All @@ -227,16 +223,14 @@ describe('DALLE3', () => {
const mockImageUrl = 'http://example.com/img-test.png';
const mockResponse = { data: [{ url: mockImageUrl }] };
generate.mockResolvedValue(mockResponse);
getFirebaseStorage.mockReturnValue({}); // Simulate Firebase being initialized

await dalle._call(mockData);

expect(getFirebaseStorage).toHaveBeenCalled();
expect(saveImageToFirebaseStorage).toHaveBeenCalledWith(
undefined,
mockImageUrl,
expect.any(String),
);
expect(processFileURL).toHaveBeenCalledWith({
userId: undefined,
URL: mockImageUrl,
fileName: expect.any(String),
});
});

it('should handle error when saving image to Firebase Storage fails', async () => {
Expand All @@ -247,17 +241,11 @@ describe('DALLE3', () => {
const mockResponse = { data: [{ url: mockImageUrl }] };
const error = new Error('Error while saving to Firebase');
generate.mockResolvedValue(mockResponse);
getFirebaseStorage.mockReturnValue({}); // Simulate Firebase being initialized
saveImageToFirebaseStorage.mockRejectedValue(error);
processFileURL.mockRejectedValue(error);

const result = await dalle._call(mockData);

expect(logger.error).toHaveBeenCalledWith(
'Error while saving the image to Firebase Storage:',
error,
);
expect(result).toBe(
'Failed to save the image to Firebase Storage. Error while saving to Firebase',
);
expect(logger.error).toHaveBeenCalledWith('Error while saving the image:', error);
expect(result).toContain('Failed to save the image');
});
});
1 change: 1 addition & 0 deletions api/app/clients/tools/util/handleTools.js
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ const loadTools = async ({

const toolOptions = {
serpapi: { location: 'Austin,Texas,United States', hl: 'en', gl: 'us' },
dalle: { fileStrategy: options.fileStrategy },
};

const toolAuthFields = {};
Expand Down
Loading

0 comments on commit 7f9d65c

Please sign in to comment.