# PyTorch手刻一個多輸入(即有批次的多層NN架構)

In [7]:
%reset -f
import torch

dtype = torch.FloatTensor
N, D_in, H, D_out = 128, 784, 120, 10

x = torch.randn(N, D_in, requires_grad=False).type(dtype)
y = torch.randn(N, D_out, requires_grad=False).type(dtype)

w1 = torch.randn(D_in, H, requires_grad=True).type(dtype)
b1 = torch.randn(H, requires_grad=True).type(dtype)

w2 = torch.randn(H, D_out, requires_grad=True).type(dtype)
b2 = torch.randn(D_out, requires_grad=True).type(dtype)


In [15]:
y.shape

torch.Size([128, 10])

In [9]:
lr = 1e-6
for epoch in range(5000):
    a1 = torch.mm(x, w1) + b1
    z1 = a1.clamp(min=0)
    a2 = torch.mm(z1, w2) + b2
    loss = (y - a2).pow(2).sum()
    loss.backward()
    
    w1.data -= lr * w1.grad.data
    w2.data -= lr * w2.grad.data
    b1.data -= lr * b1.grad.data
    b2.data -= lr * b2.grad.data

    w1.grad.data.zero_()
    w2.grad.data.zero_()
    b1.grad.data.zero_()
    b2.grad.data.zero_()
    
    print(epoch, loss)

