In [1]:
import torch

In [2]:
torch.cuda.is_available()

True

In [3]:
ρ = 0.1 # initialize ro
n = 1000 # sample size
p = 20 # dimension

In [4]:
def generate_dataset(n, p, ρ, r, β, duplication):
    X = torch.zeros((duplication, n, p))
    i = torch.arange(p).view(-1, 1)
    j = torch.arange(p).view(1, -1)
    
    Σ = torch.pow(ρ, torch.abs(i - j).float())
    L = torch.linalg.cholesky(Σ)
    Z = torch.randn(duplication, n, p)
    X = Z @ torch.transpose(L, 1, 0)
    
    σ = (1-r)/r * (β.T @ Σ @ β)
    
    ε = torch.randn(n) * torch.sqrt(σ)
    
    Y = X @ β + ε
    
    return X, Y, {
        'Σ': Σ,
        'ε': ε
    }   

In [5]:
X, Y, other = generate_dataset(n=100, p=10, ρ=0, β=torch.ones(10), r=0.8, duplication=1000)

  σ = (1-r)/r * (β.T @ Σ @ β)


In [6]:
from torch import nn

xtx = X.permute(0, 2, 1)@X

In [230]:
xtx.shape

torch.Size([1000, 10, 10])

In [7]:
Y = Y.unsqueeze(-1)

In [255]:
from torch.optim import Adam, SGD
from torch.nn.functional import mse_loss, l1_loss
from tqdm import tqdm

def lasso(λ, epochs=800):
    β = torch.randn((1000, 10, 1), device='cuda', requires_grad=True)
    optimizer = Adam([β], lr=0.005)
    loss_l = torch.zeros(epochs)
    ones = torch.ones((1000, 1), device='cuda')
    # pbar = tqdm(total=epochs, desc='LASSO')
    for epoch in range(epochs):
        optimizer.zero_grad()
        mse = torch.mean(((Y-X@β)**2), dim=1)
        l1norm = torch.sum(β.abs(), dim=1)
        # print(mse.shape, l1norm.shape)
        loss = mse + λ*l1norm
        loss.backward(ones)
        loss_l[epoch] = loss[-1]
        # pbar.set_postfix_str(f'loss: {loss}')
        # pbar.update()
        optimizer.step()
    rss = torch.sum(((Y-X@β)**2), dim=1).squeeze(-1)
    k = (β.abs() > 1e-5).sum()
    aic = 100 * torch.log(rss*0.01) + torch.full(size=(1000, ), fill_value=2*k, device='cuda')
    bic = 100 * torch.log(rss*0.01) + torch.full(size=(1000, ), fill_value=torch.log(torch.tensor(100))*k, device='cuda')
    
    return β, loss_l, aic, bic, rss

In [259]:
def ridge(λ):
    # closed form solution
    
    global X, xtx, Y
    if xtx.device != 'cuda':
        xtx = xtx.to('cuda')
    if Y.device != 'cuda':
        Y = Y.to('cuda')
    if X.device != 'cuda':
        X = X.to('cuda')
    β = torch.linalg.solve(xtx + λ * torch.eye(xtx.shape[-1], device='cuda'), X.permute(0, 2, 1)@Y)
    
    # report AIC, BIC
    rss = torch.sum(((Y-X@β)**2), dim=1).squeeze(-1)
    k = (β.abs() > 1e-5).sum()
    aic = 100 * torch.log(rss*0.01) + torch.full(size=(1000, ), fill_value=2*k, device='cuda')
    bic = 100 * torch.log(rss*0.01) + torch.full(size=(1000, ), fill_value=torch.log(torch.tensor(100))*k, device='cuda')
    
    return β, aic, bic
    

In [264]:
pbar = tqdm(total=10, desc='λ')
for λ in torch.linspace(0, 0.1, 10):
    ridge_beta, ridge_aic, ridge_bic = ridge(λ)
    lasso_beta, loss, lasso_aic, lasso_bic, rss = lasso(λ)
    pbar.set_postfix_str(f'{λ=}')
    pbar.update()
    
    

λ:   0%|          | 0/10 [00:27<?, ?it/s]
λ: 100%|██████████| 10/10 [00:15<00:00,  1.51s/it, λ=tensor(0.1000)]