From ada62ffd33989b4cef321bf28ae2b1b97b19fd53 Mon Sep 17 00:00:00 2001 From: Koen Vlaswinkel Date: Wed, 4 Oct 2023 13:23:54 +0200 Subject: [PATCH] Convert `yaml.ts` to handle multiple models per method This changes YAML parsing/creating functions for the model editor to handle multiple models per method. The changes in the actual YAML handling are fairly small because the format itself already supports multiple models per method. I've introduced a few helper functions to convert between the old and new types. This should only be necessary while we're in the middle of the transition to the new types and can be removed later. For now, we'll just take the first model in the array when converting from the new to the old type. This is a change in the behavior since currently we always take the last model in the array but this behavior is undocumented and unsupported, so it should be fine to change it. --- .../src/model-editor/auto-modeler.ts | 7 +- .../src/model-editor/modeled-method-fs.ts | 12 +- .../model-editor/modeled-methods-legacy.ts | 33 ++ extensions/ql-vscode/src/model-editor/yaml.ts | 49 ++- .../test/unit-tests/model-editor/yaml.test.ts | 410 ++++++++++-------- 5 files changed, 299 insertions(+), 212 deletions(-) create mode 100644 extensions/ql-vscode/src/model-editor/modeled-methods-legacy.ts diff --git a/extensions/ql-vscode/src/model-editor/auto-modeler.ts b/extensions/ql-vscode/src/model-editor/auto-modeler.ts index fc359e7e086..f691aeeb897 100644 --- a/extensions/ql-vscode/src/model-editor/auto-modeler.ts +++ b/extensions/ql-vscode/src/model-editor/auto-modeler.ts @@ -16,6 +16,7 @@ import { QueryRunner } from "../query-server"; import { DatabaseItem } from "../databases/local-databases"; import { Mode } from "./shared/mode"; import { CancellationTokenSource } from "vscode"; +import { convertToLegacyModeledMethods } from "./modeled-methods-legacy"; // Limit the number of candidates we send to the model in each request // to avoid long requests. @@ -192,11 +193,13 @@ export class AutoModeler { filename: "auto-model.yml", }); - const loadedMethods = loadDataExtensionYaml(models); - if (!loadedMethods) { + const rawLoadedMethods = loadDataExtensionYaml(models); + if (!rawLoadedMethods) { return; } + const loadedMethods = convertToLegacyModeledMethods(rawLoadedMethods); + // Any candidate that was part of the response is a negative result // meaning that the canidate is not a sink for the kinds that the LLM is checking for. // For now we model this as a sink neutral method, however this is subject diff --git a/extensions/ql-vscode/src/model-editor/modeled-method-fs.ts b/extensions/ql-vscode/src/model-editor/modeled-method-fs.ts index bf70b1fc1cd..8f3ce34aaee 100644 --- a/extensions/ql-vscode/src/model-editor/modeled-method-fs.ts +++ b/extensions/ql-vscode/src/model-editor/modeled-method-fs.ts @@ -10,6 +10,11 @@ import { getOnDiskWorkspaceFolders } from "../common/vscode/workspace-folders"; import { load as loadYaml } from "js-yaml"; import { CodeQLCliServer } from "../codeql-cli/cli"; import { pathsEqual } from "../common/files"; +import { + convertFromLegacyModeledMethods, + convertFromLegacyModeledMethodsFiles, + convertToLegacyModeledMethods, +} from "./modeled-methods-legacy"; export async function saveModeledMethods( extensionPack: ExtensionPack, @@ -29,8 +34,8 @@ export async function saveModeledMethods( const yamls = createDataExtensionYamls( language, methods, - modeledMethods, - existingModeledMethods, + convertFromLegacyModeledMethods(modeledMethods), + convertFromLegacyModeledMethodsFiles(existingModeledMethods), mode, ); @@ -68,7 +73,8 @@ async function loadModeledMethodFiles( ); continue; } - modeledMethodsByFile[modelFile] = modeledMethods; + modeledMethodsByFile[modelFile] = + convertToLegacyModeledMethods(modeledMethods); } return modeledMethodsByFile; diff --git a/extensions/ql-vscode/src/model-editor/modeled-methods-legacy.ts b/extensions/ql-vscode/src/model-editor/modeled-methods-legacy.ts new file mode 100644 index 00000000000..a482af0e34c --- /dev/null +++ b/extensions/ql-vscode/src/model-editor/modeled-methods-legacy.ts @@ -0,0 +1,33 @@ +import { ModeledMethod } from "./modeled-method"; + +export function convertFromLegacyModeledMethods( + modeledMethods: Record, +): Record { + // Convert a single ModeledMethod to an array of ModeledMethods + return Object.fromEntries( + Object.entries(modeledMethods).map(([signature, modeledMethod]) => { + return [signature, [modeledMethod]]; + }), + ); +} + +export function convertToLegacyModeledMethods( + modeledMethods: Record, +): Record { + // Always take the first modeled method in the array + return Object.fromEntries( + Object.entries(modeledMethods).map(([signature, modeledMethods]) => { + return [signature, modeledMethods[0]]; + }), + ); +} + +export function convertFromLegacyModeledMethodsFiles( + modeledMethods: Record>, +): Record> { + return Object.fromEntries( + Object.entries(modeledMethods).map(([filename, modeledMethods]) => { + return [filename, convertFromLegacyModeledMethods(modeledMethods)]; + }), + ); +} diff --git a/extensions/ql-vscode/src/model-editor/yaml.ts b/extensions/ql-vscode/src/model-editor/yaml.ts index c1c70448161..5ace7fab13a 100644 --- a/extensions/ql-vscode/src/model-editor/yaml.ts +++ b/extensions/ql-vscode/src/model-editor/yaml.ts @@ -71,8 +71,8 @@ ${extensions.join("\n")}`; export function createDataExtensionYamls( language: string, methods: Method[], - newModeledMethods: Record, - existingModeledMethods: Record>, + newModeledMethods: Record, + existingModeledMethods: Record>, mode: Mode, ) { switch (mode) { @@ -98,11 +98,11 @@ export function createDataExtensionYamls( function createDataExtensionYamlsByGrouping( language: string, methods: Method[], - newModeledMethods: Record, - existingModeledMethods: Record>, + newModeledMethods: Record, + existingModeledMethods: Record>, createFilename: (method: Method) => string, ): Record { - const methodsByFilename: Record> = {}; + const methodsByFilename: Record> = {}; // We only want to generate a yaml file when it's a known external API usage // and there are new modeled methods for it. This avoids us overwriting other @@ -114,10 +114,12 @@ function createDataExtensionYamlsByGrouping( } // First populate methodsByFilename with any existing modeled methods. - for (const [filename, methods] of Object.entries(existingModeledMethods)) { + for (const [filename, methodsBySignature] of Object.entries( + existingModeledMethods, + )) { if (filename in methodsByFilename) { - for (const [signature, method] of Object.entries(methods)) { - methodsByFilename[filename][signature] = method; + for (const [signature, methods] of Object.entries(methodsBySignature)) { + methodsByFilename[filename][signature] = methods; } } } @@ -125,10 +127,12 @@ function createDataExtensionYamlsByGrouping( // Add the new modeled methods, potentially overwriting existing modeled methods // but not removing existing modeled methods that are not in the new set. for (const method of methods) { - const newMethod = newModeledMethods[method.signature]; - if (newMethod) { + const newMethods = newModeledMethods[method.signature]; + if (newMethods) { const filename = createFilename(method); - methodsByFilename[filename][newMethod.signature] = newMethod; + + // Override any existing modeled methods with the new ones. + methodsByFilename[filename][method.signature] = newMethods; } } @@ -137,7 +141,7 @@ function createDataExtensionYamlsByGrouping( for (const [filename, methods] of Object.entries(methodsByFilename)) { result[filename] = createDataExtensionYaml( language, - Object.values(methods), + Object.values(methods).flatMap((methods) => methods), ); } @@ -147,8 +151,8 @@ function createDataExtensionYamlsByGrouping( export function createDataExtensionYamlsForApplicationMode( language: string, methods: Method[], - newModeledMethods: Record, - existingModeledMethods: Record>, + newModeledMethods: Record, + existingModeledMethods: Record>, ): Record { return createDataExtensionYamlsByGrouping( language, @@ -162,8 +166,8 @@ export function createDataExtensionYamlsForApplicationMode( export function createDataExtensionYamlsForFrameworkMode( language: string, methods: Method[], - newModeledMethods: Record, - existingModeledMethods: Record>, + newModeledMethods: Record, + existingModeledMethods: Record>, ): Record { return createDataExtensionYamlsByGrouping( language, @@ -228,14 +232,14 @@ function validateModelExtensionFile(data: unknown): data is ModelExtensionFile { export function loadDataExtensionYaml( data: unknown, -): Record | undefined { +): Record | undefined { if (!validateModelExtensionFile(data)) { return undefined; } const extensions = data.extensions; - const modeledMethods: Record = {}; + const modeledMethods: Record = {}; for (const extension of extensions) { const addsTo = extension.addsTo; @@ -250,11 +254,16 @@ export function loadDataExtensionYaml( } for (const row of data) { - const modeledMethod = definition.readModeledMethod(row); + const modeledMethod: ModeledMethod = definition.readModeledMethod(row); if (!modeledMethod) { continue; } - modeledMethods[modeledMethod.signature] = modeledMethod; + + if (!(modeledMethod.signature in modeledMethods)) { + modeledMethods[modeledMethod.signature] = []; + } + + modeledMethods[modeledMethod.signature].push(modeledMethod); } } diff --git a/extensions/ql-vscode/test/unit-tests/model-editor/yaml.test.ts b/extensions/ql-vscode/test/unit-tests/model-editor/yaml.test.ts index b8382e2ec31..14f3bb2344c 100644 --- a/extensions/ql-vscode/test/unit-tests/model-editor/yaml.test.ts +++ b/extensions/ql-vscode/test/unit-tests/model-editor/yaml.test.ts @@ -225,43 +225,49 @@ describe("createDataExtensionYamlsForApplicationMode", () => { }, ], { - "org.sql2o.Connection#createQuery(String)": { - type: "sink", - input: "Argument[0]", - output: "", - kind: "sql", - provenance: "df-generated", - signature: "org.sql2o.Connection#createQuery(String)", - packageName: "org.sql2o", - typeName: "Connection", - methodName: "createQuery", - methodParameters: "(String)", - }, - "org.springframework.boot.SpringApplication#run(Class,String[])": { - type: "neutral", - input: "", - output: "", - kind: "summary", - provenance: "manual", - signature: - "org.springframework.boot.SpringApplication#run(Class,String[])", - packageName: "org.springframework.boot", - typeName: "SpringApplication", - methodName: "run", - methodParameters: "(Class,String[])", - }, - "org.sql2o.Sql2o#Sql2o(String,String,String)": { - type: "sink", - input: "Argument[0]", - output: "", - kind: "jndi", - provenance: "manual", - signature: "org.sql2o.Sql2o#Sql2o(String,String,String)", - packageName: "org.sql2o", - typeName: "Sql2o", - methodName: "Sql2o", - methodParameters: "(String,String,String)", - }, + "org.sql2o.Connection#createQuery(String)": [ + { + type: "sink", + input: "Argument[0]", + output: "", + kind: "sql", + provenance: "df-generated", + signature: "org.sql2o.Connection#createQuery(String)", + packageName: "org.sql2o", + typeName: "Connection", + methodName: "createQuery", + methodParameters: "(String)", + }, + ], + "org.springframework.boot.SpringApplication#run(Class,String[])": [ + { + type: "neutral", + input: "", + output: "", + kind: "summary", + provenance: "manual", + signature: + "org.springframework.boot.SpringApplication#run(Class,String[])", + packageName: "org.springframework.boot", + typeName: "SpringApplication", + methodName: "run", + methodParameters: "(Class,String[])", + }, + ], + "org.sql2o.Sql2o#Sql2o(String,String,String)": [ + { + type: "sink", + input: "Argument[0]", + output: "", + kind: "jndi", + provenance: "manual", + signature: "org.sql2o.Sql2o#Sql2o(String,String,String)", + packageName: "org.sql2o", + typeName: "Sql2o", + methodName: "Sql2o", + methodParameters: "(String,String,String)", + }, + ], }, {}, ); @@ -463,84 +469,97 @@ describe("createDataExtensionYamlsForApplicationMode", () => { }, ], { - "org.sql2o.Connection#createQuery(String)": { - type: "sink", - input: "Argument[0]", - output: "", - kind: "sql", - provenance: "df-generated", - signature: "org.sql2o.Connection#createQuery(String)", - packageName: "org.sql2o", - typeName: "Connection", - methodName: "createQuery", - methodParameters: "(String)", - }, - "org.springframework.boot.SpringApplication#run(Class,String[])": { - type: "neutral", - input: "", - output: "", - kind: "summary", - provenance: "manual", - signature: - "org.springframework.boot.SpringApplication#run(Class,String[])", - packageName: "org.springframework.boot", - typeName: "SpringApplication", - methodName: "run", - methodParameters: "(Class,String[])", - }, - "org.sql2o.Sql2o#Sql2o(String,String,String)": { - type: "sink", - input: "Argument[0]", - output: "", - kind: "jndi", - provenance: "manual", - signature: "org.sql2o.Sql2o#Sql2o(String,String,String)", - packageName: "org.sql2o", - typeName: "Sql2o", - methodName: "Sql2o", - methodParameters: "(String,String,String)", - }, - }, - { - "models/sql2o.model.yml": { - "org.sql2o.Connection#createQuery(String)": { - type: "neutral", - input: "", + "org.sql2o.Connection#createQuery(String)": [ + { + type: "sink", + input: "Argument[0]", output: "", - kind: "summary", - provenance: "manual", + kind: "sql", + provenance: "df-generated", signature: "org.sql2o.Connection#createQuery(String)", packageName: "org.sql2o", typeName: "Connection", methodName: "createQuery", methodParameters: "(String)", }, - "org.sql2o.Query#executeScalar(Class)": { + ], + "org.springframework.boot.SpringApplication#run(Class,String[])": [ + { type: "neutral", input: "", output: "", kind: "summary", provenance: "manual", - signature: "org.sql2o.Query#executeScalar(Class)", + signature: + "org.springframework.boot.SpringApplication#run(Class,String[])", + packageName: "org.springframework.boot", + typeName: "SpringApplication", + methodName: "run", + methodParameters: "(Class,String[])", + }, + ], + "org.sql2o.Sql2o#Sql2o(String,String,String)": [ + { + type: "sink", + input: "Argument[0]", + output: "", + kind: "jndi", + provenance: "manual", + signature: "org.sql2o.Sql2o#Sql2o(String,String,String)", packageName: "org.sql2o", - typeName: "Query", - methodName: "executeScalar", - methodParameters: "(Class)", + typeName: "Sql2o", + methodName: "Sql2o", + methodParameters: "(String,String,String)", }, + ], + }, + { + "models/sql2o.model.yml": { + "org.sql2o.Connection#createQuery(String)": [ + { + type: "neutral", + input: "", + output: "", + kind: "summary", + provenance: "manual", + signature: "org.sql2o.Connection#createQuery(String)", + packageName: "org.sql2o", + typeName: "Connection", + methodName: "createQuery", + methodParameters: "(String)", + }, + ], + "org.sql2o.Query#executeScalar(Class)": [ + { + type: "neutral", + input: "", + output: "", + kind: "summary", + provenance: "manual", + signature: "org.sql2o.Query#executeScalar(Class)", + packageName: "org.sql2o", + typeName: "Query", + methodName: "executeScalar", + methodParameters: "(Class)", + }, + ], }, "models/gson.model.yml": { - "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)": { - type: "summary", - input: "Argument[this]", - output: "ReturnValue", - kind: "taint", - provenance: "df-generated", - signature: "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)", - packageName: "com.google.gson", - typeName: "TypeAdapter", - methodName: "fromJsonTree", - methodParameters: "(JsonElement)", - }, + "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)": [ + { + type: "summary", + input: "Argument[this]", + output: "ReturnValue", + kind: "taint", + provenance: "df-generated", + signature: + "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)", + packageName: "com.google.gson", + typeName: "TypeAdapter", + methodName: "fromJsonTree", + methodParameters: "(JsonElement)", + }, + ], }, }, ); @@ -694,30 +713,34 @@ describe("createDataExtensionYamlsForFrameworkMode", () => { }, ], { - "org.sql2o.Connection#createQuery(String)": { - type: "sink", - input: "Argument[0]", - output: "", - kind: "sql", - provenance: "df-generated", - signature: "org.sql2o.Connection#createQuery(String)", - packageName: "org.sql2o", - typeName: "Connection", - methodName: "createQuery", - methodParameters: "(String)", - }, - "org.sql2o.Sql2o#Sql2o(String,String,String)": { - type: "sink", - input: "Argument[0]", - output: "", - kind: "jndi", - provenance: "manual", - signature: "org.sql2o.Sql2o#Sql2o(String,String,String)", - packageName: "org.sql2o", - typeName: "Sql2o", - methodName: "Sql2o", - methodParameters: "(String,String,String)", - }, + "org.sql2o.Connection#createQuery(String)": [ + { + type: "sink", + input: "Argument[0]", + output: "", + kind: "sql", + provenance: "df-generated", + signature: "org.sql2o.Connection#createQuery(String)", + packageName: "org.sql2o", + typeName: "Connection", + methodName: "createQuery", + methodParameters: "(String)", + }, + ], + "org.sql2o.Sql2o#Sql2o(String,String,String)": [ + { + type: "sink", + input: "Argument[0]", + output: "", + kind: "jndi", + provenance: "manual", + signature: "org.sql2o.Sql2o#Sql2o(String,String,String)", + packageName: "org.sql2o", + typeName: "Sql2o", + methodName: "Sql2o", + methodParameters: "(String,String,String)", + }, + ], }, {}, ); @@ -846,71 +869,82 @@ describe("createDataExtensionYamlsForFrameworkMode", () => { }, ], { - "org.sql2o.Connection#createQuery(String)": { - type: "sink", - input: "Argument[0]", - output: "", - kind: "sql", - provenance: "df-generated", - signature: "org.sql2o.Connection#createQuery(String)", - packageName: "org.sql2o", - typeName: "Connection", - methodName: "createQuery", - methodParameters: "(String)", - }, - "org.sql2o.Sql2o#Sql2o(String,String,String)": { - type: "sink", - input: "Argument[0]", - output: "", - kind: "jndi", - provenance: "manual", - signature: "org.sql2o.Sql2o#Sql2o(String,String,String)", - packageName: "org.sql2o", - typeName: "Sql2o", - methodName: "Sql2o", - methodParameters: "(String,String,String)", - }, - }, - { - "models/org.sql2o.model.yml": { - "org.sql2o.Connection#createQuery(String)": { - type: "neutral", - input: "", + "org.sql2o.Connection#createQuery(String)": [ + { + type: "sink", + input: "Argument[0]", output: "", - kind: "summary", - provenance: "manual", + kind: "sql", + provenance: "df-generated", signature: "org.sql2o.Connection#createQuery(String)", packageName: "org.sql2o", typeName: "Connection", methodName: "createQuery", methodParameters: "(String)", }, - "org.sql2o.Query#executeScalar(Class)": { - type: "neutral", - input: "", + ], + "org.sql2o.Sql2o#Sql2o(String,String,String)": [ + { + type: "sink", + input: "Argument[0]", output: "", - kind: "summary", + kind: "jndi", provenance: "manual", - signature: "org.sql2o.Query#executeScalar(Class)", + signature: "org.sql2o.Sql2o#Sql2o(String,String,String)", packageName: "org.sql2o", - typeName: "Query", - methodName: "executeScalar", - methodParameters: "(Class)", + typeName: "Sql2o", + methodName: "Sql2o", + methodParameters: "(String,String,String)", }, + ], + }, + { + "models/org.sql2o.model.yml": { + "org.sql2o.Connection#createQuery(String)": [ + { + type: "neutral", + input: "", + output: "", + kind: "summary", + provenance: "manual", + signature: "org.sql2o.Connection#createQuery(String)", + packageName: "org.sql2o", + typeName: "Connection", + methodName: "createQuery", + methodParameters: "(String)", + }, + ], + "org.sql2o.Query#executeScalar(Class)": [ + { + type: "neutral", + input: "", + output: "", + kind: "summary", + provenance: "manual", + signature: "org.sql2o.Query#executeScalar(Class)", + packageName: "org.sql2o", + typeName: "Query", + methodName: "executeScalar", + methodParameters: "(Class)", + }, + ], }, "models/gson.model.yml": { - "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)": { - type: "summary", - input: "Argument[this]", - output: "ReturnValue", - kind: "taint", - provenance: "df-generated", - signature: "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)", - packageName: "com.google.gson", - typeName: "TypeAdapter", - methodName: "fromJsonTree", - methodParameters: "(JsonElement)", - }, + "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)": [ + { + type: "summary", + input: "Argument[this]", + output: "ReturnValue", + kind: "taint", + provenance: "df-generated", + signature: + "com.google.gson.TypeAdapter#fromJsonTree(JsonElement)", + packageName: "com.google.gson", + typeName: "TypeAdapter", + methodName: "fromJsonTree", + methodParameters: "(JsonElement)", + }, + ], }, }, ); @@ -980,18 +1014,20 @@ describe("loadDataExtensionYaml", () => { }); expect(data).toEqual({ - "org.sql2o.Connection#createQuery(String)": { - input: "Argument[0]", - kind: "sql", - output: "", - type: "sink", - provenance: "manual", - signature: "org.sql2o.Connection#createQuery(String)", - packageName: "org.sql2o", - typeName: "Connection", - methodName: "createQuery", - methodParameters: "(String)", - }, + "org.sql2o.Connection#createQuery(String)": [ + { + input: "Argument[0]", + kind: "sql", + output: "", + type: "sink", + provenance: "manual", + signature: "org.sql2o.Connection#createQuery(String)", + packageName: "org.sql2o", + typeName: "Connection", + methodName: "createQuery", + methodParameters: "(String)", + }, + ], }); });