Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Plot benchmark speed against pytorch #20

Closed
4 tasks
coreylowman opened this issue May 20, 2022 · 5 comments
Closed
4 tasks

Plot benchmark speed against pytorch #20

coreylowman opened this issue May 20, 2022 · 5 comments
Labels
documentation Improvements or additions to documentation

Comments

@coreylowman
Copy link
Owner

coreylowman commented May 20, 2022

  • Linear batched forward (matmul & broadcast add)
  • Backprop algorithm
  • Optimizer updates
  • forward with tape & without tape
@coreylowman
Copy link
Owner Author

rust code for benching:

use dfdx::prelude::*;
use rand::{prelude::StdRng, SeedableRng};
use rand_distr::StandardNormal;
use std::time::{Duration, Instant};

fn main() {
    let mut rng = StdRng::seed_from_u64(0);

    let mut l: Linear<512, 256> = Default::default();
    l.randomize(&mut rng, &StandardNormal);

    let mut opt = Adam::default();

    const N: usize = 10000;
    let mut total = Duration::default();
    for _ in 0..N {
        let x: Tensor2D<32, 512> = Tensor2D::randn(&mut rng);
        let y = l.forward(x.traced());
        let loss = y.square().mean();
        let start = Instant::now();
        let gradients = loss.backward();
        opt.update(&mut l, gradients);
        total += start.elapsed();
    }
    println!("{:?} batch per s", N as f32 / total.as_secs_f32());
}

@coreylowman
Copy link
Owner Author

Python code for benching:

from datetime import datetime, timedelta
import torch

torch.manual_seed(0)

l = torch.nn.Linear(512, 256)
opt = torch.optim.Adam(l.parameters())

total = timedelta()
N = 10000
for _ in range(N):
    x = torch.randn(32, 512)
    y = l(x)
    loss = y.square().mean()
    start = datetime.now()
    opt.zero_grad()
    loss.backward()
    opt.step()
    total += datetime.now() - start

print(N / total.total_seconds())

@coreylowman coreylowman added the documentation Improvements or additions to documentation label May 26, 2022
@coreylowman coreylowman mentioned this issue May 26, 2022
@coreylowman
Copy link
Owner Author

@coreylowman
Copy link
Owner Author

Both dfdx version and torch version should use flush to zero (#60)

coreylowman added a commit that referenced this issue Jun 28, 2022
For simplicity of example
@coreylowman coreylowman mentioned this issue Oct 19, 2022
5 tasks
@coreylowman
Copy link
Owner Author

Closing - might do in future, but this will continue to be ad-hoc for now

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
documentation Improvements or additions to documentation
Projects
None yet
Development

No branches or pull requests

1 participant