Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cli/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
"benchmark_gpt": "npm run build && node dist/benchmark_gpt.js",
"train_gpt": "npm run build && node dist/train_gpt.js",
"hellaswag_gpt": "npm run build && node dist/hellaswag_gpt.js",
"eval_finetuned_gpt2": "npm run build && node dist/evaluate_finetuned_gpt2.js",
"finetune_gpt": "npm run build && node dist/finetune_gpt.js",
"build": "tsc --build",
"test": ": nothing"
},
Expand Down
13 changes: 11 additions & 2 deletions cli/src/args.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ export interface BenchmarkArguments {
roundDuration: number
batchSize: number
validationSplit: number
datasetPath?: string
validationDatasetPath?: string

// DP
epsilon?: number
Expand All @@ -36,11 +38,14 @@ export interface BenchmarkArguments {
maxShareValue?: number

save: boolean
saveModel: boolean
host: URL
}

type BenchmarkUnsafeArguments = Omit<BenchmarkArguments, 'provider'> & {
task: string
datasetPath?: string
validationDatasetPath?: string
help?: boolean
}

Expand All @@ -55,7 +60,10 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(
roundDuration: { type: Number, alias: 'r', description: 'Round duration (in epochs)', defaultValue: 2 },
batchSize: { type: Number, alias: 'b', description: 'Training batch size', defaultValue: 10 },
validationSplit : { type: Number, alias: 'v', description: 'Validation dataset ratio', defaultValue: 0.2 },
datasetPath: { type: String, alias: 'd', description: 'Path to the dataset', optional: true },
validationDatasetPath: { type: String, alias: 'V', description: 'Path to the validation dataset', optional: true },
save: { type: Boolean, alias: 's', description: 'Save logs of benchmark', defaultValue: false },
saveModel: { type: Boolean, alias: 'm', description: 'Save trained model to disk', defaultValue: false },
host: {
type: (raw: string) => new URL(raw),
typeLabel: "URL",
Expand Down Expand Up @@ -89,18 +97,19 @@ const unsafeArgs = parse<BenchmarkUnsafeArguments>(

const supportedTasks = Map(
await Promise.all(
Set.of<TaskProvider<"image" | "tabular", Network>>(
Set.of<TaskProvider<"image" | "tabular" | "text", Network>>(
defaultTasks.cifar10,
defaultTasks.lusCovid,
defaultTasks.simpleFace,
defaultTasks.titanic,
defaultTasks.tinderDog,
defaultTasks.mnist,
defaultTasks.privacyrun,
).map(
async (t) =>
[(await t.getTask()).id, t] as [
string,
TaskProvider<"image" | "tabular", Network>,
TaskProvider<"image" | "tabular" | "text", Network>,
],
),
),
Expand Down
45 changes: 39 additions & 6 deletions cli/src/cli.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import fs from 'node:fs/promises'
import { createWriteStream } from "node:fs";
import path from "node:path";

import createDebug from "debug";
import type {
Dataset,
DataFormat,
Expand All @@ -17,25 +17,42 @@
} from "@epfml/discojs";
import { Disco, aggregator as aggregators, client as clients } from '@epfml/discojs'

import { loadText, saveModelToDisk } from "@epfml/discojs-node";
import { getTaskData } from './data.js'
import { args } from './args.js'
import { makeUserLogFile } from "./user_log.js";
import type { UserLogFile } from "./user_log.js";

const debug = createDebug("cli:main");

async function runUser<D extends DataType, N extends Network>(
task: Task<D, N>,
provider: TaskProvider<D, N>,

Check failure on line 30 in cli/src/cli.ts

View workflow job for this annotation

GitHub Actions / lint-most

'provider' is defined but never used. Allowed unused args must match /^_/u
url: URL,
data: Dataset<DataFormat.Raw[D]>,
validationData: Dataset<DataFormat.Raw[D]> | undefined,
userIndex: number,
numberOfUsers: number,
): Promise<List<SummaryLogs>> {
// cast as typescript isn't good with generics
debug(`Starting runUser for client ${userIndex}`);
const userStart = Date.now();

Check failure on line 38 in cli/src/cli.ts

View workflow job for this annotation

GitHub Actions / lint-most

'userStart' is assigned a value but never used. Allowed unused vars must match /^_/u
const trainingScheme = task.trainingInformation.scheme as N
const aggregator = aggregators.getAggregator(task)
const client = clients.getClient(trainingScheme, url, task, aggregator)
const disco = new Disco(task, client, { scheme: trainingScheme });

// For local training, load model from provider before training starts
// if (trainingScheme === "local") {
// debug(`Loading model for training client ${userIndex}...`);
// const modelStart = Date.now();
// console.log("Loading model for local training...");
// disco.trainer.model = await provider.getModel();
// console.log("Model loaded successfully");
// debug(`Model loading took ${Date.now() - modelStart}ms for client ${userIndex}`);
// }



const dir = path.join(".", `${args.testID}`);
await fs.mkdir(dir, { recursive: true });
const streamPath = path.join(dir, `client${userIndex}_local_log.jsonl`);
Expand All @@ -49,16 +66,25 @@
}

try{
for await (const log of disco.trainSummary(data)){
debug(`Starting training for client ${userIndex}`);
const trainStart = Date.now();
for await (const log of disco.trainSummary(data, validationData)){
finalLog.push(log);

if (jsonStream){
jsonStream.write(JSON.stringify(log) + "\n");
}
}
debug(`Training took ${Date.now() - trainStart}ms for client ${userIndex}`);

await new Promise((res, _) => setTimeout(() => res('timeout'), 1000)) // Wait for other peers to finish

// Save the trained model if requested
if (args.saveModel) {
const modelDir = path.join(".", `${args.testID}`, "models");
const modelFileName = `client${userIndex}_model.json`;
await saveModelToDisk(disco.trainer.model, modelDir, modelFileName);
console.log(`Model saved for client ${userIndex} at ${modelDir}/${modelFileName}`);
}
// saving the entire per-user logs
if (args.save) {
const finalPath = path.join(dir, `client${userIndex}_local_log.json`);
Expand Down Expand Up @@ -104,10 +130,17 @@
console.log({ args })

const dataSplits = await Promise.all(
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers))
Range(0, numberOfUsers).map(async i => getTaskData(task.id, i, numberOfUsers, args.datasetPath))
)

let validationData: Dataset<DataFormat.Raw[D]> | undefined = undefined;
if (args.validationDatasetPath) {
// Assume text task for now
validationData = loadText(args.validationDatasetPath).cached() as Dataset<DataFormat.Raw[D]>;
}

const logs = await Promise.all(
dataSplits.map((data, i) => runUser(task, args.host, data as Dataset<DataFormat.Raw[D]>, i, numberOfUsers))
dataSplits.map((data, i) => runUser(task, provider, args.host, data as Dataset<DataFormat.Raw[D]>, validationData, i, numberOfUsers))
)

if (args.save) {
Expand Down
10 changes: 8 additions & 2 deletions cli/src/data.ts
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
DataType,
Image,
Task,
Text,

Check failure on line 8 in cli/src/data.ts

View workflow job for this annotation

GitHub Actions / lint-most

'Text' is defined but never used. Allowed unused vars must match /^_/u
} from "@epfml/discojs";
import { loadCSV, loadImage, loadImagesInDir } from "@epfml/discojs-node";
import { loadCSV, loadImage, loadImagesInDir, loadText } from "@epfml/discojs-node";
import { Repeat } from "immutable";

async function loadSimpleFaceData(userIdx: number, totalClient: number): Promise<Dataset<DataFormat.Raw["image"]>> {
Expand Down Expand Up @@ -94,7 +95,10 @@
export async function getTaskData<D extends DataType>(
taskID: Task.ID,
userIdx: number,
totalClient: number
totalClient: number,
datasetPath?: string,
isValidation?: boolean,
validationDatasetPath?: string
): Promise<Dataset<DataFormat.Raw[D]>> {
switch (taskID) {
case "simple_face": // remove
Expand All @@ -118,6 +122,8 @@
case "mnist_federated":
case "mnist":
return loadData("mnist", userIdx) as Dataset<DataFormat.Raw[D]>;
case "privacyrun":
return loadText(isValidation && validationDatasetPath ? validationDatasetPath : datasetPath ?? '../datasets/med_mcq/train.txt') as Dataset<DataFormat.Raw[D]>;
default:
throw new Error(`Data loader for ${taskID} not implemented.`);
}
Expand Down
Loading
Loading