In [1]:
import torch
import torchvision
import numpy as np
import matplotlib.pyplot as plt
%matplotlib auto

Using matplotlib backend: Qt5Agg


In [36]:
n_train,n_test,num_inputs=20,100,200
true_w,true_b=torch.ones(num_inputs,1)*0.01,0.05

features=torch.randn((n_train+n_test,num_inputs))
labels=torch.matmul(features,true_w)+true_b
labels+=torch.Tensor(np.random.normal(0,0.01,size=labels.size()))
train_features,test_features=features[:n_train,:],features[n_train:,:]
train_labels,test_labels=labels[:n_train],labels[n_train:]


In [37]:
def init_params():
    w=torch.randn((num_inputs,1),requires_grad=True)
    b=torch.zeros(1,requires_grad=True)
    return [w,b]

In [38]:
def ls_penalty(w):
    return (w**2).sum()/2

In [39]:
def net(x,w,b):
#     print(x.shape)
    return torch.mm(x,w)+b

In [40]:
def squared_loss(y_hat,y):
#     print('y_hat:',y_hat)
    return (y_hat-y.view(y_hat.size()))**2/2

In [58]:
def sgd(params,lr,batch_size):
    for p in params:
        p.data -= lr*p.grad / batch_size
#         print(p.data)

In [59]:
batch_size,num_epochs,lr=1,100,0.003
dataset=torch.utils.data.TensorDataset(train_features,train_labels)
train_iter=torch.utils.data.DataLoader(dataset,batch_size,shuffle=True)

In [60]:
def fit_and_plot(ld):
    w,b=init_params()
    loss=squared_loss
    train_ls,test_ls=[],[]
    for _ in range(num_epochs):
        for x,y in train_iter:
            l=loss(net(x,w,b),y)+ld*ls_penalty(w)
            l=l.sum()
            
            if w.grad is not None:
#                 print(' reset 0')
                w.grad.data.zero_()
                b.grad.data.zero_()
            l.backward()
            sgd([w,b],lr,batch_size)
#         print('train loss:',loss(net(train_features,w,b),train_labels))
        train_ls.append(loss(net(train_features,w,b),train_labels).mean().item())
        test_ls.append(loss(net(test_features,w,b),test_labels).mean().item())
        print('L2 norm of w:',w.norm().item())
            

In [63]:
fit_and_plot(ld=3)

L2 norm of w: 9.609057426452637
L2 norm of w: 7.93615198135376
L2 norm of w: 6.6049723625183105
L2 norm of w: 5.507937431335449
L2 norm of w: 4.595791339874268
L2 norm of w: 3.8354694843292236
L2 norm of w: 3.201160192489624
L2 norm of w: 2.6718990802764893
L2 norm of w: 2.2302770614624023
L2 norm of w: 1.8617419004440308
L2 norm of w: 1.5542516708374023
L2 norm of w: 1.2976802587509155
L2 norm of w: 1.083652138710022
L2 norm of w: 0.9050969481468201
L2 norm of w: 0.7562064528465271
L2 norm of w: 0.6321134567260742
L2 norm of w: 0.5287802815437317
L2 norm of w: 0.44256022572517395
L2 norm of w: 0.3709110617637634
L2 norm of w: 0.31149089336395264
L2 norm of w: 0.26218631863594055
L2 norm of w: 0.22125643491744995
L2 norm of w: 0.18763087689876556
L2 norm of w: 0.15962284803390503
L2 norm of w: 0.13725747168064117
L2 norm of w: 0.11921142786741257
L2 norm of w: 0.10504072904586792
L2 norm of w: 0.09310078620910645
L2 norm of w: 0.08308923989534378
L2 norm of w: 0.07590818405151367
L2 no

In [64]:
from torch import nn,optim

In [72]:
def fit_and_plot_pytorch(wd):
    net=nn.Linear(num_inputs,1)
    nn.init.normal_(net.weight,mean=0,std=1)
    nn.init.normal_(net.bias,mean=0,std=1)
    optimizer_w=torch.optim.SGD(params=[net.weight],lr=lr,weight_decay=wd)
    optimizer_b=torch.optim.SGD(params=[net.bias],lr=lr)
    loss=squared_loss
    
    train_ls,test_ls=[],[]
    for _ in range(num_epochs):
        for x,y in train_iter:
            l=loss(net(x),y).mean()
            
            optimizer_w.zero_grad()
            optimizer_b.zero_grad()
            
            l.backward()
            
            optimizer_w.step()
            optimizer_b.step()
            
        train_ls.append(loss(net(train_features),train_labels).mean().item())
        test_ls.append(loss(net(test_features),test_labels).mean().item())
        print('L2 norm of w:',net.weight.data.norm().item())
        
            
            
            
            

In [73]:
fit_and_plot_pytorch(0)

L2 norm of w: 14.151895523071289
L2 norm of w: 14.063544273376465
L2 norm of w: 14.042744636535645
L2 norm of w: 14.037006378173828
L2 norm of w: 14.035221099853516
L2 norm of w: 14.034656524658203
L2 norm of w: 14.034491539001465
L2 norm of w: 14.034446716308594
L2 norm of w: 14.034435272216797
L2 norm of w: 14.034436225891113
L2 norm of w: 14.034438133239746
L2 norm of w: 14.034439086914062
L2 norm of w: 14.034440040588379
L2 norm of w: 14.034440040588379
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of w: 14.034440994262695
L2 norm of