New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Plot benchmark speed against pytorch #20
Labels
documentation
Improvements or additions to documentation
Comments
rust code for benching: use dfdx::prelude::*;
use rand::{prelude::StdRng, SeedableRng};
use rand_distr::StandardNormal;
use std::time::{Duration, Instant};
fn main() {
let mut rng = StdRng::seed_from_u64(0);
let mut l: Linear<512, 256> = Default::default();
l.randomize(&mut rng, &StandardNormal);
let mut opt = Adam::default();
const N: usize = 10000;
let mut total = Duration::default();
for _ in 0..N {
let x: Tensor2D<32, 512> = Tensor2D::randn(&mut rng);
let y = l.forward(x.traced());
let loss = y.square().mean();
let start = Instant::now();
let gradients = loss.backward();
opt.update(&mut l, gradients);
total += start.elapsed();
}
println!("{:?} batch per s", N as f32 / total.as_secs_f32());
} |
Python code for benching: from datetime import datetime, timedelta
import torch
torch.manual_seed(0)
l = torch.nn.Linear(512, 256)
opt = torch.optim.Adam(l.parameters())
total = timedelta()
N = 10000
for _ in range(N):
x = torch.randn(32, 512)
y = l(x)
loss = y.square().mean()
start = datetime.now()
opt.zero_grad()
loss.backward()
opt.step()
total += datetime.now() - start
print(N / total.total_seconds()) |
Closed
Both dfdx version and torch version should use flush to zero (#60) |
Closing - might do in future, but this will continue to be ad-hoc for now |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The text was updated successfully, but these errors were encountered: