In [None]:
# PyTorch version of your example:
# Learn f(x) = 2x for 4D inputs, then test on [1,2,3,4] and [2,3,4,5].

import torch
import torch.nn as nn

# Reproducibility
torch.manual_seed(0)

# Model: simple linear map R^4 -> R^4 (this is exactly what the target function is)
net = nn.Linear(4, 4, bias=True)

# Optimizer and loss
optimizer = torch.optim.SGD(net.parameters(), lr=1e-2)
criterion = nn.MSELoss()

# Training loop
for step in range(2000):
    # sample inputs in the range you used (-5, 5)
    x = (torch.rand(4) * 10.0) - 5.0          # shape: (4,)
    y = 2.0 * x                                # shape: (4,)

    # forward
    y_pred = net(x)

    # loss
    loss = criterion(y_pred, y)

    # backward + update
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    if step % 200 == 0:
        print(step, loss.item())

# Inference / testing (no gradients)
with torch.no_grad():
    test1 = torch.tensor([1.0, 2.0, 3.0, 4.0])
    test2 = torch.tensor([2.0, 3.0, 4.0, 5.0])

    pred1 = net(test1)
    pred2 = net(test2)

    print("test1:", test1.tolist())
    print("pred1:", pred1.tolist())
    print("test2:", test2.tolist())
    print("pred2:", pred2.tolist())
