## Pytorch

In [23]:
import torch
import torch.nn as nn
import torch.optim as optim

In [24]:
class LinearModel(nn.Module):
    def __init__(self, input_dim, output_dim):
        super().__init__()
        self.linear = nn.Linear(input_dim, output_dim)
        
    def forward(self, x):
        return self.linear(x)

In [25]:
model = LinearModel(1, 1)
optimizer = optim.SGD(model.parameters(), lr=0.01)

def loss_fn(y_pred, y):
    return torch.mean((y_pred - y)**2)

In [26]:
n_epochs = 200
for epoch in range(1, n_epochs + 1):
    inputs = torch.randn(32, 1)
    targets = inputs * 2
    outputs = model(inputs)

    loss = loss_fn(outputs, targets)
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    
    rsme  = 1 - torch.sqrt(torch.mean((outputs - targets)**2))

    if epoch % (n_epochs // 10) == 0:
        print(f"[{epoch:3d}/{n_epochs}] loss: {loss.item():.4f} rsme: {rsme.item():.4f}")

[ 20/200] loss: 4.4325 rsme: -1.1053
[ 40/200] loss: 1.5287 rsme: -0.2364
[ 60/200] loss: 0.8847 rsme: 0.0594
[ 80/200] loss: 0.4109 rsme: 0.3590
[100/200] loss: 0.1191 rsme: 0.6549
[120/200] loss: 0.0630 rsme: 0.7491
[140/200] loss: 0.0228 rsme: 0.8489
[160/200] loss: 0.0080 rsme: 0.9105
[180/200] loss: 0.0059 rsme: 0.9230
[200/200] loss: 0.0023 rsme: 0.9526


## Flax.linen

In [1]:
import jax
import jax.numpy as jnp
from flax import nnx
import optax

In [42]:
class LinearModel(nnx.Module):
    def __init__(self, input_dim, output_dim, rngs: nnx.Rngs):
        self.linear = nnx.Linear(input_dim, output_dim, rngs=rngs)
        
    def __call__(self, x):
        return self.linear(x)

In [54]:
key = jax.random.PRNGKey(0)
rngs = nnx.Rngs(key)

model = LinearModel(1, 1, rngs=rngs)
optimizer = nnx.Optimizer(model, optax.sgd(learning_rate=0.01))

def loss_fn(model, x, y):
    y_pred = model(x)
    loss = jnp.mean((y_pred - y)**2)
    return loss, y_pred

@nnx.jit
def train_step(model, optimizer, x, y):
    (loss, outputs), grads = nnx.value_and_grad(loss_fn, has_aux=True)(model, x, y)
    optimizer.update(grads)
    return loss, outputs

In [55]:
n_epochs = 200
for epoch in range(1, n_epochs + 1):
    inputs = jax.random.normal(key, (32, 1))
    targets = inputs * 2
    loss, outputs = train_step(model, optimizer, inputs, targets)
    rsme = 1 - jnp.sqrt(jnp.mean((outputs - targets)**2))
    
    if epoch % (n_epochs // 10) == 0:
        print(f"[{epoch:3d}/{n_epochs}] loss: {loss:.3f} rsme: {rsme:.3f}")

[ 20/200] loss: 0.009 rsme: 0.905
[ 40/200] loss: 0.005 rsme: 0.931
[ 60/200] loss: 0.003 rsme: 0.950
[ 80/200] loss: 0.001 rsme: 0.963
[100/200] loss: 0.001 rsme: 0.973
[120/200] loss: 0.000 rsme: 0.980
[140/200] loss: 0.000 rsme: 0.985
[160/200] loss: 0.000 rsme: 0.989
[180/200] loss: 0.000 rsme: 0.992
[200/200] loss: 0.000 rsme: 0.994


In [24]:
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax

In [25]:
class LinearModel(nn.Module):
    @nn.compact
    def __call__(self, x):
        return nn.Dense(features=1)(x)

In [39]:
key = jax.random.PRNGKey(0)
model = LinearModel()
params = model.init(key, jnp.ones((1, 1)))

optimizer = optax.sgd(learning_rate=0.01)
opt_state = optimizer.init(params)

def loss_fn(params, x, y):
    y_pred = model.apply(params, x)
    loss = jnp.mean((y_pred - y)**2)
    return loss, y_pred

@jax.jit
def train_step(params, opt_state, x, y):
    (loss, outputs), grads = jax.value_and_grad(loss_fn, has_aux=True)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss, outputs

In [40]:
n_epochs = 200
for epoch in range(1, n_epochs + 1):
    inputs = jax.random.normal(key, (32, 1))
    targets = inputs * 2
    params, opt_state, loss, outputs = train_step(params, opt_state, inputs, targets)
    rsme = 1 - jnp.sqrt(jnp.mean((outputs - targets)**2))
    
    if epoch % (n_epochs // 10) == 0:
        print(f"[{epoch:3d}/{n_epochs}] loss: {loss:.3f} rsme: {rsme:.3f}")

[ 20/200] loss: 1.896 rsme: -0.377
[ 40/200] loss: 0.994 rsme: 0.003
[ 60/200] loss: 0.526 rsme: 0.275
[ 80/200] loss: 0.281 rsme: 0.470
[100/200] loss: 0.151 rsme: 0.611
[120/200] loss: 0.082 rsme: 0.714
[140/200] loss: 0.045 rsme: 0.789
[160/200] loss: 0.024 rsme: 0.844
[180/200] loss: 0.013 rsme: 0.885
[200/200] loss: 0.007 rsme: 0.915
