In [1]:
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F

In [12]:
# Define the model
class Net(nn.Module):
    
    def __init__(self):
        super(Net, self).__init__()
        # 1 input image channel, 6 output channels, 5x5 square convolution kernel
        self.conv1 = nn.Conv2d(1, 6, 5)
        self.conv2 = nn.Conv2d(6, 16, 5)
        # an affine operation: y = Wx + b
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)
        
    def forward(self, x):
        # Max pooling over a (2, 2) window
        x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))
        # If the size is a square you can only specify a single number
        x = F.max_pool2d(F.relu(self.conv2(x)), 2)
        x = x.view(-1, self.num_flat_features(x))
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
    def num_flat_features(self, x):
        size = x.size()[1:] # all dimensions except the batch dimension
        num_features = 1
        for s in size:
            num_features *= s
        return num_features
    
net = Net()
print(net)
        

Net(
  (conv1): Conv2d (1, 6, kernel_size=(5, 5), stride=(1, 1))
  (conv2): Conv2d (6, 16, kernel_size=(5, 5), stride=(1, 1))
  (fc1): Linear(in_features=400, out_features=120)
  (fc2): Linear(in_features=120, out_features=84)
  (fc3): Linear(in_features=84, out_features=10)
)


In [13]:
params = list(net.parameters())
print(len(params))
print(params[5].size())

10
torch.Size([120])


In [14]:
input = Variable(torch.randn(1, 1, 32, 32))
out = net(input)
print(out)

Variable containing:
-0.0536 -0.0281  0.0550  0.1817 -0.0277 -0.0736 -0.0218  0.1565  0.0478 -0.1035
[torch.FloatTensor of size 1x10]



In [15]:
net.zero_grad()
out.backward(torch.rand(1, 11))

In [16]:
output = net(input)
target = Variable(torch.arange(1, 11))
criterion = nn.MSELoss()

loss = criterion(output, target)
print(loss)

Variable containing:
 38.3689
[torch.FloatTensor of size 1]



In [21]:
print(loss.grad_fn)
print(loss.grad_fn.next_functions[0][0])
print(loss.grad_fn.next_functions[0][0].next_functions[0][0])

<MseLossBackward object at 0x7f16dc7c95c0>
<AddmmBackward object at 0x7f16dc7c96a0>
<ExpandBackward object at 0x7f16dc7c95c0>


In [22]:
net.zero_grad() # zero out gradient buffer of all params

print('conv1.bias.grad before backward')
print(net.conv1.bias.grad)

loss.backward()

print('conv1.bias.grad after backward')
print(net.conv1.bias.grad)

conv1.bias.grad before backward
Variable containing:
 0
 0
 0
 0
 0
 0
[torch.FloatTensor of size 6]

conv1.bias.grad after backward
Variable containing:
1.00000e-02 *
  0.2416
 -5.9636
 -6.6599
  1.6504
 -6.4466
  0.3700
[torch.FloatTensor of size 6]



In [23]:
import torch.optim as optim

# Create your optimizer
optimizer = optim.SGD(net.parameters(), lr = 0.01)

# in your training loop:
optimizer.zero_grad() # zero the gradient buffers
output = net(input)
loss = criterion(output, target)
loss.backward()
optimizer.step()

In [24]:
for f in net.parameters():
    print(f.grad.data)


(0 ,0 ,.,.) = 
  0.0214  0.0927 -0.0605  0.0340  0.0788
  0.1073 -0.0275 -0.0552 -0.0006  0.0892
 -0.0348  0.0549 -0.0479 -0.0111  0.0352
  0.0445 -0.0618 -0.0365 -0.0708  0.0223
  0.0604  0.0043 -0.0331  0.0029 -0.1191

(1 ,0 ,.,.) = 
  0.0112  0.0161 -0.0511 -0.0753 -0.0268
  0.0297  0.1225 -0.0665 -0.0116  0.0312
 -0.0586  0.0143 -0.0671  0.0400  0.0035
  0.0159  0.0183  0.0218 -0.0871  0.0082
 -0.0467 -0.0029  0.0390 -0.0512 -0.0982

(2 ,0 ,.,.) = 
 -0.1537 -0.0334 -0.1009 -0.0718  0.0755
 -0.0474 -0.0820 -0.0173 -0.0649 -0.0361
  0.0074  0.0158 -0.0038 -0.0130 -0.0011
 -0.0146 -0.0347  0.0816  0.1114  0.0110
 -0.0856  0.1299  0.0802  0.0403 -0.0672

(3 ,0 ,.,.) = 
 -0.0514  0.0001  0.0612 -0.0241 -0.0026
  0.0450 -0.0949  0.1118 -0.0149 -0.0943
 -0.0222  0.0251  0.0167 -0.0088 -0.0023
  0.0281  0.0493 -0.0301  0.0189  0.0358
 -0.0142  0.0474 -0.0042 -0.0541  0.1050

(4 ,0 ,.,.) = 
  0.0653  0.0523 -0.1035 -0.0418 -0.0364
  0.0929  0.0005  0.0264 -0.1027 -0.0618
 -0.0305 -0.0613  