# optim explain

torch.optim uses parameters() in order to update parameters in the network

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

## 0.loss


One common misbelief is that backward() is done by nn.Module layers. However, .backward() operations are actually defined on autograd tensors. If we do backward() on nn.Parameters(), we actually is still doing backward on tensors.


Since loss is also a subclass from nn.Module, we can self-built one loss in order to do backward operations on tensors.


We have one computation graph and this computation graph has one output. We want to update and improve the parameters in the model, which is model.paramters(). 

1st step : we need to clear the grad of model.paramerters() = 0 or None (in order to save memory)

2nd step : we need to do autograd backward operations on tensors

3rd step : we need to update data based on optimizer and their gradient

In [2]:
torch.random.seeds = 100
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(2)])
        self.relu = nn.ReLU()
        self.linear = nn.Linear(10,1, bias=False)
        
    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i].mm(x)
        x = self.linear(x)
        #print(self.linear.weight)
        #print(self.linear.weight.grad)
        return x

# the first way to define MyLoss, must use loss.forword(input, target)
class MyLoss():
    def __init__(self):
        super(MyLoss, self).__init__()
        pass
    
    def forward(self, input, target):
        loss = torch.mean((input - target)**2)
        loss.requires_grad_(True)
        return loss
        
# the second way to define MyLoss, must use loss.forword(input, target)
class MyLoss2():
    def __init__(self):
        super(MyLoss2, self).__init__()
        pass
    def forward(self, input, target):
        loss = torch.mean((input - target)**2)
        loss.requires_grad_(True)
        return loss
    
model = MyModule()
loss = MyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(10):
    for i in range(100):
        input  = torch.ones((10,10))*(i)
        target = torch.ones((10,1))*(i)
        output = model(input)
        myloss = loss.forward(output, target)
        print("myloss : {}".format(myloss))
        optimizer.zero_grad() # grad = 0
        myloss.backward()     # grad = grad + updated_grad (updated_grad is decided by tensor.data, tensor.grad_fn)
        optimizer.step()      # data = opt(data, grad)

print("=====para=====")
for para in model.parameters():
    print(para.shape)

myloss : 0.0
myloss : 5.621788024902344
myloss : 13.37084674835205
myloss : 16.056596755981445
myloss : 13.52783203125
myloss : 11.072798728942871
myloss : 15.6159029006958
myloss : 25.566362380981445
myloss : 29.602481842041016
myloss : 25.083662033081055
myloss : 18.73093605041504
myloss : 24.32271957397461
myloss : 28.89194679260254
myloss : 19.19350814819336
myloss : 11.239212989807129
myloss : 14.178762435913086
myloss : 14.452031135559082
myloss : 8.23505973815918
myloss : 8.467677116394043
myloss : 10.004246711730957
myloss : 3.4879825115203857
myloss : 4.503981113433838
myloss : 8.41679573059082
myloss : 6.1118693351745605
myloss : 9.45297908782959
myloss : 7.2576470375061035
myloss : 6.498615264892578
myloss : 7.841006278991699
myloss : 5.343003273010254
myloss : 6.115416526794434
myloss : 2.21388578414917
myloss : 3.4894416332244873
myloss : 2.847900390625
myloss : 3.5140254497528076
myloss : 4.292602062225342
myloss : 2.177769422531128
myloss : 4.655512809753418
myloss : 3.8

myloss : 0.0003607015241868794
myloss : 0.002261731307953596
myloss : 0.004982230719178915
myloss : 0.007009559310972691
myloss : 0.007043556310236454
myloss : 0.004904492758214474
myloss : 0.0019105372484773397
myloss : 0.00012931835954077542
myloss : 0.0008078068494796753
myloss : 0.003141882596537471
myloss : 0.004765854682773352
myloss : 0.003950241021811962
myloss : 0.0015121221076697111
myloss : 9.062158642336726e-05
myloss : 0.0011204227339476347
myloss : 0.002969837747514248
myloss : 0.0029755127616226673
myloss : 0.0011191427474841475
myloss : 0.00010778941214084625
myloss : 0.001246505999006331
myloss : 0.0023291627876460552
myloss : 0.0013867386151105165
myloss : 0.00014555016241502017
myloss : 0.0007553307805210352
myloss : 0.001685968367382884
myloss : 0.0009215561440214515
myloss : 0.00011323792568873614
myloss : 0.0008505579316988587
myloss : 0.0011717334855347872
myloss : 0.0002890427422244102
myloss : 0.0003307477163616568
myloss : 0.0009412263752892613
myloss : 0.0004

