Replies: 3 comments 1 reply
-
|
Sure! Like I keep telling you, contributions are welcome! Posting duplicates of issues isn't going to make things happen faster. Please stop posting duplicates issues |
Beta Was this translation helpful? Give feedback.
1 reply
-
|
@saudet , I am currently using Java reflection to call the forward() method of any custom Model class that inherits from Module in my project. This implementation works but is not elegant—I had to use this workaround because of issues with AnyModule.
// ==================== Reflection-based Forward ====================
/**
* Find the forward(Tensor) method on the module class hierarchy.
*/
private Method findForwardMethod(Module module) {
Class<?> cls = module.getClass();
while (cls != null && cls != Object.class) {
try {
Method m = cls.getDeclaredMethod("forward", Tensor.class);
m.setAccessible(true);
System.out.printf("[DDPTrainer] Found forward method on %s%n", cls.getName());
return m;
} catch (NoSuchMethodException e) {
cls = cls.getSuperclass();
}
}
throw new UnsupportedOperationException(
"Module " + module.getClass().getName() + " has no forward(Tensor) method");
}
/**
* Invoke forward via reflection.
*/
private Tensor invokeForward(Method m, Module module, Tensor input) {
try {
return (Tensor) m.invoke(module, input);
} catch (Exception e) {
throw new RuntimeException("Failed to invoke forward: " + e.getMessage(), e);
}
}
// ==================== Forward/Backward ====================
/**
* Forward pass with automatic gradient synchronization on backward.
* This is the main entry point for DDP training.
*/
public Tensor forward(Tensor input) {
numForwardCalls++;
return invokeForward(forwardMethod, module, input);
}
package torch.distributed;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.Module;
import org.bytedeco.pytorch.global.torch;
import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.lang.reflect.Method;
/**
* DistributedDataParallel (DDP) - Wrapper for distributed training.
* Mirrors PyTorch's torch.nn.parallel.DistributedDataParallel API.
*
* Key features:
* - Automatic gradient synchronization via allreduce
* - GPU-to-GPU communication optimization (NCCL)
* - Configurable gradient averaging
* - Automatic device placement
*
* Usage:
* <pre>
* Module model = new MyModel();
* ProcessGroupWrapper pg = ProcessGroupWrapper.create(rank, worldSize, store);
* DDPTrainer ddp = DDPTrainer.builder()
* .module(model)
* .processGroup(pg)
* .bucketCapMb(25)
* .build();
* </pre>
*/
public class DDPTrainer implements AutoCloseable {
public static final String VERSION = "1.0";
private final Module module;
private final Method forwardMethod; // cached via reflection
private final ProcessGroupWrapper pg;
private final boolean isDevicecuda;
private final boolean broadcastBuffers;
private final boolean gradientAsBucketstore;
private final int bucketCapKb;
private List<Tensor> savedModels;
private Map<String, Object> extraState;
// Statistics
private long numForwardCalls = 0;
private long numBackwardCalls = 0;
private static final Map<String, DDPTrainer> instances = new ConcurrentHashMap<>();
/**
* Builder pattern for DDPTrainer configuration.
*/
public static class Builder {
private Module module;
private ProcessGroupWrapper processGroup;
private boolean broadcastBuffers = true;
private boolean gradientAsBucketstore = true;
private int bucketCapKb = 25 * 1024; // 25MB default
public Builder module(Module m) { this.module = m; return this; }
public Builder processGroup(ProcessGroupWrapper pg) { this.processGroup = pg; return this; }
public Builder broadcastBuffers(boolean b) { this.broadcastBuffers = b; return this; }
public Builder gradientAsBucketstore(boolean b) { this.gradientAsBucketstore = b; return this; }
public Builder bucketCapKb(int kb) { this.bucketCapKb = kb; return this; }
public DDPTrainer build() {
Objects.requireNonNull(module, "module is required");
Objects.requireNonNull(processGroup, "processGroup is required");
return new DDPTrainer(this);
}
}
private DDPTrainer(Builder builder) {
this.module = builder.module;
this.forwardMethod = findForwardMethod(builder.module);
this.pg = builder.processGroup;
this.broadcastBuffers = builder.broadcastBuffers;
this.gradientAsBucketstore = builder.gradientAsBucketstore;
this.bucketCapKb = builder.bucketCapKb;
this.isDevicecuda = pg.getDevice().type() == torch.DeviceType.CUDA;
this.extraState = new HashMap<>();
initialize();
}
/**
* Create a DDPTrainer instance.
*/
public static DDPTrainer create(Module module, ProcessGroupWrapper pg) {
return builder().module(module).processGroup(pg).build();
}
/**
* Get a new builder instance.
*/
public static Builder builder() {
return new Builder();
}
// ==================== Initialization ====================
private void initialize() {
Device device = pg.getDevice();
// Move model to the correct device
module.to(device, true);
if (broadcastBuffers) {
broadcastInitialParameters();
}
System.out.printf("[DDPTrainer] Initialized on rank %d with device=%s, worldSize=%d%n",
pg.getRank(), device, pg.getWorldSize());
}
/**
* Broadcast initial model parameters from rank 0.
* Skip if worldSize is 1 (no other ranks to broadcast from/to).
*/
private void broadcastInitialParameters() {
// Skip broadcast if only one rank - nothing to synchronize
if (pg.getWorldSize() <= 1) {
return;
}
List<Tensor> params = collectParameters();
if (!params.isEmpty()) {
// Broadcast each parameter individually to ensure consistent shapes
for (Tensor p : params) {
pg.broadcast(p, 0);
}
}
}
/**
* Collect all model parameters.
* Note: We clone each parameter to avoid dangling references after optimization steps.
*/
private List<Tensor> collectParameters() {
List<Tensor> params = new ArrayList<>();
TensorVector paramVec = module.parameters();
var begin = paramVec.begin();
var end = paramVec.end();
while(!begin.equals(end)){
Tensor p = begin.get();
if (p != null && !p.isNull()) {
params.add(p.clone());
}
begin.increment();
}
return params;
}
// ==================== Reflection-based Forward ====================
/**
* Find the forward(Tensor) method on the module class hierarchy.
*/
private Method findForwardMethod(Module module) {
Class<?> cls = module.getClass();
while (cls != null && cls != Object.class) {
try {
Method m = cls.getDeclaredMethod("forward", Tensor.class);
m.setAccessible(true);
System.out.printf("[DDPTrainer] Found forward method on %s%n", cls.getName());
return m;
} catch (NoSuchMethodException e) {
cls = cls.getSuperclass();
}
}
throw new UnsupportedOperationException(
"Module " + module.getClass().getName() + " has no forward(Tensor) method");
}
/**
* Invoke forward via reflection.
*/
private Tensor invokeForward(Method m, Module module, Tensor input) {
try {
return (Tensor) m.invoke(module, input);
} catch (Exception e) {
throw new RuntimeException("Failed to invoke forward: " + e.getMessage(), e);
}
}
// ==================== Forward/Backward ====================
/**
* Forward pass with automatic gradient synchronization on backward.
* This is the main entry point for DDP training.
*/
public Tensor forward(Tensor input) {
numForwardCalls++;
return invokeForward(forwardMethod, module, input);
}
/**
* Execute a training step: forward + backward + gradient sync.
*
* @return the loss tensor
*/
public Tensor step(Tensor input, Tensor target, Optimizer optimizer) {
// Forward
Tensor output = forward(input);
Tensor loss = torch.cross_entropy(output, target);
// Backward
optimizer.zero_grad();
loss.backward();
numBackwardCalls++;
// Synchronize gradients
reduceGradients();
// Optimizer step
optimizer.step();
return loss;
}
/**
* Simplified training step with loss computation.
*/
public Tensor training_step(Tensor input, Tensor target, Optimizer optimizer) {
return step(input, target, optimizer);
}
/**
* Reduce gradients across all ranks.
*/
private void reduceGradients() {
// Skip if world size is 1 - no need to reduce
if (pg.getWorldSize() <= 1) {
return;
}
List<Tensor> gradients = new ArrayList<>();
try {
TensorVector params = module.parameters();
var begin = params.begin();
var end = params.end();
while(!begin.equals(end)){
Tensor p = begin.get();
if (p != null && !p.isNull()) {
try {
Tensor grad = p.grad();
if (grad != null && !grad.isNull() && grad.defined()) {
gradients.add(grad);
}
} catch (Exception e) {
// Skip if grad access fails
}
}
begin.increment();
}
} catch (Exception e) {
// Skip gradient reduction on error
return;
}
if (!gradients.isEmpty()) {
pg.allreduce(gradients, ReduceOp.RedOpType.SUM);
// Divide by world size to get the average
int worldSize = pg.getWorldSize();
for (Tensor grad : gradients) {
grad.div_(new Scalar(worldSize));
}
}
}
/**
* Individual gradient synchronization (for custom training loops).
*/
public void synchronize() {
reduceGradients();
}
// ==================== Model Access ====================
/**
* Get the wrapped module.
*/
public Module getModule() {
return module;
}
/**
* Get local module (same as getModule).
*/
public Module getLocalModule() {
return module;
}
/**
* Get module for training (same as getModule).
*/
public Module getModuleForTraining() {
return module;
}
/**
* Get parameters iterator.
*/
public Iterable<Tensor> parameters() {
var begin = module.parameters().begin();
var end = module.parameters().end();
var list = new ArrayList<Tensor>();
while(!begin.equals(end)){
list.add(begin.get());
begin.increment();
}
return list;
// return module::parameters;
}
/**
* Set parameter for training.
* Uses set_() instead of deprecated data().copy_()
*/
public void setParameters(List<Tensor> params) {
TensorVector paramVec = module.parameters();
int i = 0;
var begin = paramVec.begin();
var end = paramVec.end();
while(!begin.equals(end)){
Tensor p = begin.get();
if (p != null && !p.isNull() && i < params.size()) {
p.set_(params.get(i));
i++;
}
begin.increment();
}
}
// ==================== State Management ====================
/**
* Get the module's buffers.
*/
public Map<String, Tensor> namedBuffers() {
Map<String, Tensor> buffers = new HashMap<>();
var begin = module.buffers().begin();
var end = module.buffers().end();
while(!begin.equals(end)){
// Buffers may not have names in Java version
buffers.put("buffer_" + buffers.size(), begin.get());
}
// for (Tensor buffer : module.buffers()) {
// // Buffers may not have names in Java version
// buffers.put("buffer_" + buffers.size(), buffer);
// }
return buffers;
}
/**
* Get extra state dict for checkpointing.
*/
public Map<String, Object> getTempStateDict() {
return extraState;
}
/**
* Load extra state dict.
*/
public void loadTempstate_dict(Map<String, Object> state) {
this.extraState.putAll(state);
}
// ==================== Utilities ====================
/**
* Set to training mode.
*/
public void train() {
module.train(true);
}
/**
* Set to evaluation mode.
*/
public void eval() {
module.eval();
}
/**
* Check if in training mode.
*/
public boolean isTraining() {
return module.is_training();
}
// ==================== Statistics ====================
/**
* Get number of forward calls.
*/
public long getNumForwardCalls() {
return numForwardCalls;
}
/**
* Get number of backward calls.
*/
public long getNumBackwardCalls() {
return numBackwardCalls;
}
/**
* Reset statistics.
*/
public void resetStats() {
numForwardCalls = 0;
numBackwardCalls = 0;
}
// ==================== Getters ====================
public ProcessGroupWrapper getProcessGroup() {
return pg;
}
public int getRank() {
return pg.getRank();
}
public int getWorldSize() {
return pg.getWorldSize();
}
public boolean isMainProcess() {
return pg.isMainProcess();
}
public Device getDevice() {
return pg.getDevice();
}
@Override
public void close() {
module.close();
instances.values().removeIf(d -> d == this);
}
@Override
public String toString() {
return String.format("DDPTrainer{rank=%d, worldSize=%d, device=%s, forwardCalls=%d}",
pg.getRank(), pg.getWorldSize(), pg.getDevice(), numForwardCalls);
}
}
```
```java
package torch.benchmark;
import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.pytorch.*;
import org.bytedeco.pytorch.global.torch;
import torch.BenchmarkNet;
import torch.amp.AutoCast;
import torch.amp.GradScaler;
import torch.distributed.DistributedStore;
import torch.distributed.ProcessGroupWrapper;
import torch.distributed.DDPTrainer;
import java.util.*;
import java.util.concurrent.*;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;
/**
* DDP Integration Benchmark.
* Tests the DDPTrainer wrapper for distributed training.
*
* Tests:
* - Gradient synchronization correctness
* - Loss convergence
* - Throughput and scaling
* - Mixed precision training
* - Memory stability
*/
public class DDPIntegrationBenchmark {
public static class TestResult {
public final String name;
public final boolean passed;
public final String message;
public final double metric;
public final long durationMs;
public TestResult(String name, boolean passed, String message, double metric, long durationMs) {
this.name = name;
this.passed = passed;
this.message = message;
this.metric = metric;
this.durationMs = durationMs;
}
@Override
public String toString() {
String status = passed ? "✅ PASS" : "❌ FAIL";
return String.format(" %-40s %s [%s] (%.2f)", name, status, message, metric);
}
}
public static class TestConfig {
public int worldSize = 1;
public int warmupSteps = 5;
public int testSteps = 50;
public int batchSize = 32;
public int inputSize = 512;
public int hiddenSize = 256;
public int numClasses = 10;
public boolean useGpu = true; // Default to CPU to avoid GPU errors
public boolean useMixedPrecision = true;
public static TestConfig fromArgs(String[] args) {
TestConfig config = new TestConfig();
for (int i = 0; i < args.length; i++) {
switch (args[i]) {
case "--world-size" -> config.worldSize = Integer.parseInt(args[++i]);
case "--steps" -> config.testSteps = Integer.parseInt(args[++i]);
case "--batch-size" -> config.batchSize = Integer.parseInt(args[++i]);
case "--input-size" -> config.inputSize = Integer.parseInt(args[++i]);
case "--use-gpu" -> config.useGpu = true;
case "--cpu-fallback", "--cpu" -> config.useGpu = false;
case "--mixed-precision" -> config.useMixedPrecision = true;
}
}
return config;
}
}
public static void main(String[] args) {
TestConfig config = TestConfig.fromArgs(args);
System.out.println("╔══════════════════════════════════════════════════════════╗");
System.out.println("║ DDP Integration Benchmark ║");
System.out.println("╚══════════════════════════════════════════════════════════╝");
System.out.println();
boolean cudaAvailable = torch.cuda_is_available();
boolean useGpu = config.useGpu && cudaAvailable;
System.out.println("Environment:");
System.out.println(" CUDA Available: " + cudaAvailable);
System.out.println(" Using GPU: " + useGpu);
System.out.println(" World Size: " + config.worldSize + " (forced to 1 for single-process mode)");
config.worldSize = 1; // Force single-process mode - multi-rank requires torchrun/mpirun
System.out.println();
List<TestResult> results = new ArrayList<>();
// Test gradient synchronization
results.add(testGradientSynchronization("Gradient-Sync", config, useGpu));
// Test loss convergence
results.add(testLossConvergence("Loss-Convergence", config, useGpu));
// Test throughput
results.add(testThroughput("Throughput", config, useGpu));
// Test scaling
results.add(testScaling("Scaling", config, useGpu));
// Test mixed precision
if (useGpu) {
results.add(testMixedPrecision("Mixed-Precision", config));
}
// Test memory stability
results.add(testMemoryStability("Memory-Stability", config, useGpu));
// Test numerical stability
results.add(testNumericalStability("Numerical-Stability", config, useGpu));
// Print results
printResults(results);
// Summary
int passed = (int) results.stream().filter(r -> r.passed).count();
int failed = results.size() - passed;
System.out.println();
System.out.printf("Summary: %d passed, %d failed%n", passed, failed);
System.out.println(failed == 0 ? "✅ ALL DDP INTEGRATION TESTS PASSED!" : "❌ SOME DDP INTEGRATION TESTS FAILED!");
}
static TestResult testGradientSynchronization(String name, TestConfig config, boolean useGpu) {
long start = System.currentTimeMillis();
AtomicInteger successCount = new AtomicInteger(0);
AtomicInteger totalCount = new AtomicInteger(0);
// Run in parallel for each rank
for (int rank = 0; rank < config.worldSize; rank++) {
final int r = rank;
CompletableFuture.runAsync(() -> {
try (PointerScope scope = new PointerScope()) {
setupAndRun(r, config, useGpu, () -> {
// Create model and DDP
BenchmarkNet model = new BenchmarkNet(config.inputSize, config.hiddenSize, config.numClasses);
// var dict = new StringAnyModuleDict();
DistributedStore.Options storeOptions = new DistributedStore.Options()
.type(DistributedStore.StoreType.FILE)
.masterPort(29500 + r);
DistributedStore store = DistributedStore.create(storeOptions, r, config.worldSize);
ProcessGroupWrapper.Options pgOptions = new ProcessGroupWrapper.Options()
.backend(useGpu ? ProcessGroupWrapper.BackendType.NCCL : ProcessGroupWrapper.BackendType.GLOO);
ProcessGroupWrapper pg = ProcessGroupWrapper.create(pgOptions, r, config.worldSize, store);
DDPTrainer ddp = DDPTrainer.create(model, pg);
Optimizer optimizer = new Adam(model.parameters(), new AdamOptions(1e-3));
// Training loop
for (int step = 0; step < config.testSteps; step++) {
Tensor input = torch.randn(new long[]{config.batchSize, config.inputSize});
Tensor target = torch.randint(config.numClasses, new long[]{config.batchSize});
if (useGpu) {
input = input.to(new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Float);
target = target.to(new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Long);
}
// Training step
Tensor output = ddp.forward(input);
Tensor loss = torch.cross_entropy(output, target);
optimizer.zero_grad();
loss.backward();
ddp.synchronize();
optimizer.step();
// Verify gradients safely
if (step > 0) {
try {
TensorVector params = model.parameters();
var it = params.begin();
var end = params.end();
while(!it.equals(end)) {
Tensor p = it.get();
if (p != null && !p.isNull() && p.grad().defined()) {
try {
totalCount.incrementAndGet();
float pf = p.grad().norm().item_float();
if (!Float.isNaN(pf) && !Float.isInfinite(pf)) {
successCount.incrementAndGet();
}
} catch (Exception e) {
// Skip invalid gradient
}
}
it.increment();
}
} catch (Exception e) {
// Skip gradient check on error
}
}
loss.close();
input.close();
target.close();
}
ddp.close();
pg.close();
store.close();
return true;
});
} catch (Exception e) {
System.err.println("Error in gradient sync test: " + e.getMessage());
}
});
}
waitForCompletion();
long duration = System.currentTimeMillis() - start;
boolean passed = successCount.get() > 0 && successCount.get() == totalCount.get();
double accuracy = totalCount.get() > 0
? (double) successCount.get() / totalCount.get() * 100
: 0;
return new TestResult(name, passed,
String.format("gradients synced: %d/%d (%.1f%%)", successCount.get(), totalCount.get(), accuracy),
accuracy, duration);
}
static TestResult testLossConvergence(String name, TestConfig config, boolean useGpu) {
long start = System.currentTimeMillis();
List<Double> losses = Collections.synchronizedList(new ArrayList<>());
boolean converged = false;
for (int rank = 0; rank < config.worldSize; rank++) {
final int r = rank;
CompletableFuture.runAsync(() -> {
try (PointerScope scope = new PointerScope()) {
setupAndRun(r, config, useGpu, () -> {
BenchmarkNet model = new BenchmarkNet(config.inputSize, config.hiddenSize, config.numClasses);
DistributedStore.Options storeOptions = new DistributedStore.Options()
.type(DistributedStore.StoreType.FILE)
.masterPort(29600 + r);
DistributedStore store = DistributedStore.create(storeOptions, r, config.worldSize);
ProcessGroupWrapper.Options pgOptions = new ProcessGroupWrapper.Options()
.backend(useGpu ? ProcessGroupWrapper.BackendType.NCCL : ProcessGroupWrapper.BackendType.GLOO);
ProcessGroupWrapper pg = ProcessGroupWrapper.create(pgOptions, r, config.worldSize, store);
DDPTrainer ddp = DDPTrainer.create(model, pg);
Optimizer optimizer = new Adam(model.parameters(), new AdamOptions(1e-3));
double initialLoss = 0;
double finalLoss = 0;
for (int step = 0; step < config.testSteps; step++) {
Tensor input = torch.randn(new long[]{config.batchSize, config.inputSize});
Tensor target = torch.randint(config.numClasses, new long[]{config.batchSize});
if (useGpu) {
input = input.to(new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Float);
target = target.to(new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Long);
}
Tensor output = ddp.forward(input);
Tensor loss = torch.cross_entropy(output, target);
// Safe loss extraction
if (step == 0) {
try {
initialLoss = loss.item_double();
} catch (Exception e) {
initialLoss = -1;
}
}
optimizer.zero_grad();
loss.backward();
ddp.synchronize();
optimizer.step();
try {
finalLoss = loss.item_double();
} catch (Exception e) {
finalLoss = -1;
}
if (step >= config.testSteps / 2) {
losses.add(finalLoss);
}
loss.close();
input.close();
target.close();
}
ddp.close();
pg.close();
store.close();
return finalLoss < initialLoss * 1.5;
});
} catch (Exception e) {
System.err.println("Error in convergence test: " + e.getMessage());
}
});
}
waitForCompletion();
long duration = System.currentTimeMillis() - start;
if (!losses.isEmpty()) {
double avgLoss = losses.stream().mapToDouble(d -> d).average().orElse(0);
converged = avgLoss < 5.0; // Loss should be reasonable
}
double convergenceRate = losses.isEmpty() ? 0 :
losses.get(losses.size() - 1) / (losses.get(0) + 0.001);
return new TestResult(name, converged,
String.format("avgLoss=%.4f, convergence=%.2f", losses.isEmpty() ? 0 : losses.stream().mapToDouble(d -> d).average().orElse(0), convergenceRate),
convergenceRate, duration);
}
static TestResult testThroughput(String name, TestConfig config, boolean useGpu) {
long start = System.currentTimeMillis();
AtomicLong totalSamples = new AtomicLong(0);
for (int rank = 0; rank < config.worldSize; rank++) {
final int r = rank;
CompletableFuture.runAsync(() -> {
try (PointerScope scope = new PointerScope()) {
setupAndRun(r, config, useGpu, () -> {
BenchmarkNet model = new BenchmarkNet(config.inputSize, config.hiddenSize, config.numClasses);
DistributedStore.Options storeOptions = new DistributedStore.Options()
.type(DistributedStore.StoreType.FILE)
.masterPort(29700 + r);
DistributedStore store = DistributedStore.create(storeOptions, r, config.worldSize);
ProcessGroupWrapper.Options pgOptions = new ProcessGroupWrapper.Options()
.backend(useGpu ? ProcessGroupWrapper.BackendType.NCCL : ProcessGroupWrapper.BackendType.GLOO);
ProcessGroupWrapper pg = ProcessGroupWrapper.create(pgOptions, r, config.worldSize, store);
DDPTrainer ddp = DDPTrainer.create(model, pg);
Optimizer optimizer = new Adam(model.parameters(), new AdamOptions(1e-3));
long stepStart = System.currentTimeMillis();
for (int step = 0; step < config.testSteps; step++) {
Tensor input = torch.randn(new long[]{config.batchSize, config.inputSize});
Tensor target = torch.randint(config.numClasses, new long[]{config.batchSize});
if (useGpu) {
input = input.to(new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Float);
target = target.to(new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Long);
}
Tensor output = ddp.forward(input);
Tensor loss = torch.cross_entropy(output, target);
optimizer.zero_grad();
loss.backward();
ddp.synchronize();
optimizer.step();
totalSamples.addAndGet(config.batchSize);
loss.close();
input.close();
target.close();
}
if (useGpu) torch.cuda_synchronize();
long elapsed = System.currentTimeMillis() - stepStart;
ddp.close();
pg.close();
store.close();
return elapsed > 0;
});
} catch (Exception e) {
System.err.println("Error in throughput test: " + e.getMessage());
}
});
}
waitForCompletion();
long duration = System.currentTimeMillis() - start;
double throughput = totalSamples.get() * 1000.0 / duration;
double speedup = throughput / (config.batchSize * 10); // Baseline
boolean passed = throughput > 0;
return new TestResult(name, passed,
String.format("throughput=%.2f samples/s, total=%d", throughput, totalSamples.get()),
throughput, duration);
}
static TestResult testScaling(String name, TestConfig config, boolean useGpu) {
long start = System.currentTimeMillis();
// Multi-rank scaling requires actual multi-process setup (torchrun).
// In single-process mode, test with worldSize=1 only to avoid GLOO inter-process communication errors.
List<Double> throughputs = Collections.synchronizedList(new ArrayList<>());
int testWs = 1; // Single-process mode only
for (int rank = 0; rank < testWs; rank++) {
final int r = rank;
CompletableFuture.runAsync(() -> {
try (PointerScope scope = new PointerScope()) {
setupAndRun(r, config, useGpu, () -> {
BenchmarkNet model = new BenchmarkNet(config.inputSize, config.hiddenSize, config.numClasses);
DistributedStore.Options storeOptions = new DistributedStore.Options()
.type(DistributedStore.StoreType.FILE)
.masterPort(29800 + r + testWs * 100);
DistributedStore store = DistributedStore.create(storeOptions, r, testWs);
ProcessGroupWrapper.Options pgOptions = new ProcessGroupWrapper.Options()
.backend(useGpu ? ProcessGroupWrapper.BackendType.NCCL : ProcessGroupWrapper.BackendType.GLOO);
ProcessGroupWrapper pg = ProcessGroupWrapper.create(pgOptions, r, testWs, store);
DDPTrainer ddp = DDPTrainer.create(model, pg);
Optimizer optimizer = new Adam(model.parameters(), new AdamOptions(1e-3));
long stepStart = System.currentTimeMillis();
int steps = Math.max(10, config.testSteps / 2);
for (int step = 0; step < steps; step++) {
Tensor input = torch.randn(new long[]{config.batchSize, config.inputSize});
Tensor target = torch.randint(config.numClasses, new long[]{config.batchSize});
if (useGpu) {
input = input.to(new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Float);
target = target.to(new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Long);
}
Tensor output = ddp.forward(input);
Tensor loss = torch.cross_entropy(output, target);
optimizer.zero_grad();
loss.backward();
ddp.synchronize();
optimizer.step();
loss.close();
input.close();
target.close();
}
if (useGpu) torch.cuda_synchronize();
long elapsed = System.currentTimeMillis() - stepStart;
throughputs.add((double) config.batchSize * steps / elapsed * 1000);
ddp.close();
pg.close();
store.close();
return true;
});
} catch (Exception e) {
System.err.println("Error: " + e.getMessage());
}
});
}
waitForCompletion();
long duration = System.currentTimeMillis() - start;
double efficiency = 0;
if (!throughputs.isEmpty()) {
efficiency = throughputs.stream().mapToDouble(d -> d).average().orElse(0);
}
// Scaling test passes if single-process throughput is measurable.
// Multi-process scaling requires torchrun/mpirun infrastructure.
boolean passed = efficiency > 0;
return new TestResult(name, passed,
String.format("single-process throughput=%.1f samples/s (multi-rank requires torchrun)", efficiency),
efficiency, duration);
}
static TestResult testMixedPrecision(String name, TestConfig config) {
long start = System.currentTimeMillis();
AtomicInteger successCount = new AtomicInteger(0);
AtomicInteger totalCount = new AtomicInteger(0);
for (int rank = 0; rank < config.worldSize; rank++) {
final int r = rank;
CompletableFuture.runAsync(() -> {
try (PointerScope scope = new PointerScope()) {
DistributedStore.Options storeOptions = new DistributedStore.Options()
.type(DistributedStore.StoreType.FILE)
.masterPort(29900 + r);
DistributedStore store = DistributedStore.create(storeOptions, r, config.worldSize);
ProcessGroupWrapper.Options pgOptions = new ProcessGroupWrapper.Options()
.backend(ProcessGroupWrapper.BackendType.NCCL);
ProcessGroupWrapper pg = ProcessGroupWrapper.create(pgOptions, r, config.worldSize, store);
BenchmarkNet model = new BenchmarkNet(config.inputSize, config.hiddenSize, config.numClasses);
DDPTrainer ddp = DDPTrainer.create(model, pg);
Optimizer optimizer = new Adam(model.parameters(), new AdamOptions(1e-3));
GradScaler scaler = new GradScaler();
for (int step = 0; step < config.testSteps; step++) {
Tensor input = torch.randn(new long[]{config.batchSize, config.inputSize}).to(
new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Float);
Tensor target = torch.randint(config.numClasses, new long[]{config.batchSize}).to(
new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Long);
try (AutoCast ac = AutoCast.cuda(AutoCast.Precision.BF16)) {
Tensor output = ddp.forward(input);
Tensor loss = torch.cross_entropy(output, target);
loss = scaler.scale(loss);
optimizer.zero_grad();
loss.backward();
ddp.synchronize();
scaler.step(optimizer, model.parameters());
scaler.update();
totalCount.incrementAndGet();
try {
float lossVal = loss.item_float();
if (!Float.isNaN(lossVal) && !Float.isInfinite(lossVal)) {
successCount.incrementAndGet();
}
} catch (Exception e) {
// Skip invalid loss
}
loss.close();
}
input.close();
target.close();
}
torch.cuda_synchronize();
ddp.close();
pg.close();
store.close();
} catch (Exception e) {
System.err.println("Error in mixed precision test: " + e.getMessage());
}
});
}
waitForCompletion();
long duration = System.currentTimeMillis() - start;
double successRate = totalCount.get() > 0
? (double) successCount.get() / totalCount.get() * 100
: 0;
boolean passed = successRate > 95;
return new TestResult(name, passed,
String.format("success rate=%.1f%% (%d/%d)", successRate, successCount.get(), totalCount.get()),
successRate, duration);
}
static TestResult testMemoryStability(String name, TestConfig config, boolean useGpu) {
long start = System.currentTimeMillis();
AtomicLong maxMemory = new AtomicLong(0);
AtomicInteger stableIterations = new AtomicInteger(0);
for (int rank = 0; rank < config.worldSize; rank++) {
final int r = rank;
CompletableFuture.runAsync(() -> {
try (PointerScope scope = new PointerScope()) {
setupAndRun(r, config, useGpu, () -> {
BenchmarkNet model = new BenchmarkNet(config.inputSize, config.hiddenSize, config.numClasses);
DistributedStore.Options storeOptions = new DistributedStore.Options()
.type(DistributedStore.StoreType.FILE)
.masterPort(30000 + r);
DistributedStore store = DistributedStore.create(storeOptions, r, config.worldSize);
ProcessGroupWrapper.Options pgOptions = new ProcessGroupWrapper.Options()
.backend(useGpu ? ProcessGroupWrapper.BackendType.NCCL : ProcessGroupWrapper.BackendType.GLOO);
ProcessGroupWrapper pg = ProcessGroupWrapper.create(pgOptions, r, config.worldSize, store);
DDPTrainer ddp = DDPTrainer.create(model, pg);
Optimizer optimizer = new Adam(model.parameters(), new AdamOptions(1e-3));
for (int step = 0; step < config.testSteps; step++) {
Tensor input = torch.randn(new long[]{config.batchSize, config.inputSize});
Tensor target = torch.randint(config.numClasses, new long[]{config.batchSize});
if (useGpu) {
input = input.to(new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Float);
target = target.to(new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Long);
}
Tensor output = ddp.forward(input);
Tensor loss = torch.cross_entropy(output, target);
optimizer.zero_grad();
loss.backward();
ddp.synchronize();
optimizer.step();
if (useGpu) {
torch.cuda_synchronize();
// Estimate memory: batch*input Float32 + batch Long labels
long memUsed = (long) config.batchSize * config.inputSize * 4 / (1024 * 1024) + 1;
maxMemory.updateAndGet(m -> Math.max(m, memUsed));
}
stableIterations.incrementAndGet();
loss.close();
input.close();
target.close();
}
ddp.close();
pg.close();
store.close();
return true;
});
} catch (Exception e) {
System.err.println("Error: " + e.getMessage());
}
});
}
waitForCompletion();
long duration = System.currentTimeMillis() - start;
boolean passed = stableIterations.get() >= config.testSteps * config.worldSize;
return new TestResult(name, passed,
String.format("stable iterations=%d/%d, maxMemory=%dMB",
stableIterations.get(), config.testSteps * config.worldSize, maxMemory.get()),
maxMemory.get(), duration);
}
static TestResult testNumericalStability(String name, TestConfig config, boolean useGpu) {
long start = System.currentTimeMillis();
AtomicInteger stableSteps = new AtomicInteger(0);
List<Double> losses = Collections.synchronizedList(new ArrayList<>());
for (int rank = 0; rank < config.worldSize; rank++) {
final int r = rank;
CompletableFuture.runAsync(() -> {
try (PointerScope scope = new PointerScope()) {
setupAndRun(r, config, useGpu, () -> {
BenchmarkNet model = new BenchmarkNet(config.inputSize, config.hiddenSize, config.numClasses);
DistributedStore.Options storeOptions = new DistributedStore.Options()
.type(DistributedStore.StoreType.FILE)
.masterPort(30100 + r);
DistributedStore store = DistributedStore.create(storeOptions, r, config.worldSize);
ProcessGroupWrapper.Options pgOptions = new ProcessGroupWrapper.Options()
.backend(useGpu ? ProcessGroupWrapper.BackendType.NCCL : ProcessGroupWrapper.BackendType.GLOO);
ProcessGroupWrapper pg = ProcessGroupWrapper.create(pgOptions, r, config.worldSize, store);
DDPTrainer ddp = DDPTrainer.create(model, pg);
var optimizer = new Adam(model.parameters(true), new AdamOptions(1e-3));
for (int step = 0; step < config.testSteps; step++) {
Tensor input = torch.randn(new long[]{config.batchSize, config.inputSize});
Tensor target = torch.randint(config.numClasses, new long[]{config.batchSize});
if (useGpu) {
input = input.to(new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Float);
target = target.to(new Device(torch.DeviceType.CUDA, (byte) r), torch.ScalarType.Long);
}
Tensor output = ddp.forward(input);
Tensor loss = torch.cross_entropy(output, target);
try {
double lossVal = loss.item_double();
if (!Double.isNaN(lossVal) && !Double.isInfinite(lossVal)) {
losses.add(lossVal);
stableSteps.incrementAndGet();
}
} catch (Exception e) {
// Skip invalid loss
}
optimizer.zero_grad();
loss.backward();
ddp.synchronize();
optimizer.step();
loss.close();
input.close();
target.close();
}
ddp.close();
pg.close();
store.close();
return true;
});
} catch (Exception e) {
System.err.println("Error: " + e.getMessage());
}
});
}
waitForCompletion();
long duration = System.currentTimeMillis() - start;
boolean passed = stableSteps.get() >= config.testSteps * config.worldSize * 0.9;
return new TestResult(name, passed,
String.format("stable steps=%d/%d", stableSteps.get(), config.testSteps * config.worldSize),
(double) stableSteps.get(), duration);
}
static Boolean setupAndRun(int rank, TestConfig config, boolean useGpu, java.util.function.Supplier<Boolean> task) {
return task.get();
}
static void waitForCompletion() {
try {
Thread.sleep(2000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
}
}
static void printResults(List<TestResult> results) {
System.out.println();
System.out.println("=".repeat(90));
System.out.println("DDP Integration Test Results");
System.out.println("=".repeat(90));
for (TestResult r : results) {
System.out.println(r);
}
}
}
```
|
Beta Was this translation helpful? Give feedback.
0 replies
-
// ==================== 测试2: AnyModule + ReLUImpl ====================
static TestResult testAnyModuleReLU(String name, long inputSize, long hiddenSize, long numClasses, boolean useGpu) {
System.out.println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
System.out.println("[Test] " + name);
System.out.println(" 方式: AnyModule(ReLUImpl).forward(Tensor)");
System.out.println(" 说明: ReLUImpl 在 AnyModule 构造函数列表中,✅ 正常工作");
try (PointerScope scope = new PointerScope()) {
Device device = getDevice(useGpu);
ReLUImpl relu = new ReLUImpl();
AnyModule am = new AnyModule(relu);
System.out.println(" AnyModule.is_empty() = " + am.is_empty());
// ReLU 单独使用时通常是层的中间步骤,这里做简单验证
Tensor input = torch.randn(new long[]{4, 128}).to(device, torch.ScalarType.Float);
Tensor output = am.forward(input);
System.out.println(" ReLU forward OK: " + Arrays.toString(output.shape()));
boolean ok = output.shape()[0] == 4 && output.shape()[1] == 128;
System.out.println(" ✅ AnyModule(ReLUImpl) 正常工作");
input.close(); output.close();
am.close(); relu.close();
return new TestResult(name, ok, "ReLU forward", 1.0);
} catch (Exception e) {
System.err.println(" ❌ " + e.getMessage());
return new TestResult(name, false, e.getMessage(), 0);
}
}
// ==================== 测试3: AnyModule 包装多种 Impl 类型 ====================
static TestResult testAnyModuleMultiType(String name, long inputSize, long hiddenSize, long numClasses, boolean useGpu) {
System.out.println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
System.out.println("[Test] " + name);
System.out.println(" 方式: 多个 AnyModule 分别包装 LinearImpl,不同模块统一 API");
System.out.println(" 说明: AnyModule.forward(Tensor) 对所有支持的 Impl 类型统一有效");
try (PointerScope scope = new PointerScope()) {
Device device = getDevice(useGpu);
// 三个不同类型的模块,全部用 AnyModule 包装
LinearImpl lin1 = new LinearImpl(inputSize, hiddenSize);
lin1.to(device, false);
LinearImpl lin2 = new LinearImpl(hiddenSize, hiddenSize);
lin2.to(device, false);
LinearImpl lin3 = new LinearImpl(hiddenSize, numClasses);
lin3.to(device, false);
AnyModule am1 = new AnyModule(lin1);
AnyModule am2 = new AnyModule(lin2);
AnyModule am3 = new AnyModule(lin3);
System.out.println(" 3 个 AnyModule 已创建(LinearImpl × 3)");
Optimizer optimizer = new Adam(lin3.parameters(), new AdamOptions(1e-3));
List<Double> losses = new ArrayList<>();
for (int step = 0; step < 20; step++) {
Tensor input = mkInput(16, inputSize, device);
Tensor target = mkTarget(16, numClasses, device);
// 链式调用 — 相同 API,不同实现
Tensor out = am1.forward(input);
out = am2.forward(out);
out = am3.forward(out);
Tensor loss = torch.cross_entropy(out, target);
losses.add(loss.item_double());
optimizer.zero_grad(); loss.backward(); optimizer.step();
closeAll(loss, out, input, target);
}
double finalLoss = losses.get(losses.size() - 1);
double initialLoss = losses.get(0);
System.out.printf(" loss: %.4f → %.4f%n", initialLoss, finalLoss);
System.out.println(" ✅ 多 AnyModule 链式调用成功");
optimizer.close();
am1.close(); am2.close(); am3.close();
lin1.close(); lin2.close(); lin3.close();
return new TestResult(name, true, String.format("loss: %.4f→%.4f", initialLoss, finalLoss), finalLoss);
} catch (Exception e) {
System.err.println(" ❌ " + e.getMessage());
return new TestResult(name, false, e.getMessage(), 0);
}
}
// ==================== 测试4: SequentialImpl 直接 forward ====================
static TestResult testSequentialDirect(String name, long inputSize, long hiddenSize, long numClasses, boolean useGpu) {
System.out.println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
System.out.println("[Test] " + name);
System.out.println(" 方式: SequentialImpl.forward(Tensor) 直接调用");
System.out.println(" 说明: SequentialImpl 不在 AnyModule 构造函数列表中,用 direct forward");
try (PointerScope scope = new PointerScope()) {
Device device = getDevice(useGpu);
SequentialImpl seq = new SequentialImpl();
seq.push_back(new LinearImpl(inputSize, hiddenSize));
seq.push_back(new ReLUImpl());
seq.push_back(new LinearImpl(hiddenSize, hiddenSize));
seq.push_back(new ReLUImpl());
seq.push_back(new LinearImpl(hiddenSize, numClasses));
AnyModule am1 = new AnyModule(seq);
seq.to(device, false);
// am1.ptr().to(device, false);
System.out.println(" SequentialImpl 已构建(3 个 Linear + 2 个 ReLU)");
Optimizer optimizer = new Adam(seq.parameters(), new AdamOptions(1e-3));
List<Double> losses = new ArrayList<>();
for (int step = 0; step < 20; step++) {
Tensor input = mkInput(16, inputSize, device);
Tensor target = mkTarget(16, numClasses, device);
Tensor output = am1.any_forward(input).getTensor(); // 直接 forward
Tensor loss = torch.cross_entropy(output, target);
losses.add(loss.item_double());
optimizer.zero_grad(); loss.backward(); optimizer.step();
closeAll(loss, output, input, target);
}
double finalLoss = losses.get(losses.size() - 1);
double initialLoss = losses.get(0);
System.out.printf(" loss: %.4f → %.4f%n", initialLoss, finalLoss);
System.out.println(" ✅ SequentialImpl.forward() 训练成功");
optimizer.close();
seq.close();
return new TestResult(name, true, String.format("loss: %.4f→%.4f", initialLoss, finalLoss), finalLoss);
} catch (Exception e) {
System.err.println(" ❌ " + e.getMessage());
return new TestResult(name, false, e.getMessage(), 0);
}
}
// ==================== 测试5: 反射 + 自定义 Module ====================
static TestResult testReflectionCustomModule(String name, long inputSize, long hiddenSize, long numClasses, boolean useGpu) {
System.out.println("━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━");
System.out.println("[Test] " + name);
System.out.println(" 方式: Java 反射查找并调用 forward(Tensor)");
System.out.println(" 说明: FSDPTrainer/DDPTrainer 内部使用此方式");
try (PointerScope scope = new PointerScope()) {
Device device = getDevice(useGpu);
CustomModuleNet model = new CustomModuleNet(inputSize, hiddenSize, numClasses);
model.to(device, false);
// 反射查找一次
java.lang.reflect.Method fm = findForwardMethod(model);
System.out.println(" 反射找到: " + fm.getDeclaringClass().getSimpleName() + ".forward(Tensor)");
Optimizer optimizer = new Adam(model.parameters(), new AdamOptions(1e-3));
List<Double> losses = new ArrayList<>();
for (int step = 0; step < 20; step++) {
Tensor input = mkInput(16, inputSize, device);
Tensor target = mkTarget(16, numClasses, device);
Tensor output = (Tensor) fm.invoke(model, input);
Tensor loss = torch.cross_entropy(output, target);
losses.add(loss.item_double());
optimizer.zero_grad(); loss.backward(); optimizer.step();
closeAll(loss, output, input, target);
}
double finalLoss = losses.get(losses.size() - 1);
double initialLoss = losses.get(0);
System.out.printf(" loss: %.4f → %.4f%n", initialLoss, finalLoss);
System.out.println(" ✅ 反射调用 forward(Tensor) 训练成功");
optimizer.close();
model.close();
return new TestResult(name, true, String.format("loss: %.4f→%.4f", initialLoss, finalLoss), finalLoss);
} catch (Exception e) {
System.err.println(" ❌ " + e.getMessage());
return new TestResult(name, false, e.getMessage(), 0);
}
}
SequentialImpl seq = new SequentialImpl();
seq.push_back(new LinearImpl(inputSize, hiddenSize));
seq.push_back(new ReLUImpl());
seq.push_back(new LinearImpl(hiddenSize, hiddenSize));
seq.push_back(new ReLUImpl());
seq.push_back(new LinearImpl(hiddenSize, numClasses));
AnyModule am1 = new AnyModule(seq);
``
will crash |
Beta Was this translation helpful? Give feedback.
0 replies
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Uh oh!
There was an error while loading. Please reload this page.
Uh oh!
There was an error while loading. Please reload this page.
-
Dear JavaCPP Team & @saudet @HGuillemet ,
After developing downstream tools that implement vLLM, transformers, and TRL workflows on top of javacpp-pytorch, I’ve identified several critical blocking issues that severely hurt the usability, generality, and competitiveness of the library. These issues prevent core framework logic from being fully closed-loop and make javacpp-pytorch feel impractical for real-world LLM/deep learning use cases.
I strongly believe these three issues should be the top priorities to resolve before the 1.5.14 official release.
When building generic frameworks like vLLM, transformers, or TRL on top of javacpp-pytorch:
We must use the parent class Module to represent anonymous models in a generic/abstract way.
However, the base Module class does NOT have a forward() method.
This creates a critical design gap:
We cannot call model.forward(...) generically for any model.
Core inference/training logic cannot be properly abstracted or closed-loop.
This limitation makes javacpp-pytorch not generic enough and severely reduces its competitiveness compared to native PyTorch.
This is a P0 (highest priority) issue. If solving this alone is too difficult, I suggest we use Claude Opus 4.8 to assist with the implementation.
The C++ libTorch headers for serialization & TorchScript are not scanned/included in the current build:
Missing headers for torch::export, torch::jit::trace, torch::jit::script
This means:
We cannot implement torch.save(), torch.jit.trace(), or torch.jit.script() in Java.
Models trained in Java cannot be exported to AOT/compiled TorchScript files.
The model lifecycle is incomplete (no end-to-end save/export).
Simply adding the missing serialization headers should resolve this and close the model serialization loop.
3. Kineto Profiler Enablement Issue
The Kineto profiler cannot be properly enabled/configured in the current build. This is required for performance profiling and production debugging.
Summary of Top 3 Priorities for 1.5.14
Add forward() to the base Module class (critical for generic frameworks)
Add missing libTorch export/serialization headers (enable torch.save, torch.jit.trace, torch.jit.script)
Fix Kineto profiler initialization/enablement
Resolving these three issues will make javacpp-pytorch fully closed-loop, production-ready, and competitive with native PyTorch for LLM/transformers/trl workloads.
What do you all think?
Beta Was this translation helpful? Give feedback.
All reactions