# PyTorch手刻一個多輸入(即有批次的多層NN架構)

In [7]:
%reset -f
import torch

dtype = torch.FloatTensor
N, D_in, H, D_out = 128, 784, 120, 10

x = torch.randn(N, D_in, requires_grad=False).type(dtype)
y = torch.randn(N, D_out, requires_grad=False).type(dtype)

w1 = torch.randn(D_in, H, requires_grad=True).type(dtype)
b1 = torch.randn(H, requires_grad=True).type(dtype)

w2 = torch.randn(H, D_out, requires_grad=True).type(dtype)
b2 = torch.randn(D_out, requires_grad=True).type(dtype)


In [15]:
y.shape

torch.Size([128, 10])

In [9]:
lr = 1e-6
for epoch in range(5000):
    a1 = torch.mm(x, w1) + b1
    z1 = a1.clamp(min=0)
    a2 = torch.mm(z1, w2) + b2
    loss = (y - a2).pow(2).sum()
    loss.backward()
    
    w1.data -= lr * w1.grad.data
    w2.data -= lr * w2.grad.data
    b1.data -= lr * b1.grad.data
    b2.data -= lr * b2.grad.data

    w1.grad.data.zero_()
    w2.grad.data.zero_()
    b1.grad.data.zero_()
    b2.grad.data.zero_()
    
    print(epoch, loss)

