In [1]:
import torch
import numpy as np

# assigns weights to data points, currently all 1's just to test other functions
def w(x, c, model):
    return torch.ones(len(c))

# samples minibatch
def get_batch(x, c, batch_size):
    inds = np.random.choice(range(len(c)), size=batch_size, replace=False)
    return x[inds], c[inds]
    
# performs an iteration of the iteratively reweighted least squares algorithm
def step(x, c, model, optimizer, w, batch_size):
    # arguments are self explanatory except w needs to be a callable function, i used a lambda below in the test
    optimizer.zero_grad()
    x, c = get_batch(x, c, batch_size)
    c_hat = model(x).flatten()
    weights = w(x, c, model)
    loss = torch.mean(weights*(c_hat-c)**2)
    loss.backward()
    optimizer.step()
    return loss

# gets parameters of model as array, used to check convergence later
def get_params(model):
    params=[]
    param_generator = model.parameters()
    for param in param_generator:
        nums = param[0].flatten()
        for num in nums:
            params.append(num.item())
    return np.array(params)

# training loop that stops when parameters dont change much, we can explore other convergence criteria
def train(x, c, model, optimizer, w, batch_size, eps):
    done = False
    params = get_params(model)
    losses=[]
    while not done:
        loss = step(x, c, model, optimizer, w, batch_size)
        losses.append(loss)
        new_params = get_params(model)
        if np.linalg.norm(params-new_params)<eps:
            done = True
        else:
            params = new_params
    return losses   

In [2]:
# test training loop on a linear least squares problem with no noise, looks like its working

n=1000
n_features = 10
batch_size = 1000
eps = 1e-12
model = torch.nn.Sequential(
      torch.nn.Linear(in_features=n_features, out_features=1)
)
start_params = get_params(model)
optimizer = torch.optim.Adam(model.parameters())
yo = lambda x, c, batch_size: w(x, c, batch_size) # this is the callable function to calculate weights
x = torch.rand(n, n_features)
true_params = torch.rand(n_features+1)
c = torch.matmul(x, true_params[:-1]) + true_params[-1]
losses = train(x, c, model, optimizer, w, batch_size, eps)
end_params = get_params(model)
true_params = true_params.numpy()
print("true params are", true_params)
print("found params are", end_params)
print("relative errror is", np.linalg.norm(end_params-true_params)/np.linalg.norm(true_params))

true params are [0.08979118 0.49159217 0.19931328 0.8223423  0.62896633 0.86027324
 0.45375198 0.84065086 0.44945568 0.2564683  0.33904016]
found params are [0.08979066 0.49159139 0.19931288 0.82234102 0.62896508 0.86027211
 0.45375136 0.8406496  0.44945493 0.25646761 0.33904454]
relative errror is 2.8518981599160026e-06
