Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 8 additions & 45 deletions app/api/generate-query/route.ts
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this endpoint is unused, could you please remove it?

Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
import { getDatabaseSchema } from '~/lib/schema';
import { NextRequest, NextResponse } from 'next/server';
import { generateSqlQueries, detectDatabaseTypeFromPrompt, type Table } from '~/lib/.server/llm/database-source';
import { generateSqlQueries } from '~/lib/.server/llm/database-source';
import { createScopedLogger } from '~/utils/logger';
import { z } from 'zod';
import { prisma } from '~/lib/prisma';
import { requireUserId } from '~/auth/session';
import { getConnectionProtocol } from '@liblab/data-access/utils/connection';
import { DataAccessor } from '@liblab/data-access/dataAccessor';

const logger = createScopedLogger('generate-sql');

const requestSchema = z.object({
prompt: z.string(),
existingQuery: z.string().optional(),
dataSourceId: z.string().optional(),
dataSourceId: z.string(),
suggestedDatabaseType: z.string().optional(),
});

Expand All @@ -22,54 +20,19 @@ export async function POST(request: NextRequest) {

try {
const body = await request.json();
const { prompt, existingQuery, dataSourceId, suggestedDatabaseType } = requestSchema.parse(body);
const { prompt, existingQuery, dataSourceId } = requestSchema.parse(body);
const existingQueries = existingQuery ? [existingQuery] : [];

let schema: Table[];
let type: string;
const schema = await getDatabaseSchema(dataSourceId, userId);

// If dataSourceId is provided, use the existing data source
if (dataSourceId) {
schema = await getDatabaseSchema(dataSourceId, userId);

const dataSource = await prisma.dataSource.findUniqueOrThrow({
where: { id: dataSourceId, createdById: userId },
});

type = getConnectionProtocol(dataSource.connectionString);
} else {
// Use AI to determine database type from prompt
const availableTypes = DataAccessor.getAvailableDatabaseTypes();
const detectedType = suggestedDatabaseType || (await detectDatabaseTypeFromPrompt(prompt, availableTypes));

if (!detectedType) {
return NextResponse.json(
{ error: 'Could not determine database type from prompt. Please specify a data source.' },
{ status: 400 },
);
}

type = detectedType;

logger.warn(
`⚠️ WARNING: No dataSourceId provided. Using sample schema for ${type}. Create a data source to query real data.`,
);

// For demonstration purposes, create a sample schema based on common patterns
// In a real implementation, you might want to ask the user to specify their schema
const sampleSchema = DataAccessor.getSampleSchema(type);

if (!sampleSchema) {
return NextResponse.json({ error: 'Unsupported database type' }, { status: 400 });
}

schema = sampleSchema;
}
const dataSource = await prisma.dataSource.findUniqueOrThrow({
where: { id: dataSourceId, createdById: userId },
});

const queries = await generateSqlQueries({
schema,
userPrompt: prompt,
databaseType: type,
connectionString: dataSource.connectionString,
existingQueries,
});

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { classNames } from '~/utils/classNames';
import { useState } from 'react';
import { toast } from 'sonner';
import { XCircle, CheckCircle, Loader2, Plug, Save } from 'lucide-react';
import { CheckCircle, Loader2, Plug, Save, XCircle } from 'lucide-react';
import type { TestConnectionResponse } from '~/components/@settings/tabs/data/DataTab';
import { z } from 'zod';
import { BaseSelect } from '~/components/ui/Select';
Expand Down Expand Up @@ -181,6 +181,7 @@ export default function AddDataSourceForm({ isSubmitting, setIsSubmitting, onSuc
<div className="min-w-[160px] flex-1">
<label className="mb-3 block text-sm font-medium text-secondary">Data source</label>
<BaseSelect
dataTestId={'add-data-source-select'}
value={dbType}
onChange={(value) => {
setDbType(value as DataSourceOption);
Expand Down
1 change: 1 addition & 0 deletions app/components/sidebar/Menu.client.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ export const Menu = () => {
animate={open ? 'open' : 'closed'}
variants={menuVariants}
style={{ width: '340px' }}
data-testid="menu"
className={classNames(
'flex selection-accent flex-col side-menu fixed top-0 h-full',
'bg-white dark:bg-gray-950 border-r border-gray-100 dark:border-gray-800/50',
Expand Down
3 changes: 3 additions & 0 deletions app/components/ui/IconButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ interface BaseIconButtonProps {
title?: string;
disabled?: boolean;
onClick?: (event: React.MouseEvent<HTMLButtonElement, MouseEvent>) => void;
dataTestId?: string;
}

type IconButtonWithoutChildrenProps = {
Expand All @@ -37,6 +38,7 @@ export const IconButton = forwardRef(
title,
onClick,
children,
dataTestId,
}: IconButtonProps,
ref: ForwardedRef<HTMLButtonElement>,
) => {
Expand All @@ -52,6 +54,7 @@ export const IconButton = forwardRef(
)}
title={title}
disabled={disabled}
data-testid={dataTestId}
onClick={(event) => {
if (disabled) {
return;
Expand Down
4 changes: 3 additions & 1 deletion app/components/ui/Select.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ interface SelectProps<T extends SelectOption = SelectOption> {
components?: any;
styles?: Partial<StylesConfig<T, false>>;
controlIcon?: React.ReactNode;
dataTestId?: string;
}

const createDefaultStyles = <T extends SelectOption>(): StylesConfig<T, false> => ({
Expand Down Expand Up @@ -207,6 +208,7 @@ export const BaseSelect = <T extends SelectOption = SelectOption>({
components: customComponents,
styles: customStyles,
controlIcon,
dataTestId,
}: SelectProps<T>) => {
const defaultStyles = createDefaultStyles<T>();
const mergedStyles = { ...defaultStyles, ...customStyles };
Expand All @@ -219,7 +221,7 @@ export const BaseSelect = <T extends SelectOption = SelectOption>({

return (
<ClientOnly>
<div style={{ width, minWidth }}>
<div style={{ width, minWidth }} data-testid={dataTestId}>
<Select<T>
value={value}
onChange={onChange}
Expand Down
1 change: 1 addition & 0 deletions app/components/ui/SettingsButton.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ export const SettingsButton = memo(({ onClick }: SettingsButtonProps) => {
size="xl"
title="Settings"
className="text-[#666] hover:text-primary hover:bg-depth-3/10 transition-colors"
dataTestId="settings-button"
>
<Settings className="w-5 h-5" />
</IconButton>
Expand Down
14 changes: 7 additions & 7 deletions app/layout.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ async function getRootData() {

let user = null;
let dataSources: any[] = [];
let pluginAccess = FREE_PLUGIN_ACCESS;
let dataSourceTypes: any[] = [];

if (session?.user) {
Expand All @@ -42,14 +41,15 @@ async function getRootData() {
// Get data sources for the user
const userAbility = await getUserAbility(session.user.id);
dataSources = await getDataSources(userAbility);
}

// Initialize plugin manager
await PluginManager.getInstance().initialize();
pluginAccess = PluginManager.getInstance().getAccessMap();
// Initialize plugin manager
await PluginManager.getInstance().initialize();

// Get available data source types
dataSourceTypes = DataSourcePluginManager.getAvailableDatabaseTypes();
}
const pluginAccess = PluginManager.getInstance().getAccessMap();

// Get available data source types
dataSourceTypes = DataSourcePluginManager.getAvailableDatabaseTypes();

return {
user,
Expand Down
12 changes: 7 additions & 5 deletions app/lib/.server/llm/database-source.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ import { generateObject } from 'ai';
import { z } from 'zod';
import { DataAccessor } from '@liblab/data-access/dataAccessor';
import { getLlm } from './get-llm';
import { getConnectionProtocol } from '@liblab/data-access/utils/connection';

const queryDecisionSchema = z.object({
shouldUpdateSql: z.boolean(),
Expand Down Expand Up @@ -60,27 +61,28 @@ export interface Table {
export type GenerateSqlQueriesOptions = {
schema: Table[];
userPrompt: string;
databaseType: string;
connectionString: string;
implementationPlan?: string;
existingQueries?: string[];
};

export async function generateSqlQueries({
schema,
userPrompt,
databaseType,
connectionString,
existingQueries,
}: GenerateSqlQueriesOptions): Promise<SqlQueryOutput | undefined> {
const dbSchema = formatDbSchemaForLLM(schema);

// Get the appropriate accessor for this database type
const accessor = DataAccessor.getByDatabaseType(databaseType);
const accessor = DataAccessor.getAccessor(connectionString);

if (!accessor) {
throw new Error(`No accessor found for database type: ${databaseType}`);
const protocol = getConnectionProtocol(connectionString);
throw new Error(`No accessor found for database type: ${protocol}`);
}

const systemPrompt = accessor.generateSystemPrompt(databaseType, dbSchema, existingQueries, userPrompt);
const systemPrompt = accessor.generateSystemPrompt(accessor.label, dbSchema, existingQueries, userPrompt);

try {
const llm = await getLlm();
Expand Down
5 changes: 1 addition & 4 deletions app/lib/.server/llm/stream-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,10 @@ ${props.summary}
where: { id: currentDataSourceId, createdById: userId },
});

const connectionDetails = new URL(dataSource.connectionString);
const type = connectionDetails.protocol.replace(':', '');

const sqlQueries = await generateSqlQueries({
schema,
userPrompt: lastUserMessage,
databaseType: type,
connectionString: dataSource.connectionString,
implementationPlan,
existingQueries,
});
Expand Down
Loading
Loading