### Lest squares with pytorch

In [None]:
import numpy as np

import torch
from torch import nn

In [None]:
class LeastSquares(nn.Module):
    """Implement least-squares."""
    def __init__(self, A):
        super().__init__()
        m, n = A.shape
        self.A = nn.Linear(n, m)
        self.A.weight = nn.Parameter(torch.FloatTensor(A), requires_grad=False)
        self.A.bias = nn.Parameter(torch.zeros(m), requires_grad=False)
        
        self.x = nn.Linear(1, n)
        #self.x.weight = nn.Parameter(torch.zeros(n, 1))
        self.x.bias = nn.Parameter(torch.zeros(n), requires_grad=False)
        
    def forward(self):
        """Evaluate model."""
        return self.A(self.x(torch.ones(1)))
    
def solve(A, y, lr=1e-3, epochs=200):
    """Solve a least-squares problem `y=A*x` using SGD."""
    ls = LeastSquares(A)
    opt = torch.optim.SGD(ls.parameters(), lr=lr)
    y = torch.FloatTensor(y)
    for k in range(epochs):
        loss = sum((ls.forward() - y)**2)
          
        loss.backward()
        opt.step()
        opt.zero_grad()
    
    return ls.x.weight.data.numpy().flatten()

In [None]:
m, n = 100, 20
A, y = np.random.randn(m, n), np.random.randn(m)
x = solve(A, y)

x_ls = np.linalg.lstsq(A, y, rcond=None)[0]
print('err:', np.linalg.norm(x_ls - x))