diff --git a/cli/package.json b/cli/package.json index cc0f741e2..8ed779b3f 100644 --- a/cli/package.json +++ b/cli/package.json @@ -9,6 +9,8 @@ "benchmark_gpt": "npm run build && node dist/benchmark_gpt.js", "train_gpt": "npm run build && node dist/train_gpt.js", "hellaswag_gpt": "npm run build && node dist/hellaswag_gpt.js", + "eval_finetuned_gpt2": "npm run build && node dist/evaluate_finetuned_gpt2.js", + "finetune_gpt": "npm run build && node dist/finetune_gpt.js", "build": "tsc --build", "test": ": nothing" }, diff --git a/cli/src/args.ts b/cli/src/args.ts index ced893a72..00d72102e 100644 --- a/cli/src/args.ts +++ b/cli/src/args.ts @@ -21,6 +21,8 @@ export interface BenchmarkArguments { roundDuration: number batchSize: number validationSplit: number + datasetPath?: string + validationDatasetPath?: string // DP epsilon?: number @@ -36,11 +38,14 @@ export interface BenchmarkArguments { maxShareValue?: number save: boolean + saveModel: boolean host: URL } type BenchmarkUnsafeArguments = Omit & { task: string + datasetPath?: string + validationDatasetPath?: string help?: boolean } @@ -55,7 +60,10 @@ const unsafeArgs = parse( roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 }, batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 }, validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 }, + datasetPath: { type: String, alias: 'd', description: 'Path to the dataset', optional: true }, + validationDatasetPath: { type: String, alias: 'V', description: 'Path to the validation dataset', optional: true }, save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false }, + saveModel: { type: Boolean, alias: 'm', description: 'Save trained model to disk', defaultValue: false }, host: { type: (raw: string) => new URL(raw), typeLabel: "URL", @@ -89,18 +97,19 @@ const unsafeArgs = parse( const supportedTasks = Map( await Promise.all( - Set.of>( + Set.of>( defaultTasks.cifar10, defaultTasks.lusCovid, defaultTasks.simpleFace, defaultTasks.titanic, defaultTasks.tinderDog, defaultTasks.mnist, + defaultTasks.privacyrun, ).map( async (t) => [(await t.getTask()).id, t] as [ string, - TaskProvider<"image" | "tabular", Network>, + TaskProvider<"image" | "tabular" | "text", Network>, ], ), ), diff --git a/cli/src/cli.ts b/cli/src/cli.ts index 2e23c6514..14523d44f 100644 --- a/cli/src/cli.ts +++ b/cli/src/cli.ts @@ -5,7 +5,7 @@ import { List, Range } from 'immutable' import fs from 'node:fs/promises' import { createWriteStream } from "node:fs"; import path from "node:path"; - +import createDebug from "debug"; import type { Dataset, DataFormat, @@ -17,25 +17,42 @@ import type { } from "@epfml/discojs"; import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs' +import { loadText, saveModelToDisk } from "@epfml/discojs-node"; import { getTaskData } from './data.js' import { args } from './args.js' import { makeUserLogFile } from "./user_log.js"; import type { UserLogFile } from "./user_log.js"; +const debug = createDebug("cli:main"); async function runUser( task: Task, + provider: TaskProvider, url: URL, data: Dataset, + validationData: Dataset | undefined, userIndex: number, numberOfUsers: number, ): Promise> { - // cast as typescript isn't good with generics + debug(`Starting runUser for client ${userIndex}`); + const userStart = Date.now(); const trainingScheme = task.trainingInformation.scheme as N const aggregator = aggregators.getAggregator(task) const client = clients.getClient(trainingScheme, url, task, aggregator) const disco = new Disco(task, client, { scheme: trainingScheme }); + // For local training, load model from provider before training starts + // if (trainingScheme === "local") { + // debug(`Loading model for training client ${userIndex}...`); + // const modelStart = Date.now(); + // console.log("Loading model for local training..."); + // disco.trainer.model = await provider.getModel(); + // console.log("Model loaded successfully"); + // debug(`Model loading took ${Date.now() - modelStart}ms for client ${userIndex}`); + // } + + + const dir = path.join(".", `${args.testID}`); await fs.mkdir(dir, { recursive: true }); const streamPath = path.join(dir, `client${userIndex}_local_log.jsonl`); @@ -49,16 +66,25 @@ async function runUser( } try{ - for await (const log of disco.trainSummary(data)){ + debug(`Starting training for client ${userIndex}`); + const trainStart = Date.now(); + for await (const log of disco.trainSummary(data, validationData)){ finalLog.push(log); if (jsonStream){ jsonStream.write(JSON.stringify(log) + "\n"); } } + debug(`Training took ${Date.now() - trainStart}ms for client ${userIndex}`); await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish - + // Save the trained model if requested + if (args.saveModel) { + const modelDir = path.join(".", `${args.testID}`, "models"); + const modelFileName = `client${userIndex}_model.json`; + await saveModelToDisk(disco.trainer.model, modelDir, modelFileName); + console.log(`Model saved for client ${userIndex} at ${modelDir}/${modelFileName}`); + } // saving the entire per-user logs if (args.save) { const finalPath = path.join(dir, `client${userIndex}_local_log.json`); @@ -104,10 +130,17 @@ async function main( console.log({ args }) const dataSplits = await Promise.all( - Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers)) + Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers, args.datasetPath)) ) + + let validationData: Dataset | undefined = undefined; + if (args.validationDatasetPath) { + // Assume text task for now + validationData = loadText(args.validationDatasetPath).cached() as Dataset; + } + const logs = await Promise.all( - dataSplits.map((data, i) => runUser(task, args.host, data as Dataset, i, numberOfUsers)) + dataSplits.map((data, i) => runUser(task, provider, args.host, data as Dataset, validationData, i, numberOfUsers)) ) if (args.save) { diff --git a/cli/src/data.ts b/cli/src/data.ts index aa4d0a330..f3b6ff834 100644 --- a/cli/src/data.ts +++ b/cli/src/data.ts @@ -5,8 +5,9 @@ import { DataType, Image, Task, + Text, } from "@epfml/discojs"; -import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node"; +import { loadCSV, loadImage, loadImagesInDir, loadText } from "@epfml/discojs-node"; import { Repeat } from "immutable"; async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise> { @@ -94,7 +95,10 @@ function loadData(dataName: string, split: number): Dataset( taskID: Task.ID, userIdx: number, - totalClient: number + totalClient: number, + datasetPath?: string, + isValidation?: boolean, + validationDatasetPath?: string ): Promise> { switch (taskID) { case "simple_face": // remove @@ -118,6 +122,8 @@ export async function getTaskData( case "mnist_federated": case "mnist": return loadData("mnist", userIdx) as Dataset; + case "privacyrun": + return loadText(isValidation && validationDatasetPath ? validationDatasetPath : datasetPath ?? '../datasets/med_mcq/train.txt') as Dataset; default: throw new Error(`Data loader for ${taskID} not implemented.`); } diff --git a/cli/src/evaluate_finetuned_gpt2.ts b/cli/src/evaluate_finetuned_gpt2.ts new file mode 100644 index 000000000..ecccc6bc1 --- /dev/null +++ b/cli/src/evaluate_finetuned_gpt2.ts @@ -0,0 +1,249 @@ +import "@tensorflow/tfjs-node"; +import * as tf from "@tensorflow/tfjs"; +import fs from "node:fs/promises"; +import { parse } from "ts-command-line-args"; +import { models, Tokenizer } from "@epfml/discojs"; +import { loadModelFromDisk } from "@epfml/discojs-node"; + +interface Args { + modelPath: string; + testPath: string; + maxSamples?: number; + savePath?: string; + help?: boolean; +} + +// ========================= +// HOW TO RUN +// ========================= +// npm -w cli run eval_finetuned_gpt2 -- --modelPath absolute_path_to_model/model.json --testPath absolute_path_to_test_data/train_no_exp.txt --maxSamples 100 + +// ========================= +// LOAD DATASET +// ========================= +async function loadDataset(filePath: string, limit = -1): Promise { + const text = await fs.readFile(filePath, "utf-8"); + const lines = text.split("\n"); + + const samples: string[] = []; + let current = ""; + + for (const line of lines) { + const l = line.trim(); + + if (l.includes("<|startoftext|>")) { + current = ""; + } else if (l.includes("<|endoftext|>")) { + samples.push(current.trim()); + if (limit !== -1 && samples.length >= limit) break; + } else { + current += l + "\n"; + } + } + + return samples; +} + +// ========================= +// PARSE SAMPLE +// ========================= +function parseSample(sample: string) { + const lines = sample.split("\n"); + + let answer = ""; + const promptLines: string[] = []; + + for (const line of lines) { + if (line.startsWith("Answer:")) { + answer = line.replace("Answer:", "").trim(); + } else { + promptLines.push(line); + } + } + + const basePrompt = promptLines.join("\n"); + return { basePrompt, answer }; +} + +// ========================= +// SOFTMAX (for safety) +// ========================= +async function scoreText( + tfModel: tf.LayersModel, + tokenizer: Tokenizer, + text: string +): Promise { + const tokens = tokenizer.tokenize(text); + + if (tokens.size < 2) return -Infinity; + + const inputTokens = tokens.slice(0, tokens.size - 1).toArray(); + const targets = tokens.slice(1).toArray(); + + const inputTensor = tf.tensor([inputTokens], [1, inputTokens.length], "int32"); + + const logits = tfModel.predict(inputTensor) as tf.Tensor; + const logitsArray = await logits.array() as number[][][]; + + let score = 0; + + for (let i = 0; i < targets.length; i++) { + const stepLogits = logitsArray[0][i]; + + const logit = stepLogits[targets[i]] ?? -100; + + score += logit; + } + + inputTensor.dispose(); + logits.dispose(); + + return score; +} + +// ========================= +// SCORE OPTIONS +// ========================= +async function scoreOptions( + tfModel: tf.LayersModel, + tokenizer: Tokenizer, + texts: string[] +): Promise { + const scores: number[] = []; + + for (const t of texts) { + const s = await scoreText(tfModel, tokenizer, t); + scores.push(s); + } + + return scores; +} + +// ========================= +// BENCHMARK +// ========================= +async function benchmarkQA( + model: models.GPT, + tokenizer: Tokenizer, + dataset: string[], + savePath?: string +) { + console.log("=== QA LOGPROB BENCHMARK ==="); + + const tfModel = model.extract(); + + let correct = 0; + let total = 0; + + const options = ["A", "B", "C", "D"]; + + const confusion: Record> = { + A: { A: 0, B: 0, C: 0, D: 0 }, + B: { A: 0, B: 0, C: 0, D: 0 }, + C: { A: 0, B: 0, C: 0, D: 0 }, + D: { A: 0, B: 0, C: 0, D: 0 } + }; + + type PredictionLog = { + predicted: string; + answer: string; + correct: boolean; + }; + + const logs: PredictionLog[] = []; + + const start = Date.now(); + + for (const sample of dataset) { + const { basePrompt, answer } = parseSample(sample); + + const texts = options.map( + (opt) => `${basePrompt}\nAnswer: ${opt}` + ); + + const scores = await scoreOptions(tfModel, tokenizer, texts); + + let bestIdx = 0; + for (let i = 1; i < scores.length; i++) { + if (scores[i] > scores[bestIdx]) bestIdx = i; + } + + const predicted = options[bestIdx]; + + if (predicted === answer) correct++; + total++; + + if (confusion[answer]) { + confusion[answer][predicted]++; + } + + logs.push({ + predicted, + answer, + correct: predicted === answer + }); + + if (total % 50 === 0) { + console.log(`Processed ${total} samples...`); + } + } + + const accuracy = correct / total; + const duration = ((Date.now() - start) / 1000).toFixed(2); + + console.log("\n========================="); + console.log(`Accuracy: ${(accuracy * 100).toFixed(2)}%`); + console.log(`Time: ${duration}s`); + console.log("=========================\n"); + + console.log("Confusion Matrix:"); + console.table(confusion); + + console.log("\nPer-class accuracy:"); + for (const cls of options) { + const totalCls = Object.values(confusion[cls]).reduce((a, b) => a + b, 0); + const correctCls = confusion[cls][cls]; + const acc = totalCls ? (correctCls / totalCls) * 100 : 0; + + console.log(`${cls}: ${acc.toFixed(2)}%`); + } + + if (savePath) { + await fs.writeFile(savePath, JSON.stringify(logs, null, 2)); + console.log(`Saved results to ${savePath}`); + } +} + +// ========================= +// MAIN +// ========================= +async function main() { + const args = parse({ + modelPath: { type: String }, + testPath: { type: String }, + maxSamples: { type: Number, optional: true, defaultValue: 100 }, + savePath: { type: String, optional: true }, + help: { type: Boolean, optional: true } + }); + + console.log("Loading tokenizer..."); + const tokenizer = await Tokenizer.from_pretrained("Xenova/gpt2"); + + console.log("Loading model..."); + const model = await loadModelFromDisk(args.modelPath); + + if (!(model instanceof models.GPT)) { + throw new Error("Model must be GPT"); + } + + console.log("Loading dataset..."); + const dataset = await loadDataset(args.testPath, args.maxSamples); + + console.log(`Loaded ${dataset.length} samples`); + + await benchmarkQA(model, tokenizer, dataset, args.savePath); + + console.log("Done."); +} + +main().catch(console.error); \ No newline at end of file diff --git a/discojs/src/client/client.ts b/discojs/src/client/client.ts index 9e1298b93..5c8aa1304 100644 --- a/discojs/src/client/client.ts +++ b/discojs/src/client/client.ts @@ -186,8 +186,11 @@ export abstract class Client extends EventEmitter<{ } url.pathname += `tasks/${this.task.id}/model.json` + debug("fetching latest model from server at %0 for task %1...", url.href, this.task.id) + const response = await fetch(url); - if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`); + if (!response.ok) throw new Error(`fetch: HTTP status ${response.status}`) + else debug("response ok, decoding model...") const encoded = new Uint8Array(await response.arrayBuffer()) return await serialization.model.decode(encoded) diff --git a/discojs/src/client/event_connection.ts b/discojs/src/client/event_connection.ts index 3e3aec409..82722d111 100644 --- a/discojs/src/client/event_connection.ts +++ b/discojs/src/client/event_connection.ts @@ -118,6 +118,10 @@ export class WebSocketServer extends EventEmitter<{ [K in type]: NarrowMessage { + debug("websocket closed: code=%o reason=%o wasClean=%o", event.code, event.reason, event.wasClean) + } + return await new Promise((resolve, reject) => { ws.onerror = (err: WebSocket.ErrorEvent) => { reject(new Error(`Server unreachable: ${err.message}`)) diff --git a/discojs/src/client/federated/federated_client.ts b/discojs/src/client/federated/federated_client.ts index a89c65ad6..b6c2c59d5 100644 --- a/discojs/src/client/federated/federated_client.ts +++ b/discojs/src/client/federated/federated_client.ts @@ -88,7 +88,9 @@ export class FederatedClient extends Client<"federated"> { // Upon connecting, the server answers with a boolean // which indicates whether there are enough participants or not debug(`[${shortenId(this.ownId)}] upon connecting, wait for participant flag %o`, this.waitingForMoreParticipants) - model.weights = serialization.weights.decode(payload) + if (payload != null) { + model.weights = serialization.weights.decode(payload) + } return model } diff --git a/discojs/src/client/federated/messages.ts b/discojs/src/client/federated/messages.ts index 3733d2c1c..c6dee6698 100644 --- a/discojs/src/client/federated/messages.ts +++ b/discojs/src/client/federated/messages.ts @@ -18,7 +18,7 @@ export interface NewFederatedNodeInfo { type: type.NewFederatedNodeInfo id: NodeID waitForMoreParticipants: boolean - payload: serialization.Encoded; + payload?: serialization.Encoded | null; round: number nbOfParticipants: number } diff --git a/discojs/src/default_tasks/index.ts b/discojs/src/default_tasks/index.ts index 43adf0d3c..f46f27026 100644 --- a/discojs/src/default_tasks/index.ts +++ b/discojs/src/default_tasks/index.ts @@ -4,4 +4,5 @@ export { mnist } from './mnist.js' export { simpleFace } from './simple_face.js' export { titanic } from './titanic.js' export { wikitext } from './wikitext.js' -export { tinderDog } from './tinder_dog.js' \ No newline at end of file +export { tinderDog } from './tinder_dog.js' +export { privacyrun } from './privacyrun.js' \ No newline at end of file diff --git a/discojs/src/default_tasks/privacyrun.ts b/discojs/src/default_tasks/privacyrun.ts new file mode 100644 index 000000000..d667e3c9c --- /dev/null +++ b/discojs/src/default_tasks/privacyrun.ts @@ -0,0 +1,68 @@ +import type { TaskProvider } from "../index.js"; +import { Tokenizer, models, serialization } from "../index.js"; + +export const privacyrun: TaskProvider<"text", "federated"> = { + async getTask() { + return { + id: 'privacyrun', + dataType: "text", + displayInformation: { + title: "GPT Privacy-Preserving Fine-tuning", + summary: { + preview: 'Fine-tune a pre-trained GPT model collaboratively and privately.', + overview: "Fine-tune a pre-trained GPT-2 model created by the ONNX converter in your browser collaboratively without sharing your raw data. The model is loaded from Google Cloud Storage and fine-tuned using federated learning." + }, + model: [ + "The model is a pre-trained GPT-2 architecture converted from ONNX and loaded from Google Cloud Storage.", + "The tokenizer used for preprocessing is the GPT-2 Byte-Pair encoding tokenizer.", + "The model is trained via an Adam optimizer with unit gradient clipping and softmax cross-entropy loss.", + "Context length is kept at 1024 to match the pre-trained model, with batch size at 1.", + ].join(" "), + dataFormatInformation: 'You can use any natural language (text) dataset. The dataset should be formatted as a plain text file with each line representing a segment of text.', + dataExample: + "For the first twenty years of its existence , the only staged performances of Parsifal took place in the Bayreuth Festspielhaus , the venue for which Wagner conceived the work.", + }, + trainingInformation: { + scheme: 'federated', + aggregationStrategy: 'mean', + minNbOfParticipants: 2, + epochs: 6, + validationSplit: 0.1, + roundDuration: 2, + batchSize: 8, + tokenizer: await Tokenizer.from_pretrained("Xenova/gpt2"), + contextLength: 1024, + tensorBackend: 'gpt' + } + } + }, + + async getModel() { + // Load the pre-trained ONNX-converted model from Google Cloud Storage + // The model should be in DiscoJS serialization format (created by onnx-converter) + const modelUrl = "https://storage.googleapis.com/deai-313515.appspot.com/model.json"; + + try { + const response = await fetch(modelUrl); + if (!response.ok) { + throw new Error(`HTTP ${response.status}: ${response.statusText}`); + } + + const arrayBuffer = await response.arrayBuffer(); + const encodedData = new Uint8Array(arrayBuffer); + + const model = await serialization.model.decode(encodedData); + + if (!(model instanceof models.GPT)) { + throw new Error("Loaded model is not a GPT model"); + } + + console.log("Successfully loaded pre-trained GPT model from Google Cloud Storage"); + + return model; + } catch (error) { + console.error("Failed to load model from Google Cloud Storage:", error); + throw new Error(`Could not load model from ${modelUrl}. Make sure the URL is correct and the model exists in DiscoJS serialization format.`); + } + }, +} diff --git a/discojs/src/models/gpt/index.ts b/discojs/src/models/gpt/index.ts index 228d60cc4..886421892 100644 --- a/discojs/src/models/gpt/index.ts +++ b/discojs/src/models/gpt/index.ts @@ -228,8 +228,16 @@ export class GPT extends Model<"text"> { } static deserialize(data: GPTSerialization): Model<"text"> { + + debug("GPT model deserialization started") + const model = new GPT(data.config); + + debug("GPT model config initialized: %O", data.config) + model.weights = data.weights; + + debug("GPT model weights initialized") return model; } diff --git a/discojs/src/models/gpt/model.ts b/discojs/src/models/gpt/model.ts index 01ee51e92..9207e8d47 100644 --- a/discojs/src/models/gpt/model.ts +++ b/discojs/src/models/gpt/model.ts @@ -64,8 +64,12 @@ export class GPTModel extends tf.LayersModel { let accuracyFraction: [number, number] = [0, 0]; let averageLoss = 0 let iteration = 1 + + debug("before iterator init") const iterator = await dataset.iterator() + debug("after getting iterator, before next") let next = await iterator.next() + debug("after next of iterator") while (next.done !== true && iteration <= this.config.maxIter) { let weightUpdateTime = performance.now() @@ -73,7 +77,9 @@ export class GPTModel extends tf.LayersModel { const { xs, ys } = next.value as { xs: tf.Tensor2D, ys: tf.Tensor3D } let preprocessingTime = performance.now() + debug("await batch data before {} iteration", iteration) await Promise.all([xs.data(), ys.data()]) + debug("after await batch data {} iteration", iteration) preprocessingTime = performance.now() - preprocessingTime // TODO include as a tensor inside the model diff --git a/discojs/src/serialization/model.ts b/discojs/src/serialization/model.ts index 020d147af..4df0fab0c 100644 --- a/discojs/src/serialization/model.ts +++ b/discojs/src/serialization/model.ts @@ -7,6 +7,10 @@ import { GPTConfig } from '../models/index.js' import * as coder from "./coder.js"; import { Encoded, isEncoded } from "./coder.js"; +import createDebug from "debug" + +const debug = createDebug("discojs:serialization:model"); + const Type = { TFJS: 0, GPT: 1 @@ -16,11 +20,13 @@ export async function encode(model: Model): Promise { switch (true) { case model instanceof models.TFJS: { const serialized = await model.serialize(); + debug("TFJS model serialized"); return coder.encode([Type.TFJS, ...serialized]); } case model instanceof models.GPT: { const { weights, config } = model.serialize(); const serializedWeights = await serialization.weights.encode(weights); + debug("GPT model weights serialized"); return coder.encode([Type.GPT, serializedWeights, config]); } default: @@ -30,23 +36,34 @@ export async function encode(model: Model): Promise { export async function decode(encoded: Encoded): Promise> { const raw = coder.decode(encoded) + + debug("IMPORTANT:model decoded") if (!Array.isArray(raw) || raw.length < 2) { throw new Error("invalid encoding, encoding isn't an array or doesn't contain enough values") } + + debug("model encoding array length: %d", raw.length) + const type = raw[0] as unknown if (typeof type !== 'number') { throw new Error('invalid encoding, first encoding field should be the model type') } + + debug("model type: %d", type) + const rawModel = raw[1] as unknown switch (type) { case Type.TFJS: { + debug("TFJS model decoding started"); if (raw.length !== 3) throw new Error( "invalid TFJS model encoding: should be an array of length 3", ); const [rawDatatype, rawModel] = raw.slice(1) as unknown[]; + debug("TFJS model datatype: %s", rawDatatype); + let datatype; switch (rawDatatype) { case "image": @@ -70,6 +87,7 @@ export async function decode(encoded: Encoded): Promise> { if (raw.length == 2) { config = undefined } else if (raw.length == 3) { + debug("GPT model config decoding") config = raw[2] as GPTConfig } else { throw new Error('invalid encoding, gpt-tfjs model encoding should be an array of length 2 or 3') @@ -79,7 +97,12 @@ export async function decode(encoded: Encoded): Promise> { throw new Error( "invalid encoding, gpt-tfjs model weights should be an encoding of its weights", ); + + debug("GPT model weights decoding...") const weights = serialization.weights.decode(rawModel) + + debug("GPT model weights decoded, deserializing model... CONFIG MIGHT BE WRONG") + debug("GPT model config: %O", config || "undefined, using default config") return models.GPT.deserialize({weights, config}) } default: diff --git a/discojs/src/training/disco.ts b/discojs/src/training/disco.ts index 7dd019b2d..33c63c244 100644 --- a/discojs/src/training/disco.ts +++ b/discojs/src/training/disco.ts @@ -21,8 +21,12 @@ import { getAggregator } from "../aggregator/index.js"; import { enumerate, split } from "../utils/async_iterator.js"; import { EventEmitter } from "../utils/event_emitter.js"; +import createDebug from "debug" + import { RoundLogs, Trainer } from "./trainer.js"; +const debug = createDebug("discojs:training:disco"); + interface DiscoConfig { scheme: N; logger: Logger; @@ -159,20 +163,24 @@ export class Disco extends EventEmitter<{ /** Train on dataset, yielding logs of every batch. */ async *trainByBatch( dataset: Dataset, + validationDataset?: Dataset, ): AsyncGenerator { - for await (const round of this.train(dataset)) + for await (const round of this.train(dataset, validationDataset)) for await (const epoch of round) yield* epoch; } /** Train on dataset, yielding summary logs */ async *trainSummary( dataset: Dataset, + validationDataset?: Dataset, ): AsyncGenerator { - for await (const [roundNum, round] of enumerate(this.train(dataset))) { + for await (const [roundNum, round] of enumerate(this.train(dataset, validationDataset))) { const [roundGen, roundLogsPromise] = async_iterator.split(round); const epochResults: Array<{epochNum: number; epochLogs: EpochLogs}> = []; + debug("Starting round %d", roundNum) + for await (const [epochNum, epoch] of enumerate(roundGen)) { const [epochGen, epochLogsPromise] = async_iterator.split(epoch); for await (const _ of epochGen); @@ -190,8 +198,8 @@ export class Disco extends EventEmitter<{ } /** Run whole train on dataset. */ - async trainFully(dataset: Dataset): Promise { - for await (const round of this.train(dataset)) + async trainFully(dataset: Dataset, validationDataset?: Dataset): Promise { + for await (const round of this.train(dataset, validationDataset)) for await (const epoch of round) for await (const _ of epoch); } @@ -203,20 +211,28 @@ export class Disco extends EventEmitter<{ **/ async *train( dataset: Dataset, + validationDataset?: Dataset, ): AsyncGenerator< AsyncGenerator, RoundLogs> > { this.#logger.success("Training started"); - const [trainingDataset, validationDataset] = - await this.#preprocessSplitAndBatch(dataset); + const [trainingDataset, validationDataset_] = + validationDataset !== undefined + ? await this.#preprocessDatasets(dataset, validationDataset) + : await this.#preprocessSplitAndBatch(dataset); // the client fetches the latest weights upon connection // TODO unsafe cast + debug("Connecting to client and fetching initial model..."); this.trainer.model = (await this.#client.connect()) as Model; + debug("Initial model fetched successfully"); + if (this.trainer.model === null) { + debug(`No pre-trained model provided for client, initializing randomly...`); + } for await (const [roundNum, round] of enumerate( - this.trainer.train(trainingDataset, validationDataset), + this.trainer.train(trainingDataset, validationDataset_), )) { yield async function* (this: Disco) { const [roundGen, roundLogsPromise] = split(round); @@ -297,6 +313,31 @@ export class Disco extends EventEmitter<{ validation.batch(batchSize).cached(), ]; } + + async #preprocessDatasets( + trainingDataset: Dataset, + validationDataset: Dataset, + ): Promise< + [ + Dataset>, + Dataset> | undefined, + ] + > { + const { batchSize } = this.#task.trainingInformation; + + let preprocessedTraining = processing.preprocess(this.#task, trainingDataset); + let preprocessedValidation = processing.preprocess(this.#task, validationDataset); + + if (this.#preprocessOnce) { + preprocessedTraining = new Dataset(await arrayFromAsync(preprocessedTraining)); + preprocessedValidation = new Dataset(await arrayFromAsync(preprocessedValidation)); + } + + return [ + preprocessedTraining.batch(batchSize).cached(), + preprocessedValidation.batch(batchSize).cached(), + ]; + } } // Array.fromAsync not yet widely used (2024) diff --git a/discojs/src/training/trainer.ts b/discojs/src/training/trainer.ts index 68e716bcc..73cae1f7c 100644 --- a/discojs/src/training/trainer.ts +++ b/discojs/src/training/trainer.ts @@ -16,8 +16,11 @@ import { } from "../index.js"; import { privacy } from "../index.js"; import { Client } from "../client/index.js"; +import createDebug from "debug"; import * as async_iterator from "../utils/async_iterator.js"; +const debug = createDebug("discojs:training:trainer"); + export interface RoundLogs { epochs: List; participants: number; @@ -88,6 +91,7 @@ export class Trainer { AsyncGenerator, RoundLogs>, void > { + debug("Start train") if (this.#training !== undefined) throw new Error( "training already running, stop it before launching a new one", @@ -109,6 +113,9 @@ export class Trainer { void > { const totalRound = Math.trunc(this.#epochs / this.#roundDuration); + + debug("Run rounds") + for (let round = 0; round < totalRound; round++) { await this.#client.onRoundBeginCommunication(); @@ -150,6 +157,8 @@ export class Trainer { ): AsyncGenerator, RoundLogs> { let epochsLogs = List(); + debug("Run round") + // Before starting the training, get the validation of global model const validation = validationDataset !== undefined ? await this.model.evaluate(validationDataset) : undefined; diff --git a/server/src/controllers/federated_controller.ts b/server/src/controllers/federated_controller.ts index 1e52fe539..cf2df0fa8 100644 --- a/server/src/controllers/federated_controller.ts +++ b/server/src/controllers/federated_controller.ts @@ -89,7 +89,7 @@ export class FederatedController extends TrainingController< type: MessageTypes.NewFederatedNodeInfo, id: clientId, waitForMoreParticipants: this.connections.size < minNbOfParticipants, - payload: this.#latestGlobalWeights, + payload: this.#aggregator.round === 0 ? undefined : this.#latestGlobalWeights, round: this.#aggregator.round, nbOfParticipants: this.connections.size }