myloss : 71.74571228027344
myloss : 63.281280517578125
myloss : 19.715560913085938
myloss : 18.052204132080078
myloss : 48.19956970214844
myloss : 42.877098083496094
myloss : 28.05245590209961
myloss : 46.2333984375
myloss : 36.725120544433594
myloss : 14.6099853515625
myloss : 27.695262908935547
myloss : 22.482837677001953
myloss : 14.499183654785156
myloss : 28.332576751708984
myloss : 12.987123489379883
myloss : 12.441163063049316
myloss : 15.960851669311523
myloss : 8.43346118927002
myloss : 20.51154327392578
myloss : 10.014842987060547
myloss : 13.660514831542969
myloss : 5.930948257446289
myloss : 6.5006256103515625
myloss : 7.793520927429199
myloss : 8.854846954345703
myloss : 10.520951271057129
myloss : 6.824660301208496
myloss : 4.795111656188965
myloss : 2.863884687423706
myloss : 3.41192889213562
myloss : 7.1213483810424805
myloss : 4.520137310028076
myloss : 6.609102725982666
myloss : 0.9606568217277527
myloss : 3.281348466873169
myloss : 3.0185065269470215
myloss : 2.44253

myloss : 2.7840142138302326e-06
myloss : 2.456011134199798e-05
myloss : 6.489304269052809e-06
myloss : 1.1570047718123533e-05
myloss : 1.8510030713514425e-05
myloss : 1.561746444167511e-06
myloss : 1.4064152310311329e-05
myloss : 1.4145474779070355e-05
myloss : 1.5115889482331113e-06
myloss : 9.553064955980517e-06
myloss : 1.569796586409211e-05
myloss : 5.0281582844036166e-06
myloss : 8.609553105998202e-07
myloss : 1.0710122296586633e-05
myloss : 1.4806905710429419e-05
myloss : 9.86928353086114e-06
myloss : 2.909015165641904e-06
myloss : 1.2229429557919502e-07
myloss : 3.5653997656481806e-06
myloss : 1.0247994396195281e-05
myloss : 2.3528922611149028e-05
myloss : 5.0696391554083675e-05
myloss : 0.00012364305439405143
myloss : 0.0003517211298458278
myloss : 0.0011785590322688222
myloss : 0.004564945586025715
myloss : 0.02010202407836914
myloss : 0.09937895089387894
myloss : 0.5410152673721313
=====para=====
torch.Size([10, 10])
torch.Size([10, 10])
torch.Size([1, 10])


## 1. optimizer

In the source code, different optimizers are inherited from optimzier. 

In the base class optimizer, we have functions including __init__ / state_dict / load_state_dict / zero_grad / add_param_group (are used for modifying param_group).

In the concrete optimizer, we define step function in order to define how to update parameters in the network.

In the step function, we use adam function in /functional/adam in order to do data update


nn.parameters are stored in the params in the step function.

**NOTICE**

When using optimizer.step(), we need to use the decorator **@torch.no_grad()** since we need to guarantee that the modification of data do not cause extra modification of **grad** and **grad_fn**

**NOTICE**

param_groups in optimizer are dict for param_groups

Here is the init for base class Optimizer.


    def __init__(self, params, defaults):

        torch._C._log_api_usage_once("python.optimizer")
        
        self.defaults = defaults
        
        self._hook_for_profile()

        if isinstance(params, torch.Tensor):
            raise TypeError("params argument given to the optimizer should be "
                            "an iterable of Tensors or dicts, but got " +
                            torch.typename(params))

        self.state = defaultdict(dict)
        self.param_groups = []

        param_groups = list(params)
        if len(param_groups) == 0:
            raise ValueError("optimizer got an empty parameter list")
        if not isinstance(param_groups[0], dict):
            param_groups = [{'params': param_groups}]

        for param_group in param_groups:
            self.add_param_group(param_group)
            

In [3]:
torch.random.seeds = 100
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(2)])
        self.linear = nn.Linear(10,1, bias=False)
        
    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i].mm(x)
        x = self.linear(x)
        return x