0 tensor(1665.5991, grad_fn=<SumBackward0>)
1 tensor(1664.1943, grad_fn=<SumBackward0>)
2 tensor(1662.7991, grad_fn=<SumBackward0>)
3 tensor(1661.3693, grad_fn=<SumBackward0>)
4 tensor(1659.9557, grad_fn=<SumBackward0>)
5 tensor(1658.5460, grad_fn=<SumBackward0>)
6 tensor(1657.1495, grad_fn=<SumBackward0>)
7 tensor(1655.7620, grad_fn=<SumBackward0>)
8 tensor(1654.3899, grad_fn=<SumBackward0>)
9 tensor(1653.0225, grad_fn=<SumBackward0>)
10 tensor(1651.6674, grad_fn=<SumBackward0>)
11 tensor(1650.3196, grad_fn=<SumBackward0>)
12 tensor(1648.9834, grad_fn=<SumBackward0>)
13 tensor(1647.6564, grad_fn=<SumBackward0>)
14 tensor(1646.3386, grad_fn=<SumBackward0>)
15 tensor(1645.0283, grad_fn=<SumBackward0>)
16 tensor(1643.7267, grad_fn=<SumBackward0>)
17 tensor(1642.4353, grad_fn=<SumBackward0>)
18 tensor(1641.1531, grad_fn=<SumBackward0>)
19 tensor(1639.8784, grad_fn=<SumBackward0>)
20 tensor(1638.6125, grad_fn=<SumBackward0>)
21 tensor(1637.3533, grad_fn=<SumBackward0>)
22 tensor(1636.1082,

318 tensor(1444.8464, grad_fn=<SumBackward0>)
319 tensor(1444.5040, grad_fn=<SumBackward0>)
320 tensor(1444.1626, grad_fn=<SumBackward0>)
321 tensor(1443.8215, grad_fn=<SumBackward0>)
322 tensor(1443.4833, grad_fn=<SumBackward0>)
323 tensor(1443.1444, grad_fn=<SumBackward0>)
324 tensor(1442.8059, grad_fn=<SumBackward0>)
325 tensor(1442.4705, grad_fn=<SumBackward0>)
326 tensor(1442.1348, grad_fn=<SumBackward0>)
327 tensor(1441.7998, grad_fn=<SumBackward0>)
328 tensor(1441.4663, grad_fn=<SumBackward0>)
329 tensor(1441.1348, grad_fn=<SumBackward0>)
330 tensor(1440.8030, grad_fn=<SumBackward0>)
331 tensor(1440.4717, grad_fn=<SumBackward0>)
332 tensor(1440.1431, grad_fn=<SumBackward0>)
333 tensor(1439.8132, grad_fn=<SumBackward0>)
334 tensor(1439.4851, grad_fn=<SumBackward0>)
335 tensor(1439.1573, grad_fn=<SumBackward0>)
336 tensor(1438.8313, grad_fn=<SumBackward0>)
337 tensor(1438.5067, grad_fn=<SumBackward0>)
338 tensor(1438.1814, grad_fn=<SumBackward0>)
339 tensor(1437.8584, grad_fn=<Sum

656 tensor(1365.1002, grad_fn=<SumBackward0>)
657 tensor(1364.9269, grad_fn=<SumBackward0>)
658 tensor(1364.7539, grad_fn=<SumBackward0>)
659 tensor(1364.5807, grad_fn=<SumBackward0>)
660 tensor(1364.4080, grad_fn=<SumBackward0>)
661 tensor(1364.2352, grad_fn=<SumBackward0>)
662 tensor(1364.0631, grad_fn=<SumBackward0>)
663 tensor(1363.8914, grad_fn=<SumBackward0>)
664 tensor(1363.7192, grad_fn=<SumBackward0>)
665 tensor(1363.5474, grad_fn=<SumBackward0>)
666 tensor(1363.3759, grad_fn=<SumBackward0>)
667 tensor(1363.2042, grad_fn=<SumBackward0>)
668 tensor(1363.0332, grad_fn=<SumBackward0>)
669 tensor(1362.8621, grad_fn=<SumBackward0>)
670 tensor(1362.6909, grad_fn=<SumBackward0>)
671 tensor(1362.5200, grad_fn=<SumBackward0>)
672 tensor(1362.3503, grad_fn=<SumBackward0>)
673 tensor(1362.1815, grad_fn=<SumBackward0>)
674 tensor(1362.0125, grad_fn=<SumBackward0>)
675 tensor(1361.8431, grad_fn=<SumBackward0>)
676 tensor(1361.6740, grad_fn=<SumBackward0>)
677 tensor(1361.5051, grad_fn=<Sum

990 tensor(1316.0808, grad_fn=<SumBackward0>)
991 tensor(1315.9501, grad_fn=<SumBackward0>)
992 tensor(1315.8195, grad_fn=<SumBackward0>)
993 tensor(1315.6887, grad_fn=<SumBackward0>)
994 tensor(1315.5583, grad_fn=<SumBackward0>)
995 tensor(1315.4280, grad_fn=<SumBackward0>)
996 tensor(1315.2979, grad_fn=<SumBackward0>)
997 tensor(1315.1675, grad_fn=<SumBackward0>)
998 tensor(1315.0372, grad_fn=<SumBackward0>)
999 tensor(1314.9072, grad_fn=<SumBackward0>)
1000 tensor(1314.7772, grad_fn=<SumBackward0>)
1001 tensor(1314.6472, grad_fn=<SumBackward0>)
1002 tensor(1314.5172, grad_fn=<SumBackward0>)
1003 tensor(1314.3875, grad_fn=<SumBackward0>)
1004 tensor(1314.2576, grad_fn=<SumBackward0>)
1005 tensor(1314.1279, grad_fn=<SumBackward0>)
1006 tensor(1313.9983, grad_fn=<SumBackward0>)
1007 tensor(1313.8689, grad_fn=<SumBackward0>)
1008 tensor(1313.7395, grad_fn=<SumBackward0>)
1009 tensor(1313.6101, grad_fn=<SumBackward0>)
1010 tensor(1313.4808, grad_fn=<SumBackward0>)
1011 tensor(1313.3517, 

1168 tensor(1293.9333, grad_fn=<SumBackward0>)
1169 tensor(1293.8149, grad_fn=<SumBackward0>)
1170 tensor(1293.6964, grad_fn=<SumBackward0>)
1171 tensor(1293.5780, grad_fn=<SumBackward0>)
1172 tensor(1293.4597, grad_fn=<SumBackward0>)
1173 tensor(1293.3413, grad_fn=<SumBackward0>)
1174 tensor(1293.2233, grad_fn=<SumBackward0>)
1175 tensor(1293.1051, grad_fn=<SumBackward0>)
1176 tensor(1292.9869, grad_fn=<SumBackward0>)
1177 tensor(1292.8690, grad_fn=<SumBackward0>)
1178 tensor(1292.7510, grad_fn=<SumBackward0>)
1179 tensor(1292.6331, grad_fn=<SumBackward0>)
1180 tensor(1292.5154, grad_fn=<SumBackward0>)
1181 tensor(1292.3975, grad_fn=<SumBackward0>)
1182 tensor(1292.2798, grad_fn=<SumBackward0>)
1183 tensor(1292.1622, grad_fn=<SumBackward0>)
1184 tensor(1292.0446, grad_fn=<SumBackward0>)
1185 tensor(1291.9270, grad_fn=<SumBackward0>)
1186 tensor(1291.8096, grad_fn=<SumBackward0>)
1187 tensor(1291.6921, grad_fn=<SumBackward0>)
1188 tensor(1291.5747, grad_fn=<SumBackward0>)
1189 tensor(1

1503 tensor(1257.3894, grad_fn=<SumBackward0>)
1504 tensor(1257.2891, grad_fn=<SumBackward0>)
1505 tensor(1257.1887, grad_fn=<SumBackward0>)
1506 tensor(1257.0886, grad_fn=<SumBackward0>)
1507 tensor(1256.9885, grad_fn=<SumBackward0>)
1508 tensor(1256.8882, grad_fn=<SumBackward0>)
1509 tensor(1256.7882, grad_fn=<SumBackward0>)
1510 tensor(1256.6881, grad_fn=<SumBackward0>)
1511 tensor(1256.5881, grad_fn=<SumBackward0>)
1512 tensor(1256.4882, grad_fn=<SumBackward0>)
1513 tensor(1256.3882, grad_fn=<SumBackward0>)
1514 tensor(1256.2886, grad_fn=<SumBackward0>)
1515 tensor(1256.1887, grad_fn=<SumBackward0>)
1516 tensor(1256.0889, grad_fn=<SumBackward0>)
1517 tensor(1255.9893, grad_fn=<SumBackward0>)
1518 tensor(1255.8896, grad_fn=<SumBackward0>)
1519 tensor(1255.7899, grad_fn=<SumBackward0>)
1520 tensor(1255.6904, grad_fn=<SumBackward0>)
1521 tensor(1255.5909, grad_fn=<SumBackward0>)
1522 tensor(1255.4915, grad_fn=<SumBackward0>)
1523 tensor(1255.3918, grad_fn=<SumBackward0>)
1524 tensor(1

1838 tensor(1226.3300, grad_fn=<SumBackward0>)
1839 tensor(1226.2446, grad_fn=<SumBackward0>)
1840 tensor(1226.1592, grad_fn=<SumBackward0>)
1841 tensor(1226.0737, grad_fn=<SumBackward0>)
1842 tensor(1225.9884, grad_fn=<SumBackward0>)
1843 tensor(1225.9031, grad_fn=<SumBackward0>)
1844 tensor(1225.8179, grad_fn=<SumBackward0>)
1845 tensor(1225.7325, grad_fn=<SumBackward0>)
1846 tensor(1225.6472, grad_fn=<SumBackward0>)
1847 tensor(1225.5620, grad_fn=<SumBackward0>)
1848 tensor(1225.4771, grad_fn=<SumBackward0>)
1849 tensor(1225.3920, grad_fn=<SumBackward0>)
1850 tensor(1225.3069, grad_fn=<SumBackward0>)
1851 tensor(1225.2218, grad_fn=<SumBackward0>)
1852 tensor(1225.1370, grad_fn=<SumBackward0>)
1853 tensor(1225.0520, grad_fn=<SumBackward0>)
1854 tensor(1224.9670, grad_fn=<SumBackward0>)
1855 tensor(1224.8822, grad_fn=<SumBackward0>)
1856 tensor(1224.7975, grad_fn=<SumBackward0>)
1857 tensor(1224.7128, grad_fn=<SumBackward0>)
1858 tensor(1224.6281, grad_fn=<SumBackward0>)
1859 tensor(1

2171 tensor(1199.4645, grad_fn=<SumBackward0>)
2172 tensor(1199.3915, grad_fn=<SumBackward0>)
2173 tensor(1199.3184, grad_fn=<SumBackward0>)
2174 tensor(1199.2456, grad_fn=<SumBackward0>)
2175 tensor(1199.1727, grad_fn=<SumBackward0>)
2176 tensor(1199.0999, grad_fn=<SumBackward0>)
2177 tensor(1199.0271, grad_fn=<SumBackward0>)
2178 tensor(1198.9541, grad_fn=<SumBackward0>)
2179 tensor(1198.8813, grad_fn=<SumBackward0>)
2180 tensor(1198.8087, grad_fn=<SumBackward0>)
2181 tensor(1198.7361, grad_fn=<SumBackward0>)
2182 tensor(1198.6635, grad_fn=<SumBackward0>)
2183 tensor(1198.5907, grad_fn=<SumBackward0>)
2184 tensor(1198.5181, grad_fn=<SumBackward0>)
2185 tensor(1198.4456, grad_fn=<SumBackward0>)
2186 tensor(1198.3730, grad_fn=<SumBackward0>)
2187 tensor(1198.3007, grad_fn=<SumBackward0>)
2188 tensor(1198.2283, grad_fn=<SumBackward0>)
2189 tensor(1198.1558, grad_fn=<SumBackward0>)
2190 tensor(1198.0835, grad_fn=<SumBackward0>)
2191 tensor(1198.0111, grad_fn=<SumBackward0>)
2192 tensor(1

2505 tensor(1176.9004, grad_fn=<SumBackward0>)
2506 tensor(1176.8381, grad_fn=<SumBackward0>)
2507 tensor(1176.7758, grad_fn=<SumBackward0>)
2508 tensor(1176.7135, grad_fn=<SumBackward0>)
2509 tensor(1176.6511, grad_fn=<SumBackward0>)
2510 tensor(1176.5890, grad_fn=<SumBackward0>)
2511 tensor(1176.5266, grad_fn=<SumBackward0>)
2512 tensor(1176.4644, grad_fn=<SumBackward0>)
2513 tensor(1176.4023, grad_fn=<SumBackward0>)
2514 tensor(1176.3402, grad_fn=<SumBackward0>)
2515 tensor(1176.2782, grad_fn=<SumBackward0>)
2516 tensor(1176.2161, grad_fn=<SumBackward0>)
2517 tensor(1176.1541, grad_fn=<SumBackward0>)
2518 tensor(1176.0920, grad_fn=<SumBackward0>)
2519 tensor(1176.0300, grad_fn=<SumBackward0>)
2520 tensor(1175.9680, grad_fn=<SumBackward0>)
2521 tensor(1175.9061, grad_fn=<SumBackward0>)
2522 tensor(1175.8442, grad_fn=<SumBackward0>)
2523 tensor(1175.7823, grad_fn=<SumBackward0>)
2524 tensor(1175.7205, grad_fn=<SumBackward0>)
2525 tensor(1175.6587, grad_fn=<SumBackward0>)
2526 tensor(1

2844 tensor(1157.3508, grad_fn=<SumBackward0>)
2845 tensor(1157.2976, grad_fn=<SumBackward0>)
2846 tensor(1157.2444, grad_fn=<SumBackward0>)
2847 tensor(1157.1913, grad_fn=<SumBackward0>)
2848 tensor(1157.1382, grad_fn=<SumBackward0>)
2849 tensor(1157.0852, grad_fn=<SumBackward0>)
2850 tensor(1157.0321, grad_fn=<SumBackward0>)
2851 tensor(1156.9789, grad_fn=<SumBackward0>)
2852 tensor(1156.9259, grad_fn=<SumBackward0>)
2853 tensor(1156.8729, grad_fn=<SumBackward0>)
2854 tensor(1156.8199, grad_fn=<SumBackward0>)
2855 tensor(1156.7671, grad_fn=<SumBackward0>)
2856 tensor(1156.7141, grad_fn=<SumBackward0>)
2857 tensor(1156.6613, grad_fn=<SumBackward0>)
2858 tensor(1156.6084, grad_fn=<SumBackward0>)
2859 tensor(1156.5555, grad_fn=<SumBackward0>)
2860 tensor(1156.5027, grad_fn=<SumBackward0>)
2861 tensor(1156.4498, grad_fn=<SumBackward0>)
2862 tensor(1156.3971, grad_fn=<SumBackward0>)
2863 tensor(1156.3442, grad_fn=<SumBackward0>)
2864 tensor(1156.2916, grad_fn=<SumBackward0>)
2865 tensor(1

3188 tensor(1140.4407, grad_fn=<SumBackward0>)
3189 tensor(1140.3953, grad_fn=<SumBackward0>)
3190 tensor(1140.3500, grad_fn=<SumBackward0>)
3191 tensor(1140.3047, grad_fn=<SumBackward0>)
3192 tensor(1140.2595, grad_fn=<SumBackward0>)
3193 tensor(1140.2142, grad_fn=<SumBackward0>)
3194 tensor(1140.1691, grad_fn=<SumBackward0>)
3195 tensor(1140.1239, grad_fn=<SumBackward0>)
3196 tensor(1140.0786, grad_fn=<SumBackward0>)
3197 tensor(1140.0336, grad_fn=<SumBackward0>)
3198 tensor(1139.9883, grad_fn=<SumBackward0>)
3199 tensor(1139.9432, grad_fn=<SumBackward0>)
3200 tensor(1139.8982, grad_fn=<SumBackward0>)
3201 tensor(1139.8531, grad_fn=<SumBackward0>)
3202 tensor(1139.8081, grad_fn=<SumBackward0>)
3203 tensor(1139.7631, grad_fn=<SumBackward0>)
3204 tensor(1139.7180, grad_fn=<SumBackward0>)
3205 tensor(1139.6731, grad_fn=<SumBackward0>)
3206 tensor(1139.6282, grad_fn=<SumBackward0>)
3207 tensor(1139.5831, grad_fn=<SumBackward0>)
3208 tensor(1139.5383, grad_fn=<SumBackward0>)
3209 tensor(1

3524 tensor(1126.2501, grad_fn=<SumBackward0>)
3525 tensor(1126.2102, grad_fn=<SumBackward0>)
3526 tensor(1126.1702, grad_fn=<SumBackward0>)
3527 tensor(1126.1302, grad_fn=<SumBackward0>)
3528 tensor(1126.0901, grad_fn=<SumBackward0>)
3529 tensor(1126.0502, grad_fn=<SumBackward0>)
3530 tensor(1126.0101, grad_fn=<SumBackward0>)
3531 tensor(1125.9702, grad_fn=<SumBackward0>)
3532 tensor(1125.9303, grad_fn=<SumBackward0>)
3533 tensor(1125.8906, grad_fn=<SumBackward0>)
3534 tensor(1125.8507, grad_fn=<SumBackward0>)
3535 tensor(1125.8109, grad_fn=<SumBackward0>)
3536 tensor(1125.7710, grad_fn=<SumBackward0>)
3537 tensor(1125.7313, grad_fn=<SumBackward0>)
3538 tensor(1125.6917, grad_fn=<SumBackward0>)
3539 tensor(1125.6519, grad_fn=<SumBackward0>)
3540 tensor(1125.6121, grad_fn=<SumBackward0>)
3541 tensor(1125.5725, grad_fn=<SumBackward0>)
3542 tensor(1125.5328, grad_fn=<SumBackward0>)
3543 tensor(1125.4933, grad_fn=<SumBackward0>)
3544 tensor(1125.4536, grad_fn=<SumBackward0>)
3545 tensor(1

3860 tensor(1113.9976, grad_fn=<SumBackward0>)
3861 tensor(1113.9644, grad_fn=<SumBackward0>)
3862 tensor(1113.9310, grad_fn=<SumBackward0>)
3863 tensor(1113.8976, grad_fn=<SumBackward0>)
3864 tensor(1113.8643, grad_fn=<SumBackward0>)
3865 tensor(1113.8311, grad_fn=<SumBackward0>)
3866 tensor(1113.7977, grad_fn=<SumBackward0>)
3867 tensor(1113.7645, grad_fn=<SumBackward0>)
3868 tensor(1113.7312, grad_fn=<SumBackward0>)
3869 tensor(1113.6981, grad_fn=<SumBackward0>)
3870 tensor(1113.6648, grad_fn=<SumBackward0>)
3871 tensor(1113.6318, grad_fn=<SumBackward0>)
3872 tensor(1113.5985, grad_fn=<SumBackward0>)
3873 tensor(1113.5653, grad_fn=<SumBackward0>)
3874 tensor(1113.5322, grad_fn=<SumBackward0>)
3875 tensor(1113.4990, grad_fn=<SumBackward0>)
3876 tensor(1113.4659, grad_fn=<SumBackward0>)
3877 tensor(1113.4327, grad_fn=<SumBackward0>)
3878 tensor(1113.3998, grad_fn=<SumBackward0>)
3879 tensor(1113.3667, grad_fn=<SumBackward0>)
3880 tensor(1113.3336, grad_fn=<SumBackward0>)
3881 tensor(1

4199 tensor(1103.5422, grad_fn=<SumBackward0>)
4200 tensor(1103.5135, grad_fn=<SumBackward0>)
4201 tensor(1103.4849, grad_fn=<SumBackward0>)
4202 tensor(1103.4564, grad_fn=<SumBackward0>)
4203 tensor(1103.4279, grad_fn=<SumBackward0>)
4204 tensor(1103.3993, grad_fn=<SumBackward0>)
4205 tensor(1103.3707, grad_fn=<SumBackward0>)
4206 tensor(1103.3422, grad_fn=<SumBackward0>)
4207 tensor(1103.3137, grad_fn=<SumBackward0>)
4208 tensor(1103.2853, grad_fn=<SumBackward0>)
4209 tensor(1103.2567, grad_fn=<SumBackward0>)
4210 tensor(1103.2284, grad_fn=<SumBackward0>)
4211 tensor(1103.2000, grad_fn=<SumBackward0>)
4212 tensor(1103.1714, grad_fn=<SumBackward0>)
4213 tensor(1103.1429, grad_fn=<SumBackward0>)
4214 tensor(1103.1147, grad_fn=<SumBackward0>)
4215 tensor(1103.0863, grad_fn=<SumBackward0>)
4216 tensor(1103.0579, grad_fn=<SumBackward0>)
4217 tensor(1103.0295, grad_fn=<SumBackward0>)
4218 tensor(1103.0012, grad_fn=<SumBackward0>)
4219 tensor(1102.9730, grad_fn=<SumBackward0>)
4220 tensor(1

4534 tensor(1094.7046, grad_fn=<SumBackward0>)
4535 tensor(1094.6802, grad_fn=<SumBackward0>)
4536 tensor(1094.6558, grad_fn=<SumBackward0>)
4537 tensor(1094.6315, grad_fn=<SumBackward0>)
4538 tensor(1094.6071, grad_fn=<SumBackward0>)
4539 tensor(1094.5828, grad_fn=<SumBackward0>)
4540 tensor(1094.5585, grad_fn=<SumBackward0>)
4541 tensor(1094.5342, grad_fn=<SumBackward0>)
4542 tensor(1094.5100, grad_fn=<SumBackward0>)
4543 tensor(1094.4856, grad_fn=<SumBackward0>)
4544 tensor(1094.4612, grad_fn=<SumBackward0>)
4545 tensor(1094.4371, grad_fn=<SumBackward0>)
4546 tensor(1094.4127, grad_fn=<SumBackward0>)
4547 tensor(1094.3885, grad_fn=<SumBackward0>)
4548 tensor(1094.3644, grad_fn=<SumBackward0>)
4549 tensor(1094.3401, grad_fn=<SumBackward0>)
4550 tensor(1094.3159, grad_fn=<SumBackward0>)
4551 tensor(1094.2916, grad_fn=<SumBackward0>)
4552 tensor(1094.2676, grad_fn=<SumBackward0>)
4553 tensor(1094.2433, grad_fn=<SumBackward0>)
4554 tensor(1094.2191, grad_fn=<SumBackward0>)
4555 tensor(1

4878 tensor(1086.9346, grad_fn=<SumBackward0>)
4879 tensor(1086.9137, grad_fn=<SumBackward0>)
4880 tensor(1086.8928, grad_fn=<SumBackward0>)
4881 tensor(1086.8718, grad_fn=<SumBackward0>)
4882 tensor(1086.8511, grad_fn=<SumBackward0>)
4883 tensor(1086.8303, grad_fn=<SumBackward0>)
4884 tensor(1086.8093, grad_fn=<SumBackward0>)
4885 tensor(1086.7886, grad_fn=<SumBackward0>)
4886 tensor(1086.7677, grad_fn=<SumBackward0>)
4887 tensor(1086.7468, grad_fn=<SumBackward0>)
4888 tensor(1086.7261, grad_fn=<SumBackward0>)
4889 tensor(1086.7053, grad_fn=<SumBackward0>)
4890 tensor(1086.6846, grad_fn=<SumBackward0>)
4891 tensor(1086.6638, grad_fn=<SumBackward0>)
4892 tensor(1086.6429, grad_fn=<SumBackward0>)
4893 tensor(1086.6222, grad_fn=<SumBackward0>)
4894 tensor(1086.6014, grad_fn=<SumBackward0>)
4895 tensor(1086.5808, grad_fn=<SumBackward0>)
4896 tensor(1086.5601, grad_fn=<SumBackward0>)
4897 tensor(1086.5393, grad_fn=<SumBackward0>)
4898 tensor(1086.5186, grad_fn=<SumBackward0>)
4899 tensor(1

In [1]:
import torch
from torch.autograd import Variable

dtype = torch.FloatTensor
N, D_in, H, D_out = 64, 1000, 100, 10

x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)

w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    y_pred = x.mm(w1).clamp(min=0).mm(w2)

    loss = (y_pred - y).pow(2).sum()
    # print(t, loss.data[0].item())
    print(loss)
    loss.backward()

    w1.data -= learning_rate * w1.grad.data
    w2.data -= learning_rate * w2.grad.data

    w1.grad.data.zero_()
    w2.grad.data.zero_()



tensor(22746592., grad_fn=<SumBackward0>)
tensor(15246758., grad_fn=<SumBackward0>)
tensor(11661868., grad_fn=<SumBackward0>)
tensor(9748469., grad_fn=<SumBackward0>)
tensor(8605668., grad_fn=<SumBackward0>)
tensor(7810116., grad_fn=<SumBackward0>)
tensor(7102151.5000, grad_fn=<SumBackward0>)
tensor(6389003.5000, grad_fn=<SumBackward0>)
tensor(5609677., grad_fn=<SumBackward0>)
tensor(4811203., grad_fn=<SumBackward0>)
tensor(4013734.5000, grad_fn=<SumBackward0>)
tensor(3281270.7500, grad_fn=<SumBackward0>)
tensor(2628764.5000, grad_fn=<SumBackward0>)
tensor(2084604.1250, grad_fn=<SumBackward0>)
tensor(1637876.1250, grad_fn=<SumBackward0>)
tensor(1286040.6250, grad_fn=<SumBackward0>)
tensor(1010285.4375, grad_fn=<SumBackward0>)
tensor(798795.5625, grad_fn=<SumBackward0>)
tensor(636363.5625, grad_fn=<SumBackward0>)
tensor(512653.6875, grad_fn=<SumBackward0>)
tensor(417641.1250, grad_fn=<SumBackward0>)
tensor(344626.6250, grad_fn=<SumBackward0>)
tensor(287853.6250, grad_fn=<SumBackward0>)


tensor(2.6461, grad_fn=<SumBackward0>)
tensor(2.5521, grad_fn=<SumBackward0>)
tensor(2.4613, grad_fn=<SumBackward0>)
tensor(2.3740, grad_fn=<SumBackward0>)
tensor(2.2898, grad_fn=<SumBackward0>)
tensor(2.2087, grad_fn=<SumBackward0>)
tensor(2.1304, grad_fn=<SumBackward0>)
tensor(2.0550, grad_fn=<SumBackward0>)
tensor(1.9820, grad_fn=<SumBackward0>)
tensor(1.9120, grad_fn=<SumBackward0>)
tensor(1.8444, grad_fn=<SumBackward0>)
tensor(1.7792, grad_fn=<SumBackward0>)
tensor(1.7163, grad_fn=<SumBackward0>)
tensor(1.6555, grad_fn=<SumBackward0>)
tensor(1.5971, grad_fn=<SumBackward0>)
tensor(1.5407, grad_fn=<SumBackward0>)
tensor(1.4863, grad_fn=<SumBackward0>)
tensor(1.4339, grad_fn=<SumBackward0>)
tensor(1.3833, grad_fn=<SumBackward0>)
tensor(1.3346, grad_fn=<SumBackward0>)
tensor(1.2875, grad_fn=<SumBackward0>)
tensor(1.2421, grad_fn=<SumBackward0>)
tensor(1.1984, grad_fn=<SumBackward0>)
tensor(1.1562, grad_fn=<SumBackward0>)
tensor(1.1155, grad_fn=<SumBackward0>)
tensor(1.0763, grad_fn=<S

In [9]:
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms

input_size = 784       # The image size = 28 x 28 = 784
hidden_size = 500      # The number of nodes at the hidden layer
num_classes = 10       # The number of output classes. In this case, from 0 to 9
num_epochs = 5         # The number of times entire dataset is trained
batch_size = 100       # The size of input data took for one iteration
learning_rate = 0.001  # The speed of convergence
dtype = torch.FloatTensor

In [10]:
train_dataset = dsets.MNIST(root='./data',
                           train=True,
                           transform=transforms.ToTensor(),
                           download=True)

test_dataset = dsets.MNIST(root='./data',
                           train=False,
                           transform=transforms.ToTensor())

In [21]:
x = train_dataset.train_data.float() /255

def one_hot_embedding(labels, num_classes):
    y = torch.eye(num_classes) 
    return y[labels] 

x = x.view(-1, 28*28)
y = one_hot_embedding(train_dataset.train_labels, 10)



In [22]:
w1 = torch.randn(input_size, hidden_size, requires_grad=True).type(dtype)
b1 = torch.randn(hidden_size, requires_grad=True).type(dtype)

w2 = torch.randn(hidden_size, num_classes, requires_grad=True).type(dtype)
b2 = torch.randn(num_classes, requires_grad=True).type(dtype)


In [19]:
torch.mm(x, w1) + b1

tensor([[-23.2466,  -2.3021,  -5.1985,  ...,  -3.1002,  12.0090,  -5.2520],
        [-33.3994,  -3.8443, -14.6838,  ...,  -1.5646,   2.1384,  -8.2969],
        [-10.8907,  -8.1168,  -8.3095,  ...,   3.5389,   6.6156, -12.4017],
        ...,
        [-10.6083,   0.8321,   2.1596,  ..., -12.6337,  13.0580,  -7.1332],
        [ -6.5841,  -0.8262,   1.5488,  ...,  -2.2759,   3.0573,   4.3280],
        [ -6.7206,   3.2759, -19.2771,  ...,  10.6873,  -7.8474,  -3.8390]],
       grad_fn=<AddBackward0>)

In [23]:
lr = 1e-3
for epoch in range(5000):
    a1 = torch.mm(x, w1) + b1
    z1 = a1.clamp(min=0)
    a2 = torch.mm(z1, w2) + b2
    loss = ((y - a2).pow(2).sum()) / 60000
    loss.backward()
    
    w1.data -= lr * w1.grad.data
    w2.data -= lr * w2.grad.data
    b1.data -= lr * b1.grad.data
    b2.data -= lr * b2.grad.data

    w1.grad.data.zero_()
    w2.grad.data.zero_()
    b1.grad.data.zero_()
    b2.grad.data.zero_()
    
    print(epoch, loss.item())

0 211739.140625
1 177583936.0
2 5411841.0
3 464.275390625
4 462.3185729980469
5 460.3740539550781
6 458.4357604980469
7 456.5135803222656
8 454.59405517578125
9 452.6898193359375
10 450.7911376953125
11 448.9046936035156
12 447.0278015136719
13 445.1597595214844
14 443.3017272949219
15 441.4555969238281
16 439.6151428222656
17 437.78515625
18 435.95904541015625
19 434.1485290527344
20 432.34381103515625
21 430.5506286621094
22 428.7627258300781
23 426.98602294921875
24 425.2127685546875
25 423.4541015625
26 421.70147705078125
27 419.9590148925781
28 418.2230224609375
29 416.49273681640625
30 414.77374267578125
31 413.0616455078125
32 411.357666015625
33 409.66278076171875
34 407.9742736816406
35 406.294189453125
36 404.62188720703125
37 402.95751953125
38 401.30487060546875
39 399.6538391113281
40 398.0116882324219
41 396.3793029785156
42 394.751708984375
43 393.13037109375
44 391.5211181640625
45 389.91973876953125
46 388.3201599121094
47 386.7317810058594
48 385.1478576660156
49 383.

374 105.86263275146484
375 105.4548110961914
376 105.04875946044922
377 104.64427185058594
378 104.24080657958984
379 103.8393325805664
380 103.44005584716797
381 103.04203796386719
382 102.64535522460938
383 102.24983978271484
384 101.85653686523438
385 101.4645004272461
386 101.07395935058594
387 100.68546295166016
388 100.29740905761719
389 99.91327667236328
390 99.52819061279297
391 99.14594268798828
392 98.76520538330078
393 98.38568878173828
394 98.00720977783203
395 97.6313247680664
396 97.2559814453125
397 96.88291931152344
398 96.51074981689453
399 96.14097595214844
400 95.77173614501953
401 95.40460205078125
402 95.03890228271484
403 94.6739730834961
404 94.31096649169922
405 93.94966125488281
406 93.58953094482422
407 93.2302474975586
408 92.87330627441406
409 92.51715087890625
410 92.16256713867188
411 91.81046295166016
412 91.45878601074219
413 91.1086654663086
414 90.76010131835938
415 90.41199493408203
416 90.0667953491211
417 89.72239685058594
418 89.37901306152344
419 

745 26.923458099365234
746 26.830795288085938
747 26.73836326599121
748 26.646116256713867
749 26.554718017578125
750 26.463336944580078
751 26.372356414794922
752 26.281742095947266
753 26.19159507751465
754 26.10171127319336
755 26.01220703125
756 25.92289924621582
757 25.834238052368164
758 25.745609283447266
759 25.657636642456055
760 25.569778442382812
761 25.482501983642578
762 25.395275115966797
763 25.308595657348633
764 25.22208595275879
765 25.136045455932617
766 25.05036163330078
767 24.9649715423584
768 24.879823684692383
769 24.79521942138672
770 24.71084213256836
771 24.626733779907227
772 24.543136596679688
773 24.459508895874023
774 24.37659454345703
775 24.293964385986328
776 24.211458206176758
777 24.129371643066406
778 24.047657012939453
779 23.966184616088867
780 23.88509178161621
781 23.804237365722656
782 23.723766326904297
783 23.64360809326172
784 23.563779830932617
785 23.484281539916992
786 23.405048370361328
787 23.3260555267334
788 23.24768829345703
789 23.1

1104 9.024632453918457
1105 9.001672744750977
1106 8.978886604309082
1107 8.956180572509766
1108 8.933509826660156
1109 8.91097354888916
1110 8.888535499572754
1111 8.866081237792969
1112 8.843852996826172
1113 8.821640014648438
1114 8.799483299255371
1115 8.777350425720215
1116 8.755374908447266
1117 8.733550071716309
1118 8.711692810058594
1119 8.68997859954834
1120 8.668377876281738
1121 8.646815299987793
1122 8.625375747680664
1123 8.604070663452148
1124 8.582748413085938
1125 8.561507225036621
1126 8.540349006652832
1127 8.5192289352417
1128 8.498244285583496
1129 8.477330207824707
1130 8.456439971923828
1131 8.435797691345215
1132 8.415094375610352
1133 8.394468307495117
1134 8.373964309692383
1135 8.353569984436035
1136 8.333187103271484
1137 8.312948226928711
1138 8.292668342590332
1139 8.272546768188477
1140 8.252551078796387
1141 8.232536315917969
1142 8.212686538696289
1143 8.19284725189209
1144 8.173080444335938
1145 8.153348922729492
1146 8.133818626403809
1147 8.114334106

1461 4.504082679748535
1462 4.4978814125061035
1463 4.4917402267456055
1464 4.485625743865967
1465 4.479520320892334
1466 4.47341251373291
1467 4.467335224151611
1468 4.4612908363342285
1469 4.455262660980225
1470 4.449274063110352
1471 4.44326114654541
1472 4.437319755554199
1473 4.431396961212158
1474 4.425468444824219
1475 4.419562339782715
1476 4.413666725158691
1477 4.407799243927002
1478 4.401949882507324
1479 4.396126747131348
1480 4.390321731567383
1481 4.384539604187012
1482 4.378816604614258
1483 4.37307596206665
1484 4.36733865737915
1485 4.361620903015137
1486 4.355916976928711
1487 4.350260257720947
1488 4.344624996185303
1489 4.339017391204834
1490 4.333378314971924
1491 4.327813625335693
1492 4.322251796722412
1493 4.316709041595459
1494 4.311168670654297
1495 4.30565071105957
1496 4.300169944763184
1497 4.294674396514893
1498 4.289173603057861
1499 4.283782482147217
1500 4.278370380401611
1501 4.272965431213379
1502 4.2676191329956055
1503 4.262252330780029
1504 4.25688

1813 3.2134766578674316
1814 3.2114298343658447
1815 3.2093875408172607
1816 3.207385301589966
1817 3.2053234577178955
1818 3.2033092975616455
1819 3.201296091079712
1820 3.19928240776062
1821 3.197263240814209
1822 3.195251941680908
1823 3.193258285522461
1824 3.191274881362915
1825 3.1892995834350586
1826 3.187309741973877
1827 3.185328483581543
1828 3.1833510398864746
1829 3.1813955307006836
1830 3.17942476272583
1831 3.1774797439575195
1832 3.17553448677063
1833 3.173588752746582
1834 3.17166805267334
1835 3.1697299480438232
1836 3.16780161857605
1837 3.165874481201172
1838 3.163940668106079
1839 3.162031412124634
1840 3.160111427307129
1841 3.1582090854644775
1842 3.156320333480835
1843 3.15444278717041
1844 3.152543067932129
1845 3.1506595611572266
1846 3.1487808227539062
1847 3.146914482116699
1848 3.145063877105713
1849 3.143178701400757
1850 3.1413137912750244
1851 3.1394736766815186
1852 3.13763165473938
1853 3.1357929706573486
1854 3.1339468955993652
1855 3.132105827331543
1

2163 2.731987476348877
2164 2.731055498123169
2165 2.730133295059204
2166 2.729219436645508
2167 2.728292226791382
2168 2.7273659706115723
2169 2.7264485359191895
2170 2.725537061691284
2171 2.724637746810913
2172 2.7237191200256348
2173 2.722801446914673
2174 2.721891164779663
2175 2.720979690551758
2176 2.720087766647339
2177 2.7191853523254395
2178 2.718278646469116
2179 2.717369794845581
2180 2.7164647579193115
2181 2.7155628204345703
2182 2.7146859169006348
2183 2.7137880325317383
2184 2.712890625
2185 2.7119851112365723
2186 2.7111129760742188
2187 2.7102134227752686
2188 2.7093255519866943
2189 2.708448648452759
2190 2.707568645477295
2191 2.706666946411133
2192 2.7057905197143555
2193 2.704914093017578
2194 2.70402455329895
2195 2.7031571865081787
2196 2.7022907733917236
2197 2.7014000415802
2198 2.7005348205566406
2199 2.6996495723724365
2200 2.6987900733947754
2201 2.6979284286499023
2202 2.6970794200897217
2203 2.6961989402770996
2204 2.6953258514404297
2205 2.69447326660156

2513 2.4802749156951904
2514 2.4796979427337646
2515 2.4791247844696045
2516 2.478562831878662
2517 2.4779775142669678
2518 2.4774277210235596
2519 2.476862668991089
2520 2.4762752056121826
2521 2.4757254123687744
2522 2.4751670360565186
2523 2.4745805263519287
2524 2.474011182785034
2525 2.4734385013580322
2526 2.4728899002075195
2527 2.4723196029663086
2528 2.471755266189575
2529 2.4712042808532715
2530 2.470625638961792
2531 2.4700822830200195
2532 2.4695122241973877
2533 2.468956470489502
2534 2.4683871269226074
2535 2.4678335189819336
2536 2.4672763347625732
2537 2.4666998386383057
2538 2.4661521911621094
2539 2.465595006942749
2540 2.4650304317474365
2541 2.4644932746887207
2542 2.4639344215393066
2543 2.4633679389953613
2544 2.462817668914795
2545 2.462275266647339
2546 2.461702823638916
2547 2.4611494541168213
2548 2.4606122970581055
2549 2.4600439071655273
2550 2.459496259689331
2551 2.458946943283081
2552 2.4584133625030518
2553 2.4578604698181152
2554 2.4573018550872803
2555

2861 2.3092284202575684
2862 2.3087990283966064
2863 2.3083765506744385
2864 2.3079473972320557
2865 2.3075201511383057
2866 2.3070924282073975
2867 2.306666135787964
2868 2.30625319480896
2869 2.3058226108551025
2870 2.305398941040039
2871 2.3049700260162354
2872 2.3045480251312256
2873 2.3041326999664307
2874 2.3037047386169434
2875 2.3032755851745605
2876 2.3028464317321777
2877 2.3024280071258545
2878 2.3020167350769043
2879 2.301591396331787
2880 2.301164150238037
2881 2.3007419109344482
2882 2.300326108932495
2883 2.2999017238616943
2884 2.2994935512542725
2885 2.2990708351135254
2886 2.298647403717041
2887 2.2982258796691895
2888 2.297804355621338
2889 2.2974026203155518
2890 2.2969775199890137
2891 2.296558380126953
2892 2.296140670776367
2893 2.295717477798462
2894 2.295297145843506
2895 2.2948741912841797
2896 2.294473886489868
2897 2.2940382957458496
2898 2.293628454208374
2899 2.2932112216949463
2900 2.292804479598999
2901 2.2923977375030518
2902 2.291978359222412
2903 2.29

3210 2.175581216812134
3211 2.175229549407959
3212 2.174887180328369
3213 2.174532651901245
3214 2.174205780029297
3215 2.1738529205322266
3216 2.173513412475586
3217 2.173172950744629
3218 2.172818899154663
3219 2.172476053237915
3220 2.172131061553955
3221 2.1717941761016846
3222 2.1714484691619873
3223 2.1711061000823975
3224 2.170764684677124
3225 2.1704256534576416
3226 2.1700809001922607
3227 2.1697373390197754
3228 2.169393301010132
3229 2.1690566539764404
3230 2.16870379447937
3231 2.1683671474456787
3232 2.168012857437134
3233 2.167670249938965
3234 2.1673245429992676
3235 2.166987657546997
3236 2.16664719581604
3237 2.166308641433716
3238 2.1659679412841797
3239 2.16563081741333
3240 2.165287971496582
3241 2.164956569671631
3242 2.164609909057617
3243 2.164271354675293
3244 2.163925886154175
3245 2.1635901927948
3246 2.163254976272583
3247 2.162904977798462
3248 2.162571907043457
3249 2.162231206893921
3250 2.161888837814331
3251 2.1615452766418457
3252 2.1612133979797363
325

3560 2.0649473667144775
3561 2.0646543502807617
3562 2.064359188079834
3563 2.0640738010406494
3564 2.063786745071411
3565 2.0634987354278564
3566 2.0632033348083496
3567 2.0629208087921143
3568 2.0626296997070312
3569 2.06233811378479
3570 2.062046766281128
3571 2.061760187149048
3572 2.0614731311798096
3573 2.0611796379089355
3574 2.060896873474121
3575 2.0606095790863037
3576 2.0603256225585938
3577 2.0600368976593018
3578 2.059749126434326
3579 2.0594680309295654
3580 2.0591840744018555
3581 2.058892011642456
3582 2.0586042404174805
3583 2.058314561843872
3584 2.058026075363159
3585 2.0577383041381836
3586 2.057446002960205
3587 2.05716872215271
3588 2.056879758834839
3589 2.0565948486328125
3590 2.056316375732422
3591 2.0560288429260254
3592 2.055752754211426
3593 2.055454730987549
3594 2.0551767349243164
3595 2.0548946857452393
3596 2.054616689682007
3597 2.054325819015503
3598 2.054043769836426
3599 2.053755521774292
3600 2.0534629821777344
3601 2.0531797409057617
3602 2.0528984

3908 1.9717037677764893
3909 1.971449851989746
3910 1.9712039232254028
3911 1.9709556102752686
3912 1.9707112312316895
3913 1.970456600189209
3914 1.9702069759368896
3915 1.9699642658233643
3916 1.9697129726409912
3917 1.9694677591323853
3918 1.9692294597625732
3919 1.9689809083938599
3920 1.968734622001648
3921 1.9684937000274658
3922 1.9682464599609375
3923 1.9680023193359375
3924 1.9677480459213257
3925 1.9675010442733765
3926 1.9672560691833496
3927 1.9670073986053467
3928 1.9667621850967407
3929 1.9665154218673706
3930 1.9662672281265259
3931 1.9660249948501587
3932 1.9657715559005737
3933 1.9655303955078125
3934 1.9652864933013916
3935 1.9650384187698364
3936 1.964792013168335
3937 1.9645473957061768
3938 1.964300274848938
3939 1.9640575647354126
3940 1.9638121128082275
3941 1.9635705947875977
3942 1.9633252620697021
3943 1.9630851745605469
3944 1.9628398418426514
3945 1.9625983238220215
3946 1.9623610973358154
3947 1.9621155261993408
3948 1.9618730545043945
3949 1.96163201332092

4253 1.8918956518173218
4254 1.8916794061660767
4255 1.8914642333984375
4256 1.891244888305664
4257 1.8910261392593384
4258 1.8908121585845947
4259 1.890592098236084
4260 1.8903768062591553
4261 1.8901662826538086
4262 1.8899531364440918
4263 1.8897407054901123
4264 1.8895231485366821
4265 1.8893158435821533
4266 1.8891019821166992
4267 1.8888835906982422
4268 1.8886679410934448
4269 1.8884541988372803
4270 1.888240098953247
4271 1.8880259990692139
4272 1.8878042697906494
4273 1.8875921964645386
4274 1.8873794078826904
4275 1.8871634006500244
4276 1.8869473934173584
4277 1.8867344856262207
4278 1.886517882347107
4279 1.8862963914871216
4280 1.8860825300216675
4281 1.8858684301376343
4282 1.8856563568115234
4283 1.8854392766952515
4284 1.885225772857666
4285 1.8850139379501343
4286 1.8847932815551758
4287 1.8845854997634888
4288 1.8843741416931152
4289 1.8841614723205566
4290 1.8839482069015503
4291 1.8837416172027588
4292 1.8835270404815674
4293 1.8833143711090088
4294 1.88310444355010

4598 1.822086215019226
4599 1.8218982219696045
4600 1.8217090368270874
4601 1.821520209312439
4602 1.82133150100708
4603 1.821141004562378
4604 1.8209550380706787
4605 1.8207684755325317
4606 1.82057523727417
4607 1.8203871250152588
4608 1.820194959640503
4609 1.8200054168701172
4610 1.8198176622390747
4611 1.819630742073059
4612 1.8194500207901
4613 1.8192607164382935
4614 1.819071888923645
4615 1.8188841342926025
4616 1.818697452545166
4617 1.8185049295425415
4618 1.818312644958496
4619 1.8181216716766357
4620 1.8179351091384888
4621 1.8177486658096313
4622 1.817558765411377
4623 1.817370057106018
4624 1.8171838521957397
4625 1.8169960975646973
4626 1.8168094158172607
4627 1.8166217803955078
4628 1.8164399862289429
4629 1.8162537813186646
4630 1.8160665035247803
4631 1.8158762454986572
4632 1.8156895637512207
4633 1.8155031204223633
4634 1.8153159618377686
4635 1.8151272535324097
4636 1.8149422407150269
4637 1.8147567510604858
4638 1.8145712614059448
4639 1.8143808841705322
4640 1.81

4943 1.7606958150863647
4944 1.7605289220809937
4945 1.7603609561920166
4946 1.760194182395935
4947 1.7600284814834595
4948 1.7598614692687988
4949 1.759695053100586
4950 1.7595332860946655
4951 1.7593662738800049
4952 1.7592005729675293
4953 1.7590291500091553
4954 1.7588634490966797
4955 1.7586958408355713
4956 1.7585285902023315
4957 1.7583632469177246
4958 1.7581965923309326
4959 1.7580312490463257
4960 1.75786554813385
4961 1.7576984167099
4962 1.7575318813323975
4963 1.7573652267456055
4964 1.757197380065918
4965 1.7570325136184692
4966 1.7568665742874146
4967 1.7567001581192017
4968 1.7565343379974365
4969 1.7563691139221191
4970 1.756205439567566
4971 1.756040096282959
4972 1.7558742761611938
4973 1.7557039260864258
4974 1.7555387020111084
4975 1.7553728818893433
4976 1.7552084922790527
4977 1.755043387413025
4978 1.7548774480819702
4979 1.7547134160995483
4980 1.754549264907837
4981 1.7543787956237793
4982 1.7542190551757812
4983 1.7540544271469116
4984 1.7538880109786987
4985