Skip to content

A JavaScript library like PyTorch, built from scratch.

License

Notifications You must be signed in to change notification settings

kayabaakihiko13/js-torch

 
 

Repository files navigation

js-torch

PyTorch in JavaScript

  • JS-Torch is a Deep Learning JavaScript library built from scratch, to closely follow PyTorch's syntax.
  • It contains a fully functional Tensor object, which can track gradients, Deep Learning Layers and functions, and an Automatic Differentiation engine.
  • Feel free to try out the Web Demo!

Implemented Tensor Operations:
Implemented Deep Learning Layers:

1. Project Structure

  • assets/ : Folder to store images and the Demo.
  • src/ : Framework with JavaScript files.
    • src/tensor.ts: File with the Tensor class and all of the tensor Operations.
    • src/utils.ts: File with operations and helper functions.
    • src/layers.ts: Submodule of the framework. Contains full layers.
    • src/optim.ts: Submodule of the framework. Contains Adam Optimizer.
  • tests/: Folder with unit tests. Contains test.ts.

2. Running it Yourself

Simple Autograd Example:

import { torch } from "js-pytorch";

// Instantiate Tensors:
let x = torch.randn([8, 4, 5]);
let w = torch.randn([8, 5, 4], (requires_grad = true));
let b = torch.tensor([0.2, 0.5, 0.1, 0.0], (requires_grad = true));

// Make calculations:
let out = torch.matmul(x, w);
out = torch.add(out, b);

// Compute gradients on whole graph:
out.backward();

// Get gradients from specific Tensors:
console.log(w.grad);
console.log(b.grad);

Complex Autograd Example (Transformer):

import { torch } from "js-pytorch";
const nn = torch.nn;

class Transformer extends nn.Module {
  constructor(vocab_size, hidden_size, n_timesteps, n_heads, p) {
    super();
    // Instantiate Transformer's Layers:
    this.embed = new nn.Embedding(vocab_size, hidden_size);
    this.pos_embed = new nn.PositionalEmbedding(n_timesteps, hidden_size);
    this.b1 = new nn.Block(
      hidden_size,
      hidden_size,
      n_heads,
      n_timesteps,
      (dropout_p = p)
    );
    this.b2 = new nn.Block(
      hidden_size,
      hidden_size,
      n_heads,
      n_timesteps,
      (dropout_p = p)
    );
    this.ln = new nn.LayerNorm(hidden_size);
    this.linear = new nn.Linear(hidden_size, vocab_size);
  }

  forward(x) {
    let z;
    z = torch.add(this.embed.forward(x), this.pos_embed.forward(x));
    z = this.b1.forward(z);
    z = this.b2.forward(z);
    z = this.ln.forward(z);
    z = this.linear.forward(z);
    return z;
  }
}

// Instantiate your custom nn.Module:
const model = new Transformer(
  vocab_size,
  hidden_size,
  n_timesteps,
  n_heads,
  dropout_p
);

// Define loss function and optimizer:
const loss_func = new nn.CrossEntropyLoss();
const optimizer = new optim.Adam(model.parameters(), (lr = 5e-3), (reg = 0));

// Instantiate sample input and output:
let x = torch.randint(0, vocab_size, [batch_size, n_timesteps, 1]);
let y = torch.randint(0, vocab_size, [batch_size, n_timesteps]);
let loss;

// Training Loop:
for (let i = 0; i < 40; i++) {
  // Forward pass through the Transformer:
  let z = model.forward(x);

  // Get loss:
  loss = loss_func.forward(z, y);

  // Backpropagate the loss using torch.tensor's backward() method:
  loss.backward();

  // Update the weights:
  optimizer.step();

  // Reset the gradients to zero after each training step:
  optimizer.zero_grad();
}

Note: You can install the package locally with: npm install js-pytorch


3. Distribution & Devtools

  • To Build for Distribution, run npm run build. CJS and ESM modules and index.d.ts will be output in the dist/ folder.
  • To check the code with ESLint at any time, run npm run lint.
  • To improve code formatting with prettier, run npm run prettier.

4. Results

  • The models implemented in the unit tests all converged to near-zero losses.
  • Run them with npm test!
  • This package is not as optimized as PyTorch yet, but I tried making it more interpretable. Efficiency improvements are incoming!
  • Hope you enjoy!

5. Benchmarks

  • Performance benchmarks are also included and tracked in the tests/benchmarks/ directory.
  • Run all benchmarks with npm run bench
  • Save new benchmarks with npm run bench:update and add the updated files to your commit.

About

A JavaScript library like PyTorch, built from scratch.

Resources

License

Stars

Watchers

Forks

Packages

No packages published

Languages

  • TypeScript 56.6%
  • JavaScript 31.8%
  • HTML 9.8%
  • CSS 1.8%