0 tensor(1665.5991, grad_fn=<SumBackward0>)
1 tensor(1664.1943, grad_fn=<SumBackward0>)
2 tensor(1662.7991, grad_fn=<SumBackward0>)
3 tensor(1661.3693, grad_fn=<SumBackward0>)
4 tensor(1659.9557, grad_fn=<SumBackward0>)
5 tensor(1658.5460, grad_fn=<SumBackward0>)
6 tensor(1657.1495, grad_fn=<SumBackward0>)
7 tensor(1655.7620, grad_fn=<SumBackward0>)
8 tensor(1654.3899, grad_fn=<SumBackward0>)
9 tensor(1653.0225, grad_fn=<SumBackward0>)
10 tensor(1651.6674, grad_fn=<SumBackward0>)
11 tensor(1650.3196, grad_fn=<SumBackward0>)
12 tensor(1648.9834, grad_fn=<SumBackward0>)
13 tensor(1647.6564, grad_fn=<SumBackward0>)
14 tensor(1646.3386, grad_fn=<SumBackward0>)
15 tensor(1645.0283, grad_fn=<SumBackward0>)
16 tensor(1643.7267, grad_fn=<SumBackward0>)
17 tensor(1642.4353, grad_fn=<SumBackward0>)
18 tensor(1641.1531, grad_fn=<SumBackward0>)
19 tensor(1639.8784, grad_fn=<SumBackward0>)
20 tensor(1638.6125, grad_fn=<SumBackward0>)
21 tensor(1637.3533, grad_fn=<SumBackward0>)
22 tensor(1636.1082,

318 tensor(1444.8464, grad_fn=<SumBackward0>)
319 tensor(1444.5040, grad_fn=<SumBackward0>)
320 tensor(1444.1626, grad_fn=<SumBackward0>)
321 tensor(1443.8215, grad_fn=<SumBackward0>)
322 tensor(1443.4833, grad_fn=<SumBackward0>)
323 tensor(1443.1444, grad_fn=<SumBackward0>)
324 tensor(1442.8059, grad_fn=<SumBackward0>)
325 tensor(1442.4705, grad_fn=<SumBackward0>)
326 tensor(1442.1348, grad_fn=<SumBackward0>)
327 tensor(1441.7998, grad_fn=<SumBackward0>)
328 tensor(1441.4663, grad_fn=<SumBackward0>)
329 tensor(1441.1348, grad_fn=<SumBackward0>)
330 tensor(1440.8030, grad_fn=<SumBackward0>)
331 tensor(1440.4717, grad_fn=<SumBackward0>)
332 tensor(1440.1431, grad_fn=<SumBackward0>)
333 tensor(1439.8132, grad_fn=<SumBackward0>)
334 tensor(1439.4851, grad_fn=<SumBackward0>)
335 tensor(1439.1573, grad_fn=<SumBackward0>)
336 tensor(1438.8313, grad_fn=<SumBackward0>)
337 tensor(1438.5067, grad_fn=<SumBackward0>)
338 tensor(1438.1814, grad_fn=<SumBackward0>)
339 tensor(1437.8584, grad_fn=<Sum

656 tensor(1365.1002, grad_fn=<SumBackward0>)
657 tensor(1364.9269, grad_fn=<SumBackward0>)
658 tensor(1364.7539, grad_fn=<SumBackward0>)
659 tensor(1364.5807, grad_fn=<SumBackward0>)
660 tensor(1364.4080, grad_fn=<SumBackward0>)
661 tensor(1364.2352, grad_fn=<SumBackward0>)
662 tensor(1364.0631, grad_fn=<SumBackward0>)
663 tensor(1363.8914, grad_fn=<SumBackward0>)
664 tensor(1363.7192, grad_fn=<SumBackward0>)
665 tensor(1363.5474, grad_fn=<SumBackward0>)
666 tensor(1363.3759, grad_fn=<SumBackward0>)
667 tensor(1363.2042, grad_fn=<SumBackward0>)
668 tensor(1363.0332, grad_fn=<SumBackward0>)
669 tensor(1362.8621, grad_fn=<SumBackward0>)
670 tensor(1362.6909, grad_fn=<SumBackward0>)
671 tensor(1362.5200, grad_fn=<SumBackward0>)
672 tensor(1362.3503, grad_fn=<SumBackward0>)
673 tensor(1362.1815, grad_fn=<SumBackward0>)
674 tensor(1362.0125, grad_fn=<SumBackward0>)
675 tensor(1361.8431, grad_fn=<SumBackward0>)
676 tensor(1361.6740, grad_fn=<SumBackward0>)
677 tensor(1361.5051, grad_fn=<Sum

990 tensor(1316.0808, grad_fn=<SumBackward0>)
991 tensor(1315.9501, grad_fn=<SumBackward0>)
992 tensor(1315.8195, grad_fn=<SumBackward0>)
993 tensor(1315.6887, grad_fn=<SumBackward0>)
994 tensor(1315.5583, grad_fn=<SumBackward0>)
995 tensor(1315.4280, grad_fn=<SumBackward0>)
996 tensor(1315.2979, grad_fn=<SumBackward0>)
997 tensor(1315.1675, grad_fn=<SumBackward0>)
998 tensor(1315.0372, grad_fn=<SumBackward0>)
999 tensor(1314.9072, grad_fn=<SumBackward0>)
1000 tensor(1314.7772, grad_fn=<SumBackward0>)
1001 tensor(1314.6472, grad_fn=<SumBackward0>)
1002 tensor(1314.5172, grad_fn=<SumBackward0>)
1003 tensor(1314.3875, grad_fn=<SumBackward0>)
1004 tensor(1314.2576, grad_fn=<SumBackward0>)
1005 tensor(1314.1279, grad_fn=<SumBackward0>)
1006 tensor(1313.9983, grad_fn=<SumBackward0>)
1007 tensor(1313.8689, grad_fn=<SumBackward0>)
1008 tensor(1313.7395, grad_fn=<SumBackward0>)
1009 tensor(1313.6101, grad_fn=<SumBackward0>)
1010 tensor(1313.4808, grad_fn=<SumBackward0>)
1011 tensor(1313.3517, 

1168 tensor(1293.9333, grad_fn=<SumBackward0>)
1169 tensor(1293.8149, grad_fn=<SumBackward0>)
1170 tensor(1293.6964, grad_fn=<SumBackward0>)
1171 tensor(1293.5780, grad_fn=<SumBackward0>)
1172 tensor(1293.4597, grad_fn=<SumBackward0>)
1173 tensor(1293.3413, grad_fn=<SumBackward0>)
1174 tensor(1293.2233, grad_fn=<SumBackward0>)
1175 tensor(1293.1051, grad_fn=<SumBackward0>)
1176 tensor(1292.9869, grad_fn=<SumBackward0>)
1177 tensor(1292.8690, grad_fn=<SumBackward0>)
1178 tensor(1292.7510, grad_fn=<SumBackward0>)
1179 tensor(1292.6331, grad_fn=<SumBackward0>)
1180 tensor(1292.5154, grad_fn=<SumBackward0>)
1181 tensor(1292.3975, grad_fn=<SumBackward0>)
1182 tensor(1292.2798, grad_fn=<SumBackward0>)
1183 tensor(1292.1622, grad_fn=<SumBackward0>)
1184 tensor(1292.0446, grad_fn=<SumBackward0>)
1185 tensor(1291.9270, grad_fn=<SumBackward0>)
1186 tensor(1291.8096, grad_fn=<SumBackward0>)
1187 tensor(1291.6921, grad_fn=<SumBackward0>)
1188 tensor(1291.5747, grad_fn=<SumBackward0>)
1189 tensor(1

1503 tensor(1257.3894, grad_fn=<SumBackward0>)
1504 tensor(1257.2891, grad_fn=<SumBackward0>)
1505 tensor(1257.1887, grad_fn=<SumBackward0>)
1506 tensor(1257.0886, grad_fn=<SumBackward0>)
1507 tensor(1256.9885, grad_fn=<SumBackward0>)
1508 tensor(1256.8882, grad_fn=<SumBackward0>)
1509 tensor(1256.7882, grad_fn=<SumBackward0>)
1510 tensor(1256.6881, grad_fn=<SumBackward0>)
1511 tensor(1256.5881, grad_fn=<SumBackward0>)
1512 tensor(1256.4882, grad_fn=<SumBackward0>)
1513 tensor(1256.3882, grad_fn=<SumBackward0>)
1514 tensor(1256.2886, grad_fn=<SumBackward0>)
1515 tensor(1256.1887, grad_fn=<SumBackward0>)
1516 tensor(1256.0889, grad_fn=<SumBackward0>)
1517 tensor(1255.9893, grad_fn=<SumBackward0>)
1518 tensor(1255.8896, grad_fn=<SumBackward0>)
1519 tensor(1255.7899, grad_fn=<SumBackward0>)
1520 tensor(1255.6904, grad_fn=<SumBackward0>)
1521 tensor(1255.5909, grad_fn=<SumBackward0>)
1522 tensor(1255.4915, grad_fn=<SumBackward0>)
1523 tensor(1255.3918, grad_fn=<SumBackward0>)
1524 tensor(1

1838 tensor(1226.3300, grad_fn=<SumBackward0>)
1839 tensor(1226.2446, grad_fn=<SumBackward0>)
1840 tensor(1226.1592, grad_fn=<SumBackward0>)
1841 tensor(1226.0737, grad_fn=<SumBackward0>)
1842 tensor(1225.9884, grad_fn=<SumBackward0>)
1843 tensor(1225.9031, grad_fn=<SumBackward0>)
1844 tensor(1225.8179, grad_fn=<SumBackward0>)
1845 tensor(1225.7325, grad_fn=<SumBackward0>)
1846 tensor(1225.6472, grad_fn=<SumBackward0>)
1847 tensor(1225.5620, grad_fn=<SumBackward0>)
1848 tensor(1225.4771, grad_fn=<SumBackward0>)
1849 tensor(1225.3920, grad_fn=<SumBackward0>)
1850 tensor(1225.3069, grad_fn=<SumBackward0>)
1851 tensor(1225.2218, grad_fn=<SumBackward0>)
1852 tensor(1225.1370, grad_fn=<SumBackward0>)
1853 tensor(1225.0520, grad_fn=<SumBackward0>)
1854 tensor(1224.9670, grad_fn=<SumBackward0>)
1855 tensor(1224.8822, grad_fn=<SumBackward0>)
1856 tensor(1224.7975, grad_fn=<SumBackward0>)
1857 tensor(1224.7128, grad_fn=<SumBackward0>)
1858 tensor(1224.6281, grad_fn=<SumBackward0>)
1859 tensor(1

2171 tensor(1199.4645, grad_fn=<SumBackward0>)
2172 tensor(1199.3915, grad_fn=<SumBackward0>)
2173 tensor(1199.3184, grad_fn=<SumBackward0>)
2174 tensor(1199.2456, grad_fn=<SumBackward0>)
2175 tensor(1199.1727, grad_fn=<SumBackward0>)
2176 tensor(1199.0999, grad_fn=<SumBackward0>)
2177 tensor(1199.0271, grad_fn=<SumBackward0>)
2178 tensor(1198.9541, grad_fn=<SumBackward0>)
2179 tensor(1198.8813, grad_fn=<SumBackward0>)
2180 tensor(1198.8087, grad_fn=<SumBackward0>)
2181 tensor(1198.7361, grad_fn=<SumBackward0>)
2182 tensor(1198.6635, grad_fn=<SumBackward0>)
2183 tensor(1198.5907, grad_fn=<SumBackward0>)
2184 tensor(1198.5181, grad_fn=<SumBackward0>)
2185 tensor(1198.4456, grad_fn=<SumBackward0>)
2186 tensor(1198.3730, grad_fn=<SumBackward0>)
2187 tensor(1198.3007, grad_fn=<SumBackward0>)
2188 tensor(1198.2283, grad_fn=<SumBackward0>)
2189 tensor(1198.1558, grad_fn=<SumBackward0>)
2190 tensor(1198.0835, grad_fn=<SumBackward0>)
2191 tensor(1198.0111, grad_fn=<SumBackward0>)
2192 tensor(1

2505 tensor(1176.9004, grad_fn=<SumBackward0>)
2506 tensor(1176.8381, grad_fn=<SumBackward0>)
2507 tensor(1176.7758, grad_fn=<SumBackward0>)
2508 tensor(1176.7135, grad_fn=<SumBackward0>)
2509 tensor(1176.6511, grad_fn=<SumBackward0>)
2510 tensor(1176.5890, grad_fn=<SumBackward0>)
2511 tensor(1176.5266, grad_fn=<SumBackward0>)
2512 tensor(1176.4644, grad_fn=<SumBackward0>)
2513 tensor(1176.4023, grad_fn=<SumBackward0>)
2514 tensor(1176.3402, grad_fn=<SumBackward0>)
2515 tensor(1176.2782, grad_fn=<SumBackward0>)
2516 tensor(1176.2161, grad_fn=<SumBackward0>)
2517 tensor(1176.1541, grad_fn=<SumBackward0>)
2518 tensor(1176.0920, grad_fn=<SumBackward0>)
2519 tensor(1176.0300, grad_fn=<SumBackward0>)
2520 tensor(1175.9680, grad_fn=<SumBackward0>)
2521 tensor(1175.9061, grad_fn=<SumBackward0>)
2522 tensor(1175.8442, grad_fn=<SumBackward0>)
2523 tensor(1175.7823, grad_fn=<SumBackward0>)
2524 tensor(1175.7205, grad_fn=<SumBackward0>)
2525 tensor(1175.6587, grad_fn=<SumBackward0>)
2526 tensor(1

2844 tensor(1157.3508, grad_fn=<SumBackward0>)
2845 tensor(1157.2976, grad_fn=<SumBackward0>)
2846 tensor(1157.2444, grad_fn=<SumBackward0>)
2847 tensor(1157.1913, grad_fn=<SumBackward0>)
2848 tensor(1157.1382, grad_fn=<SumBackward0>)
2849 tensor(1157.0852, grad_fn=<SumBackward0>)
2850 tensor(1157.0321, grad_fn=<SumBackward0>)
2851 tensor(1156.9789, grad_fn=<SumBackward0>)
2852 tensor(1156.9259, grad_fn=<SumBackward0>)
2853 tensor(1156.8729, grad_fn=<SumBackward0>)
2854 tensor(1156.8199, grad_fn=<SumBackward0>)
2855 tensor(1156.7671, grad_fn=<SumBackward0>)
2856 tensor(1156.7141, grad_fn=<SumBackward0>)
2857 tensor(1156.6613, grad_fn=<SumBackward0>)
2858 tensor(1156.6084, grad_fn=<SumBackward0>)
2859 tensor(1156.5555, grad_fn=<SumBackward0>)
2860 tensor(1156.5027, grad_fn=<SumBackward0>)
2861 tensor(1156.4498, grad_fn=<SumBackward0>)
2862 tensor(1156.3971, grad_fn=<SumBackward0>)
2863 tensor(1156.3442, grad_fn=<SumBackward0>)
2864 tensor(1156.2916, grad_fn=<SumBackward0>)
2865 tensor(1

3188 tensor(1140.4407, grad_fn=<SumBackward0>)
3189 tensor(1140.3953, grad_fn=<SumBackward0>)
3190 tensor(1140.3500, grad_fn=<SumBackward0>)
3191 tensor(1140.3047, grad_fn=<SumBackward0>)
3192 tensor(1140.2595, grad_fn=<SumBackward0>)
3193 tensor(1140.2142, grad_fn=<SumBackward0>)
3194 tensor(1140.1691, grad_fn=<SumBackward0>)
3195 tensor(1140.1239, grad_fn=<SumBackward0>)
3196 tensor(1140.0786, grad_fn=<SumBackward0>)
3197 tensor(1140.0336, grad_fn=<SumBackward0>)
3198 tensor(1139.9883, grad_fn=<SumBackward0>)
3199 tensor(1139.9432, grad_fn=<SumBackward0>)
3200 tensor(1139.8982, grad_fn=<SumBackward0>)
3201 tensor(1139.8531, grad_fn=<SumBackward0>)
3202 tensor(1139.8081, grad_fn=<SumBackward0>)
3203 tensor(1139.7631, grad_fn=<SumBackward0>)
3204 tensor(1139.7180, grad_fn=<SumBackward0>)
3205 tensor(1139.6731, grad_fn=<SumBackward0>)
3206 tensor(1139.6282, grad_fn=<SumBackward0>)
3207 tensor(1139.5831, grad_fn=<SumBackward0>)
3208 tensor(1139.5383, grad_fn=<SumBackward0>)
3209 tensor(1

3524 tensor(1126.2501, grad_fn=<SumBackward0>)
3525 tensor(1126.2102, grad_fn=<SumBackward0>)
3526 tensor(1126.1702, grad_fn=<SumBackward0>)
3527 tensor(1126.1302, grad_fn=<SumBackward0>)
3528 tensor(1126.0901, grad_fn=<SumBackward0>)
3529 tensor(1126.0502, grad_fn=<SumBackward0>)
3530 tensor(1126.0101, grad_fn=<SumBackward0>)
3531 tensor(1125.9702, grad_fn=<SumBackward0>)
3532 tensor(1125.9303, grad_fn=<SumBackward0>)
3533 tensor(1125.8906, grad_fn=<SumBackward0>)
3534 tensor(1125.8507, grad_fn=<SumBackward0>)
3535 tensor(1125.8109, grad_fn=<SumBackward0>)
3536 tensor(1125.7710, grad_fn=<SumBackward0>)
3537 tensor(1125.7313, grad_fn=<SumBackward0>)
3538 tensor(1125.6917, grad_fn=<SumBackward0>)
3539 tensor(1125.6519, grad_fn=<SumBackward0>)
3540 tensor(1125.6121, grad_fn=<SumBackward0>)
3541 tensor(1125.5725, grad_fn=<SumBackward0>)
3542 tensor(1125.5328, grad_fn=<SumBackward0>)
3543 tensor(1125.4933, grad_fn=<SumBackward0>)
3544 tensor(1125.4536, grad_fn=<SumBackward0>)
3545 tensor(1

3860 tensor(1113.9976, grad_fn=<SumBackward0>)
3861 tensor(1113.9644, grad_fn=<SumBackward0>)
3862 tensor(1113.9310, grad_fn=<SumBackward0>)
3863 tensor(1113.8976, grad_fn=<SumBackward0>)
3864 tensor(1113.8643, grad_fn=<SumBackward0>)
3865 tensor(1113.8311, grad_fn=<SumBackward0>)
3866 tensor(1113.7977, grad_fn=<SumBackward0>)
3867 tensor(1113.7645, grad_fn=<SumBackward0>)
3868 tensor(1113.7312, grad_fn=<SumBackward0>)
3869 tensor(1113.6981, grad_fn=<SumBackward0>)
3870 tensor(1113.6648, grad_fn=<SumBackward0>)
3871 tensor(1113.6318, grad_fn=<SumBackward0>)
3872 tensor(1113.5985, grad_fn=<SumBackward0>)
3873 tensor(1113.5653, grad_fn=<SumBackward0>)
3874 tensor(1113.5322, grad_fn=<SumBackward0>)
3875 tensor(1113.4990, grad_fn=<SumBackward0>)
3876 tensor(1113.4659, grad_fn=<SumBackward0>)
3877 tensor(1113.4327, grad_fn=<SumBackward0>)
3878 tensor(1113.3998, grad_fn=<SumBackward0>)
3879 tensor(1113.3667, grad_fn=<SumBackward0>)
3880 tensor(1113.3336, grad_fn=<SumBackward0>)
3881 tensor(1

4199 tensor(1103.5422, grad_fn=<SumBackward0>)
4200 tensor(1103.5135, grad_fn=<SumBackward0>)
4201 tensor(1103.4849, grad_fn=<SumBackward0>)
4202 tensor(1103.4564, grad_fn=<SumBackward0>)
4203 tensor(1103.4279, grad_fn=<SumBackward0>)
4204 tensor(1103.3993, grad_fn=<SumBackward0>)
4205 tensor(1103.3707, grad_fn=<SumBackward0>)
4206 tensor(1103.3422, grad_fn=<SumBackward0>)
4207 tensor(1103.3137, grad_fn=<SumBackward0>)
4208 tensor(1103.2853, grad_fn=<SumBackward0>)
4209 tensor(1103.2567, grad_fn=<SumBackward0>)
4210 tensor(1103.2284, grad_fn=<SumBackward0>)
4211 tensor(1103.2000, grad_fn=<SumBackward0>)
4212 tensor(1103.1714, grad_fn=<SumBackward0>)
4213 tensor(1103.1429, grad_fn=<SumBackward0>)
4214 tensor(1103.1147, grad_fn=<SumBackward0>)
4215 tensor(1103.0863, grad_fn=<SumBackward0>)
4216 tensor(1103.0579, grad_fn=<SumBackward0>)
4217 tensor(1103.0295, grad_fn=<SumBackward0>)
4218 tensor(1103.0012, grad_fn=<SumBackward0>)
4219 tensor(1102.9730, grad_fn=<SumBackward0>)
4220 tensor(1

4534 tensor(1094.7046, grad_fn=<SumBackward0>)
4535 tensor(1094.6802, grad_fn=<SumBackward0>)
4536 tensor(1094.6558, grad_fn=<SumBackward0>)
4537 tensor(1094.6315, grad_fn=<SumBackward0>)
4538 tensor(1094.6071, grad_fn=<SumBackward0>)
4539 tensor(1094.5828, grad_fn=<SumBackward0>)
4540 tensor(1094.5585, grad_fn=<SumBackward0>)
4541 tensor(1094.5342, grad_fn=<SumBackward0>)
4542 tensor(1094.5100, grad_fn=<SumBackward0>)
4543 tensor(1094.4856, grad_fn=<SumBackward0>)
4544 tensor(1094.4612, grad_fn=<SumBackward0>)
4545 tensor(1094.4371, grad_fn=<SumBackward0>)
4546 tensor(1094.4127, grad_fn=<SumBackward0>)
4547 tensor(1094.3885, grad_fn=<SumBackward0>)
4548 tensor(1094.3644, grad_fn=<SumBackward0>)
4549 tensor(1094.3401, grad_fn=<SumBackward0>)
4550 tensor(1094.3159, grad_fn=<SumBackward0>)
4551 tensor(1094.2916, grad_fn=<SumBackward0>)
4552 tensor(1094.2676, grad_fn=<SumBackward0>)
4553 tensor(1094.2433, grad_fn=<SumBackward0>)
4554 tensor(1094.2191, grad_fn=<SumBackward0>)
4555 tensor(1

4878 tensor(1086.9346, grad_fn=<SumBackward0>)
4879 tensor(1086.9137, grad_fn=<SumBackward0>)
4880 tensor(1086.8928, grad_fn=<SumBackward0>)
4881 tensor(1086.8718, grad_fn=<SumBackward0>)
4882 tensor(1086.8511, grad_fn=<SumBackward0>)
4883 tensor(1086.8303, grad_fn=<SumBackward0>)
4884 tensor(1086.8093, grad_fn=<SumBackward0>)
4885 tensor(1086.7886, grad_fn=<SumBackward0>)
4886 tensor(1086.7677, grad_fn=<SumBackward0>)
4887 tensor(1086.7468, grad_fn=<SumBackward0>)
4888 tensor(1086.7261, grad_fn=<SumBackward0>)
4889 tensor(1086.7053, grad_fn=<SumBackward0>)
4890 tensor(1086.6846, grad_fn=<SumBackward0>)
4891 tensor(1086.6638, grad_fn=<SumBackward0>)
4892 tensor(1086.6429, grad_fn=<SumBackward0>)
4893 tensor(1086.6222, grad_fn=<SumBackward0>)
4894 tensor(1086.6014, grad_fn=<SumBackward0>)
4895 tensor(1086.5808, grad_fn=<SumBackward0>)
4896 tensor(1086.5601, grad_fn=<SumBackward0>)
4897 tensor(1086.5393, grad_fn=<SumBackward0>)
4898 tensor(1086.5186, grad_fn=<SumBackward0>)
4899 tensor(1

In [1]:
import torch
from torch.autograd import Variable

dtype = torch.FloatTensor
N, D_in, H, D_out = 64, 1000, 100, 10

x = Variable(torch.randn(N, D_in).type(dtype), requires_grad=False)
y = Variable(torch.randn(N, D_out).type(dtype), requires_grad=False)

w1 = Variable(torch.randn(D_in, H).type(dtype), requires_grad=True)
w2 = Variable(torch.randn(H, D_out).type(dtype), requires_grad=True)

learning_rate = 1e-6
for t in range(500):
    y_pred = x.mm(w1).clamp(min=0).mm(w2)

    loss = (y_pred - y).pow(2).sum()
    # print(t, loss.data[0].item())
    print(loss)
    loss.backward()

    w1.data -= learning_rate * w1.grad.data
    w2.data -= learning_rate * w2.grad.data

    w1.grad.data.zero_()
    w2.grad.data.zero_()



tensor(22746592., grad_fn=<SumBackward0>)
tensor(15246758., grad_fn=<SumBackward0>)
tensor(11661868., grad_fn=<SumBackward0>)
tensor(9748469., grad_fn=<SumBackward0>)
tensor(8605668., grad_fn=<SumBackward0>)
tensor(7810116., grad_fn=<SumBackward0>)
tensor(7102151.5000, grad_fn=<SumBackward0>)
tensor(6389003.5000, grad_fn=<SumBackward0>)
tensor(5609677., grad_fn=<SumBackward0>)
tensor(4811203., grad_fn=<SumBackward0>)
tensor(4013734.5000, grad_fn=<SumBackward0>)
tensor(3281270.7500, grad_fn=<SumBackward0>)
tensor(2628764.5000, grad_fn=<SumBackward0>)
tensor(2084604.1250, grad_fn=<SumBackward0>)
tensor(1637876.1250, grad_fn=<SumBackward0>)
tensor(1286040.6250, grad_fn=<SumBackward0>)
tensor(1010285.4375, grad_fn=<SumBackward0>)
tensor(798795.5625, grad_fn=<SumBackward0>)
tensor(636363.5625, grad_fn=<SumBackward0>)
tensor(512653.6875, grad_fn=<SumBackward0>)
tensor(417641.1250, grad_fn=<SumBackward0>)
tensor(344626.6250, grad_fn=<SumBackward0>)
tensor(287853.6250, grad_fn=<SumBackward0>)


tensor(2.6461, grad_fn=<SumBackward0>)
tensor(2.5521, grad_fn=<SumBackward0>)
tensor(2.4613, grad_fn=<SumBackward0>)
tensor(2.3740, grad_fn=<SumBackward0>)
tensor(2.2898, grad_fn=<SumBackward0>)
tensor(2.2087, grad_fn=<SumBackward0>)
tensor(2.1304, grad_fn=<SumBackward0>)
tensor(2.0550, grad_fn=<SumBackward0>)
tensor(1.9820, grad_fn=<SumBackward0>)
tensor(1.9120, grad_fn=<SumBackward0>)
tensor(1.8444, grad_fn=<SumBackward0>)
tensor(1.7792, grad_fn=<SumBackward0>)
tensor(1.7163, grad_fn=<SumBackward0>)
tensor(1.6555, grad_fn=<SumBackward0>)
tensor(1.5971, grad_fn=<SumBackward0>)
tensor(1.5407, grad_fn=<SumBackward0>)
tensor(1.4863, grad_fn=<SumBackward0>)
tensor(1.4339, grad_fn=<SumBackward0>)
tensor(1.3833, grad_fn=<SumBackward0>)
tensor(1.3346, grad_fn=<SumBackward0>)
tensor(1.2875, grad_fn=<SumBackward0>)
tensor(1.2421, grad_fn=<SumBackward0>)
tensor(1.1984, grad_fn=<SumBackward0>)
tensor(1.1562, grad_fn=<SumBackward0>)
tensor(1.1155, grad_fn=<SumBackward0>)
tensor(1.0763, grad_fn=<S

In [1]:
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms

input_size = 784       # The image size = 28 x 28 = 784
hidden_size = 500      # The number of nodes at the hidden layer
num_classes = 10       # The number of output classes. In this case, from 0 to 9
num_epochs = 5         # The number of times entire dataset is trained
batch_size = 100       # The size of input data took for one iteration
learning_rate = 0.001  # The speed of convergence
dtype = torch.FloatTensor

In [2]:
train_dataset = dsets.MNIST(root='./data',
                           train=True,
                           transform=transforms.ToTensor(),
                           download=True)

test_dataset = dsets.MNIST(root='./data',
                           train=False,
                           transform=transforms.ToTensor())

In [3]:
x = train_dataset.train_data.float() /255

def one_hot_embedding(labels, num_classes):
    y = torch.eye(num_classes) 
    return y[labels] 

x = x.view(-1, 28*28)
y = one_hot_embedding(train_dataset.train_labels, 10)



In [5]:
w1 = torch.randn(input_size, hidden_size, requires_grad=True).type(dtype)
b1 = torch.randn(hidden_size, requires_grad=True).type(dtype)

w2 = torch.randn(hidden_size, num_classes, requires_grad=True).type(dtype)
b2 = torch.randn(num_classes, requires_grad=True).type(dtype)


In [19]:
torch.mm(x, w1) + b1

tensor([[-23.2466,  -2.3021,  -5.1985,  ...,  -3.1002,  12.0090,  -5.2520],
        [-33.3994,  -3.8443, -14.6838,  ...,  -1.5646,   2.1384,  -8.2969],
        [-10.8907,  -8.1168,  -8.3095,  ...,   3.5389,   6.6156, -12.4017],
        ...,
        [-10.6083,   0.8321,   2.1596,  ..., -12.6337,  13.0580,  -7.1332],
        [ -6.5841,  -0.8262,   1.5488,  ...,  -2.2759,   3.0573,   4.3280],
        [ -6.7206,   3.2759, -19.2771,  ...,  10.6873,  -7.8474,  -3.8390]],
       grad_fn=<AddBackward0>)

In [6]:
lr = 1e-3
for epoch in range(500):
    a1 = torch.mm(x, w1) + b1
    z1 = a1.clamp(min=0)
    a2 = torch.mm(z1, w2) + b2
    loss = ((y - a2).pow(2).sum()) / 60000
    loss.backward()
    
    w1.data -= lr * w1.grad.data
    w2.data -= lr * w2.grad.data
    b1.data -= lr * b1.grad.data
    b2.data -= lr * b2.grad.data

    w1.grad.data.zero_()
    w2.grad.data.zero_()
    b1.grad.data.zero_()
    b2.grad.data.zero_()
    
    print(epoch, loss.item())

0 231828.421875
1 177387184.0
2 2790915.25
3 528.4381103515625
4 526.2119750976562
5 523.9876708984375
6 521.78564453125
7 519.5919799804688
8 517.4118041992188
9 515.241943359375
10 513.0802001953125
11 510.9402770996094
12 508.8017578125
13 506.68145751953125
14 504.5646057128906
15 502.46624755859375
16 500.3731689453125
17 498.29296875
18 496.2196044921875
19 494.162353515625
20 492.1149597167969
21 490.0758056640625
22 488.04620361328125
23 486.0256042480469
24 484.0165100097656
25 482.01806640625
26 480.0275573730469
27 478.04833984375
28 476.0764465332031
29 474.1175231933594
30 472.1629943847656
31 470.2198181152344
32 468.2865295410156
33 466.3604736328125
34 464.44805908203125
35 462.5404968261719
36 460.6458740234375
37 458.753173828125
38 456.8763427734375
39 455.000732421875
40 453.1397399902344
41 451.2851257324219
42 449.44158935546875
43 447.60308837890625
44 445.7754211425781
45 443.9538269042969
46 442.13861083984375
47 440.3373718261719
48 438.5414123535156
49 436.75

375 119.4216537475586
376 118.95648193359375
377 118.4921646118164
378 118.0294189453125
379 117.56939697265625
380 117.11161041259766
381 116.65460205078125
382 116.20045471191406
383 115.7474594116211
384 115.29501342773438
385 114.84513092041016
386 114.39887237548828
387 113.95283508300781
388 113.5086669921875
389 113.06613159179688
390 112.6254653930664
391 112.1878890991211
392 111.7512435913086
393 111.31485748291016
394 110.88218688964844
395 110.45010375976562
396 110.01988220214844
397 109.59175872802734
398 109.16578674316406
399 108.74066925048828
400 108.31889343261719
401 107.89625549316406
402 107.47704315185547
403 107.05873107910156
404 106.64266967773438
405 106.22789764404297
406 105.81465148925781
407 105.40287780761719
408 104.99365234375
409 104.58642578125
410 104.17921447753906
411 103.77462005615234
412 103.37114715576172
413 102.96920776367188
414 102.56930541992188
415 102.17161560058594
416 101.77488708496094
417 101.37995147705078
418 100.98600006103516
41