model = MyModule()
loss = nn.MSELoss()
optimizer = optim.Adam(model.parameters())

for epoch in range(10):
    for i in range(100):
        input  = torch.ones((10,10))*(i)
        target = torch.ones((10,1))*(i)
        output = model(input)
        myloss = loss.forward(output, target)
        optimizer.zero_grad()        # grad = 0
        myloss.backward()            # grad = grad + updated_grad (updated_grad is decided by tensor.data, tensor.grad_fn)
        optimizer.step()             # data = opt(data, grad)

# the first elememt in this param_groups[0] is params of the input (model.parameters())
# the following elements in this param_groups[0] is 'lr', 'betas', 'decay' ...
print(optimizer.param_groups[0])

{'params': [Parameter containing:
tensor([[-1.0673e+00, -1.7266e+00,  1.1235e-04, -1.0110e+00, -9.3613e-02,
          6.4645e-01, -2.6980e+00,  1.2736e-01,  4.3433e-03, -4.3700e-01],
        [ 4.1189e-01, -3.0783e-01,  2.8894e-01,  2.7649e-03, -4.9187e-01,
         -5.7965e-01,  1.3223e+00,  1.2666e+00,  7.4824e-01, -5.6755e-02],
        [-1.0052e+00, -4.9082e-01, -1.9141e+00,  1.3645e+00,  1.4591e+00,
         -1.1250e+00, -2.6680e-01,  1.7784e+00, -9.7213e-01,  2.2495e-01],
        [-1.0383e+00, -5.7842e-01,  9.8913e-01,  2.6514e-01, -1.5930e+00,
          2.7621e-01,  8.1757e-01,  1.0836e+00, -1.4204e-01,  1.2197e+00],
        [ 7.8965e-02,  6.8238e-02,  1.0282e+00,  2.2101e+00,  1.3505e-01,
          4.1777e-01,  9.7451e-01, -6.6441e-01, -8.3558e-01, -4.3806e-01],
        [ 3.8864e-01,  3.4815e-01,  1.3808e+00, -4.5133e-01,  5.4479e-01,
         -4.4370e-01, -1.4232e+00,  1.4735e+00, -1.0317e+00,  2.4796e-01],
        [ 1.6025e-01,  2.0164e+00, -7.2563e-02,  1.3049e+00,  3.2508e-01

**NOTICE**

step(closure), when doing optimizer.step(), if we do have closure param, we need to do get the loss by doing function of the closure(). Since some algorithm requires us to calculate the model for many times like the CG or LBFGS algo which means that **these algo woulde have inside coding that requires repeatly using closure() function, so we need to provide them with this function**.

It means that we need to first rerun the whole model again to get the loss and continue to do the following updates.

    loss = None
        if closure is not None:
            with torch.enable_grad():
                loss = closure()


In [4]:
# using LBGFS that requires closure()

torch.random.seeds = 100
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(2)])
        self.linear = nn.Linear(10,1, bias=False)
        
    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i].mm(x)
        x = self.linear(x)
        return x

model = MyModule()
loss = nn.MSELoss()
optimizer = optim.LBFGS(model.parameters())

for epoch in range(10):
    for i in range(100):
        def closure():
            optimizer.zero_grad()
            output = model(input)
            myloss = loss.forward(output, target)
            myloss.backward()
            return myloss
        input  = torch.ones((10,10))*(i)
        target = torch.ones((10,1))*(i)
        output = model(input)
        myloss = loss.forward(output, target)
        print("myloss : {}".format(myloss))
        optimizer.zero_grad()        # grad = 0
        myloss.backward()            # grad = grad + updated_grad (updated_grad is decided by tensor.data, tensor.grad_fn)
        optimizer.step(closure)      # data = opt(data, grad)

myloss : 0.0
myloss : 18.4127140045166
myloss : 0.7174724340438843
myloss : 0.0007368209771811962
myloss : 9.619463980925502e-07
myloss : 1.2819375569961267e-06
myloss : 1.8452686845193966e-06
myloss : 2.8558133635669947e-10
myloss : 2.495653406153764e-10
myloss : 4.05634636990726e-10
myloss : 4.3146428096996203e-10
myloss : 6.057234713807702e-10
myloss : 7.137714308491638e-10
myloss : 8.803908713161945e-10
myloss : 1.0884833034907615e-09
myloss : 9.327777439338547e-10
myloss : 1.0200892353040558e-09
myloss : 2.5946065296267307e-09
myloss : 1.3489624972606862e-09
myloss : 1.4479155652225018e-09
myloss : 1.6472767594422066e-09
myloss : 1.807347826954242e-09
myloss : 2.2293533685768807e-09
myloss : 2.0867445549299646e-09
myloss : 2.5669577574660707e-09
myloss : 3.0209776902268004e-09
myloss : 4.0978194171259474e-09
myloss : 2.7532223167980874e-09
myloss : 3.568129569586631e-09
myloss : 3.655441060956832e-09
myloss : 3.236345857970946e-09
myloss : 3.9639416193892885e-09
myloss : 4.2084140

myloss : 4.4237821739123717e-10
myloss : 3.2829121643374037e-09
myloss : 3.4691765016248155e-09
myloss : 5.587935669737476e-10
myloss : 3.958121053138086e-10
myloss : 2.793967834868738e-10
myloss : 1.0477378964424133e-09
myloss : 1.2572854979353565e-09
myloss : 5.587935669737476e-10
myloss : 1.2572854979353565e-09
myloss : 1.5832484212552345e-09
myloss : 9.08039532454552e-10
myloss : 1.5832484212552345e-09
myloss : 8.149072527885437e-10
myloss : 0.0
myloss : 3.979039252493925e-14
myloss : 1.59161570099757e-13
myloss : 1.2278178690774966e-12
myloss : 6.36646280399028e-13
myloss : 3.910827391095939e-12
myloss : 4.911271476309986e-12
myloss : 1.5643309564383756e-11
myloss : 2.546585121596112e-12
myloss : 6.912159646738081e-12
myloss : 1.5643309564383756e-11
myloss : 6.039044958550122e-11
myloss : 1.9645085905239945e-11
myloss : 1.8917489102987517e-11
myloss : 6.257323825753502e-11
myloss : 1.6007107098148232e-11
myloss : 1.0186340486384449e-11
myloss : 7.130438600677635e-11
myloss : 2.764

myloss : 2.0954757373736754e-10
myloss : 1.1641532182693481e-10
myloss : 3.0267982564780027e-10
myloss : 1.8626451769865326e-10
myloss : 1.6298144778215118e-10
myloss : 1.1641532182693481e-10
myloss : 2.1187589460680556e-09
myloss : 9.313225884932663e-11
myloss : 5.820766091346741e-10
myloss : 9.778886589373315e-10
myloss : 2.3748725208605492e-09
myloss : 2.0023436242411208e-09
myloss : 3.0267982564780027e-10
myloss : 7.916242106276172e-10
myloss : 1.2572854979353565e-09
myloss : 1.3737008197622913e-09
myloss : 1.1408701761084217e-09
myloss : 2.2584571990336144e-09
myloss : 9.778886589373315e-10
myloss : 4.190951474747351e-10
myloss : 4.4237821739123717e-10
myloss : 7.450580707946131e-10
myloss : 1.1641532182693481e-10
myloss : 2.2584571990336144e-09
myloss : 2.3283064365386963e-10
myloss : 5.122274271407434e-10
myloss : 6.519257911286047e-10
myloss : 6.752088888006824e-10
myloss : 2.3050232833554674e-09
myloss : 2.514570995870713e-09
myloss : 6.286427489676782e-10
myloss : 4.423782173

myloss : 3.2014214196296464e-11
myloss : 1.4697434935762033e-10
myloss : 1.2514647651507005e-10
myloss : 4.511093582015846e-11
myloss : 5.384208703884674e-11
myloss : 5.2386893434341886e-11
myloss : 2.0954757373736754e-10
myloss : 3.899913170180014e-10
myloss : 1.1059455434780929e-10
myloss : 1.1641532182693481e-10
myloss : 2.3283064712331658e-11
myloss : 1.1641532356165829e-11
myloss : 2.27009883113638e-10
myloss : 1.7462298274040222e-10
myloss : 5.878973974304813e-10
myloss : 1.4551915228366852e-10
myloss : 1.3387761732541748e-10
myloss : 1.0710209386033398e-09
myloss : 5.2386893434341886e-11
myloss : 2.0954757373736754e-10
myloss : 3.7834979593753815e-10
myloss : 1.2805685678518586e-10
myloss : 6.111804395914078e-10
myloss : 5.878973974304813e-10
myloss : 1.5716068724191956e-10
myloss : 5.005859060602802e-10
myloss : 6.984919587171845e-11
myloss : 1.8044374328063384e-10
myloss : 1.3387761732541748e-10
myloss : 2.1536834815538697e-10
myloss : 1.0011718121205604e-09
myloss : 2.0954757

## 2. lr_scheduler

provide several methods to adjust the learning rate based on the number of epoches

We should pay special attention that before PyTorch1.1.0, scheduler.step() should be put before optimizer.step(). 

However, after PyTorch1.1.0, scheduler.step() woudl be put after optimizer.step() [after data has been updated, we can adjust our learning rate]

In [5]:
torch.random.seeds = 100
class MyModule(nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(2)])
        self.relu = nn.ReLU()
        self.linear = nn.Linear(10,1, bias=False)
        
    def forward(self, x):
        # ParameterList can act as an iterable, or be indexed using ints
        for i, p in enumerate(self.params):
            x = self.params[i].mm(x)
        x = self.linear(x)
        #print(self.linear.weight)
        #print(self.linear.weight.grad)
        return x

# the first way to define MyLoss, must use loss.forword(input, target)
class MyLoss():
    def __init__(self):
        super(MyLoss, self).__init__()
        pass
    
    def forward(self, input, target):
        loss = torch.mean((input - target)**2)
        loss.requires_grad_(True)
        return loss
        
# the second way to define MyLoss, must use loss.forword(input, target)
class MyLoss2():
    def __init__(self):
        super(MyLoss2, self).__init__()
        pass
    def forward(self, input, target):
        loss = torch.mean((input - target)**2)
        loss.requires_grad_(True)
        return loss

In [6]:
# WITHOUT LR_scheduler

model = MyModule()
loss = MyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

for epoch in range(10):
    for i in range(100):
        input  = torch.ones((10,10))*(i)
        target = torch.ones((10,1))*(i)
        output = model(input)
        myloss = loss.forward(output, target)
        #print("myloss : {}".format(myloss))
        optimizer.zero_grad() # grad = 0
        myloss.backward()     # grad = grad + updated_grad (updated_grad is decided by tensor.data, tensor.grad_fn)
        optimizer.step()      # data = opt(data, grad)
    print(optimizer.param_groups[0]['lr'])

print("=====para=====")
for para in model.parameters():
    print(para.shape)

0.01
0.01
0.01
0.01
0.01
0.01
0.01
0.01
0.01
0.01
=====para=====
torch.Size([10, 10])
torch.Size([10, 10])
torch.Size([1, 10])


In [7]:
# WITH LR_scheduler

model = MyModule()
loss = MyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
lambda1 = lambda epoch: epoch // 30
lambda2 = lambda epoch: 0.95 ** epoch
# NOTICE: we can define a number of lambda functions
# BUT we only need to use one lambda functions when doing lr adjustment
# normally lambda has input of epoch in order to adjust lr based on epochs
scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda2)
for epoch in range(10):
    for i in range(100):
        input  = torch.ones((10,10))*(i)
        target = torch.ones((10,1))*(i)
        output = model(input)
        myloss = loss.forward(output, target)
        #print("myloss : {}".format(myloss))
        optimizer.zero_grad() # grad = 0
        myloss.backward()     # grad = grad + updated_grad (updated_grad is decided by tensor.data, tensor.grad_fn)
        optimizer.step()      # data = opt(data, grad)
    print(optimizer.param_groups[0]['lr'])
    # should be added after optimizer.step()
    scheduler.step()
    print(scheduler.get_last_lr()[0])

print("=====para=====")
for para in model.parameters():
    print(para.shape)

0.01
0.0095
0.0095
0.009025
0.009025
0.00857375
0.00857375
0.0081450625
0.0081450625
0.007737809374999998
0.007737809374999998
0.007350918906249998
0.007350918906249998
0.006983372960937498
0.006983372960937498
0.006634204312890623
0.006634204312890623
0.006302494097246091
0.006302494097246091
0.005987369392383787
=====para=====
torch.Size([10, 10])
torch.Size([10, 10])
torch.Size([1, 10])
