### Lest squares with pytorch

In [None]:
import numpy as np
from numpy.linalg import norm, lstsq

import torch
from torch import nn

### Formulate a least-squares problem in several different ways

In [None]:
class LeastSquares1(nn.Module):
    """Implement least-squares using nn.Linear."""
    def __init__(self, A):
        super().__init__()
        m, n = A.shape
        self.A = nn.Linear(n, m)
        self.A.weight = nn.Parameter(torch.FloatTensor(A), requires_grad=False)
        self.A.bias = nn.Parameter(torch.zeros(m), requires_grad=False)
        
        self.x = nn.Linear(1, n)
        self.x.weight = nn.Parameter(torch.zeros(n, 1))
        self.x.bias = nn.Parameter(torch.zeros(n), requires_grad=False)
        
    def forward(self):
        """Evaluate model."""
        return self.A(self.x(torch.ones(1)))
    
    def solution(self):
        """Return the solution."""
        return self.x.weight.data.numpy().flatten()

class LeastSquares2(nn.Module):
    """Implement least-squares using nn.Parameter."""
    def __init__(self, A):
        super().__init__()
        m, n = A.shape
        self.A = torch.FloatTensor(A)
        self.x = nn.Parameter(torch.zeros(n))
        
    def forward(self):
        """Evaluate model."""
        return self.A @ self.x

    def solution(self):
        """Return the solution."""
        return self.x.data.numpy().flatten()
    
class LeastSquares3(nn.Module):
    """Implement least-squares using torch.tensor."""
    def __init__(self, A):
        super().__init__()
        m, n = A.shape
        self.A = torch.FloatTensor(A)
        self.x = torch.zeros(n, requires_grad=True)
        
    def forward(self):
        """Evaluate model."""
        return self.A @ self.x

    def solution(self):
        """Return the solution."""
        return self.x.data.numpy().flatten()

### Solve a least-squares problem.

In [None]:
def solve(model, A, y, lr=1e-3, epochs=200):
    """Solve a least-squares problem `y=A*x` using torch.optim.SGD."""
    ls = model(A)
    opt = torch.optim.SGD(ls.parameters(), lr=lr)
    
    y = torch.FloatTensor(y)
    for k in range(epochs):
        loss = sum((ls.forward() - y)**2)
        loss.backward()
        opt.step()
        opt.zero_grad()
    
    return ls.solution()

def solve_manual(model, A, y, lr=1e-3, epochs=200):
    """Solve a least-squares problem `y=A*x` manually."""
    ls = model(A)
    
    y = torch.FloatTensor(y)
    h = 2*ls.A.t()@y
    for k in range(epochs):
        loss = sum((ls.forward() - y)**2)
                
        loss.backward()

        grad = 2*ls.A.t()@(ls.A@ls.x) - h
        assert (grad - ls.x.grad).norm() < 1e-04
        
        with torch.no_grad():
            ls.x -= lr * ls.x.grad
            ls.x.grad.zero_()
    
    return ls.solution()

### Test

In [None]:
m, n = 100, 20
A, y = np.random.randn(m, n), np.random.randn(m)

x_ls = lstsq(A, y, rcond=None)[0]

x1 = solve(LeastSquares1, A, y)
print('[x1] err:', norm(x_ls - x1))

x2 = solve(LeastSquares2, A, y)
print('[x2] err:', norm(x_ls - x2))

x3 = solve_manual(LeastSquares3, A, y)
print('[x3] err:', norm(x_ls - x3))