# Implementing neural networks

Today, we're going to go through the process together of implementing a neural network for a simple regression problem.  Then I'm going to turn you loose to adapt this methodology to the MNIST problem.

We're going to use a new library to implement this network.  This library is called pytorch, and you can easily install it by following the installation instructions found [here](https://pytorch.org/get-started/locally/).  Why are we not using numpy?  We'll return to that in a moment.  However, torch actually behaves significantly like numpy in a variety of ways.  For example we can generate a synthetic dataset (rather similar to that from Homework 2) using some familiar commands:


In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

# Generate the features
x = torch.linspace(-2,2,101).reshape(-1,1)

# Generate the response variables 
y_obs = x**2 + x + torch.cos(2*np.pi*x) + torch.randn_like(x)*0.3

plt.plot(x,y_obs,'k.')

We'd like to find a function that fits this data.  One way to do this is, of course, linear regression, but that requires specifying the form of the design matrix.  As we saw in lecture, we'd like to learn the design matrix from the data.  We'll use a neural network to perform this task.  Algebraically, we can specify the neural network via the following sequentially applied functions.
$$
\underbrace{z}_{m\times p} = \underbrace{x}_{m\times 1} \underbrace{W^{(1)}}_{1\times p} + \underbrace{b^{(1)}}_{1\times p}
$$
$$
h = \sigma(z)
$$
$$
\underbrace{y}_{m\times 1} = h \underbrace{W^{(2)}}{p \times 1} + \underbrace{b^{(2)}}_{1\times 1}
$$

Because this is getting a bit complicated, let's develop a create a class for this neural network to hold weights and apply functions as appropriate.    

In [None]:
class NeuralNet(object):
    
    def __init__(self,n,p,N):
        self.n = n   # Number of features (1 for univariate problem)
        self.p = p   # Number of nodes in the hidden layer
        self.N = N   # Number of outputs (1 for the regression problem)
        
        # Instantiate weight matrices 
        self.W_1 = torch.randn(n,p)*10
        self.W_2 = torch.randn(p,N)/np.sqrt(p)*1
        
        # Instantiate bias vectors (Why do we need this?)
        self.b_1 = torch.randn(1,p)*10
        self.b_2 = torch.randn(1,N)/np.sqrt(p)*1
               
    def forward(self,X):
        # Applies the neural network model
        ## All of these self. prefixes save calculation results
        ## as class variables - we can inspect them later if we
        ## wish to
        self.X = X
        self.z = self.X @ self.W_1 + self.b_1  # First linear 
        self.h = torch.sigmoid(self.z)         # Activation
        self.y = self.h @ self.W_2 + self.b_2  # Second linear
        
        return self.y


You'll notice that we're instantiating weights randomly.  Let's see what type of functions this model produces, prior to training.  It's interesting to see what the effect is of messing with the variance of the weights when initializing them.   

In [None]:
# Sample 10 random neural nets
for i in range(10):
    
    # Create the neural network
    net = NeuralNet(1,20,1)
    
    # Make a prediction
    y_pred = net.forward(x)
    
    # Plot the predictions
    plt.plot(x,y_pred)


Of course, this isn't all that interesting on its own.  We now need to train this thing.  We'll do this using gradient descent, and herein lies the power of pytorch.  It is a framework for *automatic differentiation*.  What does this mean?  It means that it keeps a record of all of the operations that have been done to produce the output of a given function.  It then can *automaticall* apply the chain rule to produce derivatives of a function with respect to anything that was used to compute it.  Here, we're hoping to take the gradient with respect to the weights and biases.  We can tell pytorch that we're going to want these things by using the "requires_grad_" flag.  

In [None]:
class NeuralNet(object):
    
    def __init__(self,n,p,N):
        self.n = n   # Number of features (1 for univariate problem)
        self.p = p   # Number of nodes in the hidden layer
        self.N = N   # Number of outputs (1 for the regression problem)
        
        # Instantiate weight matrices 
        self.W_1 = torch.randn(n,p)*10
        self.W_2 = torch.randn(p,N)/np.sqrt(p)
        
        # Instantiate bias vectors (Why do we need this?)
        self.b_1 = torch.randn(1,p)*10
        self.b_2 = torch.randn(1,N)/np.sqrt(p)
        
        ### CHANGE FROM ABOVE ###  
        # Collect the model parameters, and tell pytorch to
        # collect gradient information about them.
        self.parameters = [self.W_1,self.W_2,self.b_1,self.b_2]
        for param in self.parameters:
            param.requires_grad_()
    def forward(self,X):
        # Applies the neural network model
        ## All of these self. prefixes save calculation results
        ## as class variables - we can inspect them later if we
        ## wish to
        self.X = X
        self.z = self.X @ self.W_1 + self.b_1  # First linear 
        self.h = torch.sigmoid(self.z)         # Activation
        self.y = self.h @ self.W_2 + self.b_2  # Second linear
        return self.y
    
    def zero_grad(self):
        ### Each parameter has an additional array associated
        ### with it to store its gradient.  This is not 
        ### automatically cleared, so we have a method to
        ### clear it.
        for param in self.parameters:
            try:
                param.grad.data[:] = 0.0
            except AttributeError:
                pass
        

One thing that still need is something to minimize.  Since this is a regression problem, we'll use mean-squared-error

In [None]:
def mse(y_pred,y_obs):
    m = y_pred.shape[0]
    return 1./m*((y_pred-y_obs)**2).sum()

Now, the code for gradient descent becomes strikingly simple:

In [None]:
net = NeuralNet(1,20,1)  # Instantiate network
eta = 1e-1               # Set learning rate (empirically derived)
for t in range(50000):   # run for 50000 epochs
    y_pred = net.forward(x)   # Make a prediction
    L = mse(y_pred,y_obs)     # Compute mse
    net.zero_grad()           # Clear gradient buffer
    L.backward()              # MAGIC: compute dL/d parameter
    for param in net.parameters:            # update parameters w/
        param.data -= eta*param.grad.data   # GD
        
    if t%100==0:         # Print loss    
        print(t,L.item())
    
        


Now we can plot our model prediction versus observations.  Pretty good!  And no manual selection of basis functions.

In [None]:
plt.plot(x.detach().squeeze(),y_pred.detach().squeeze())
plt.plot(x.detach().squeeze(),y_obs.detach().squeeze(),'k.')

## Applying an MLP to MNIST:
Train a neural on MNIST using pytorch.  You should use the above code as a template.  Things you'll need to change: $n$ will no longer be 1, but rather 784.  $N$ will no longer be one, but 10.  You'll want to adjust $p$, the number of hidden layer nodes.  You'll likely need to adjust the learning rate.  Finally, and most importantly, you'll need to use a different loss function.  In particular, you'll replace our handrolled MSE code with [this](https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html).  IMPORTANT NOTE: this loss expects *logits* as inputs, which is to say that it will do softmax for you internally.  As such, the architecture of your network should be more or less the same as above.  

In [1]:
from sklearn.datasets import fetch_openml
X, y = fetch_openml('mnist_784', version=1, return_X_y=True, as_frame=False)
X = X[:5000]
y = y[:5000]

In [2]:
X -= X.mean()
X /= X.std()

In [3]:

class NeuralNet(object):
    
    def __init__(self,n,p,N):
        self.n = n   # Number of features (1 for univariate problem)
        self.p = p   # Number of nodes in the hidden layer
        self.N = N   # Number of outputs (1 for the regression problem)
        
        # Instantiate weight matrices 
        self.W_1 = torch.randn(n,p)*1e-3
        self.W_2 = torch.randn(p,N)*1e-3
        
        # Instantiate bias vectors (Why do we need this? it's the appended ones column)
        self.b_1 = torch.randn(1,p)
        self.b_2 = torch.randn(1,N)
        
        ### CHANGE FROM ABOVE ###  
        # Collect the model parameters, and tell pytorch to
        # collect gradient information about them.
        self.parameters = [self.W_1,self.W_2,self.b_1,self.b_2]
        for param in self.parameters:
            param.requires_grad_()
            
    def forward(self,X):
        # Applies the neural network model
        ## All of these self. prefixes save calculation results
        ## as class variables - we can inspect them later if we
        ## wish to
        self.X = X
        self.a = self.X @ self.W_1 + self.b_1  # First linear 
        self.z = torch.sigmoid(self.a)         # Activation
        self.y = self.z @ self.W_2 + self.b_2  # Second linear
        return self.y
    
    def zero_grad(self):
        ### Each parameter has an additional array associated
        ### with it to store its gradient.  This is not 
        ### automatically cleared, so we have a method to
        ### clear it. (Set gradient to zero before calculating again)
        for param in self.parameters:
            try:
                param.grad.data[:] = 0.0
            except AttributeError:
                pass
        

In [4]:
import torch
X_train = torch.from_numpy(X).to(torch.float)
y_train = torch.from_numpy(y.astype(int)).to(torch.long)

In [None]:
y_pred.shape

In [None]:
L

In [None]:
net = NeuralNet(784,200,10)  # Instantiate network
eta = 1e-2               # Set learning rate (empirically derived)
loss = torch.nn.CrossEntropyLoss()
for t in range(50000):   # run for 50000 epochs
    y_pred = net.forward(X_train)   # Make a prediction
    L = loss(y_pred,y_train)     # Compute categorical cross entropy
    net.zero_grad()           # Clear gradient buffer
    L.backward()              # MAGIC: compute dL/d parameter
    for param in net.parameters:            # update parameters w/
        param.data -= eta*param.grad.data   # GD
        
    if t%10==0:         # Print loss    
        print(t,L.item(),(torch.argmax(y_pred,axis=1)==y_train).sum()/len(y_train))
        
    

0 2.4497275352478027 tensor(0.0958)
10 2.3451731204986572 tensor(0.0958)
20 2.3133301734924316 tensor(0.1100)
30 2.3032419681549072 tensor(0.1100)
40 2.299819231033325 tensor(0.1100)
50 2.2984297275543213 tensor(0.1128)
60 2.297654867172241 tensor(0.1184)
70 2.297062397003174 tensor(0.1126)
80 2.296522855758667 tensor(0.1126)
90 2.295994758605957 tensor(0.1126)
100 2.2954659461975098 tensor(0.1126)
110 2.2949328422546387 tensor(0.1126)
120 2.294393539428711 tensor(0.1126)
130 2.293846368789673 tensor(0.1126)
140 2.293290376663208 tensor(0.1126)
150 2.2927260398864746 tensor(0.1126)
160 2.2921512126922607 tensor(0.1128)
170 2.29156494140625 tensor(0.1130)
180 2.2909674644470215 tensor(0.1132)
190 2.2903568744659424 tensor(0.1134)
200 2.2897329330444336 tensor(0.1140)
210 2.2890942096710205 tensor(0.1142)
220 2.2884387969970703 tensor(0.1142)
230 2.2877678871154785 tensor(0.1146)
240 2.2870776653289795 tensor(0.1152)
250 2.2863683700561523 tensor(0.1160)
260 2.285637617111206 tensor(0.11

2150 0.8959529995918274 tensor(0.7632)
2160 0.8910797834396362 tensor(0.7646)
2170 0.8862506151199341 tensor(0.7650)
2180 0.8814650177955627 tensor(0.7664)
2190 0.8767220973968506 tensor(0.7678)
2200 0.8720213770866394 tensor(0.7696)
2210 0.8673626184463501 tensor(0.7712)
2220 0.8627450466156006 tensor(0.7738)
2230 0.8581682443618774 tensor(0.7760)
2240 0.8536316156387329 tensor(0.7774)
2250 0.8491347432136536 tensor(0.7790)
2260 0.8446772694587708 tensor(0.7802)
2270 0.8402584195137024 tensor(0.7814)
2280 0.8358781337738037 tensor(0.7828)
2290 0.8315359354019165 tensor(0.7840)
2300 0.8272313475608826 tensor(0.7852)
2310 0.8229638934135437 tensor(0.7876)
2320 0.8187333345413208 tensor(0.7888)
2330 0.8145394325256348 tensor(0.7900)
2340 0.8103814721107483 tensor(0.7920)
2350 0.8062594532966614 tensor(0.7934)
2360 0.802172839641571 tensor(0.7938)
2370 0.7981215119361877 tensor(0.7956)
2380 0.7941048741340637 tensor(0.7970)
2390 0.7901228666305542 tensor(0.7978)
2400 0.7861748933792114 te

4250 0.4205106496810913 tensor(0.8956)
4260 0.4196189045906067 tensor(0.8958)
4270 0.41873282194137573 tensor(0.8962)
4280 0.4178522825241089 tensor(0.8962)
4290 0.4169773459434509 tensor(0.8964)
4300 0.4161078631877899 tensor(0.8964)
4310 0.41524383425712585 tensor(0.8964)
4320 0.4143851101398468 tensor(0.8966)
4330 0.4135317802429199 tensor(0.8968)
4340 0.41268372535705566 tensor(0.8970)
4350 0.4118409752845764 tensor(0.8970)
4360 0.4110034108161926 tensor(0.8972)
4370 0.41017091274261475 tensor(0.8976)
4380 0.40934354066848755 tensor(0.8976)
4390 0.40852120518684387 tensor(0.8978)
4400 0.4077039062976837 tensor(0.8978)
4410 0.4068915545940399 tensor(0.8982)
4420 0.4060841500759125 tensor(0.8982)
4430 0.40528157353401184 tensor(0.8982)
4440 0.4044838845729828 tensor(0.8984)
4450 0.4036909341812134 tensor(0.8986)
4460 0.4029027223587036 tensor(0.8988)
4470 0.4021192789077759 tensor(0.8988)
4480 0.40134039521217346 tensor(0.8988)
4490 0.40056630969047546 tensor(0.8988)
4500 0.399796664

6330 0.3082233667373657 tensor(0.9188)
6340 0.30789250135421753 tensor(0.9188)
6350 0.3075627088546753 tensor(0.9192)
6360 0.30723416805267334 tensor(0.9192)
6370 0.30690670013427734 tensor(0.9194)
6380 0.30658039450645447 tensor(0.9194)
6390 0.30625516176223755 tensor(0.9194)
6400 0.305931031703949 tensor(0.9198)
6410 0.30560800433158875 tensor(0.9198)
6420 0.30528607964515686 tensor(0.9198)
6430 0.3049652874469757 tensor(0.9198)
6440 0.3046455383300781 tensor(0.9200)
6450 0.3043268322944641 tensor(0.9202)
6460 0.30400922894477844 tensor(0.9204)
6470 0.30369269847869873 tensor(0.9204)
6480 0.30337727069854736 tensor(0.9204)
6490 0.3030627965927124 tensor(0.9204)
6500 0.3027493953704834 tensor(0.9204)
6510 0.30243703722953796 tensor(0.9206)
6520 0.3021257221698761 tensor(0.9208)
6530 0.301815390586853 tensor(0.9208)
6540 0.30150604248046875 tensor(0.9210)
6550 0.3011978566646576 tensor(0.9210)
6560 0.30089056491851807 tensor(0.9210)
6570 0.30058425664901733 tensor(0.9212)
6580 0.300278

8410 0.25660091638565063 tensor(0.9308)
8420 0.2564108967781067 tensor(0.9308)
8430 0.256221204996109 tensor(0.9310)
8440 0.25603199005126953 tensor(0.9310)
8450 0.2558431029319763 tensor(0.9312)
8460 0.2556546628475189 tensor(0.9312)
8470 0.2554665505886078 tensor(0.9314)
8480 0.2552787959575653 tensor(0.9314)
8490 0.25509148836135864 tensor(0.9314)
8500 0.25490450859069824 tensor(0.9314)
8510 0.25471797585487366 tensor(0.9316)
8520 0.25453174114227295 tensor(0.9316)
8530 0.2543458640575409 tensor(0.9318)
8540 0.25416040420532227 tensor(0.9320)
8550 0.2539752721786499 tensor(0.9320)
8560 0.25379058718681335 tensor(0.9322)
8570 0.2536062002182007 tensor(0.9322)
8580 0.25342220067977905 tensor(0.9322)
8590 0.2532385587692261 tensor(0.9322)
8600 0.25305524468421936 tensor(0.9322)
8610 0.2528723180294037 tensor(0.9322)
8620 0.25268980860710144 tensor(0.9324)
8630 0.2525075376033783 tensor(0.9324)
8640 0.25232571363449097 tensor(0.9326)
8650 0.2521441876888275 tensor(0.9326)
8660 0.2519630

10470 0.22382287681102753 tensor(0.9404)
10480 0.2236880660057068 tensor(0.9404)
10490 0.22355341911315918 tensor(0.9404)
10500 0.2234189659357071 tensor(0.9404)
10510 0.22328469157218933 tensor(0.9404)
10520 0.2231506109237671 tensor(0.9402)
10530 0.22301669418811798 tensor(0.9402)
10540 0.22288298606872559 tensor(0.9402)
10550 0.22274944186210632 tensor(0.9402)
10560 0.22261609137058258 tensor(0.9402)
10570 0.22248288989067078 tensor(0.9404)
10580 0.2223498821258545 tensor(0.9404)
10590 0.22221703827381134 tensor(0.9404)
10600 0.22208437323570251 tensor(0.9406)
10610 0.22195185720920563 tensor(0.9406)
10620 0.22181957960128784 tensor(0.9406)
10630 0.221687451004982 tensor(0.9408)
10640 0.22155548632144928 tensor(0.9408)
10650 0.22142373025417328 tensor(0.9408)
10660 0.22129210829734802 tensor(0.9408)
10670 0.22116069495677948 tensor(0.9410)
10680 0.22102941572666168 tensor(0.9410)
10690 0.2208983451128006 tensor(0.9410)
10700 0.22076743841171265 tensor(0.9412)
10710 0.220636725425720

12490 0.19968371093273163 tensor(0.9466)
12500 0.19957709312438965 tensor(0.9466)
12510 0.1994706243276596 tensor(0.9466)
12520 0.19936421513557434 tensor(0.9468)
12530 0.19925794005393982 tensor(0.9468)
12540 0.19915176928043365 tensor(0.9468)
12550 0.19904567301273346 tensor(0.9468)
12560 0.19893968105316162 tensor(0.9468)
12570 0.19883380830287933 tensor(0.9468)
12580 0.19872808456420898 tensor(0.9468)
12590 0.19862240552902222 tensor(0.9468)
12600 0.198516845703125 tensor(0.9468)
12610 0.19841139018535614 tensor(0.9468)
12620 0.19830606877803802 tensor(0.9466)
12630 0.19820082187652588 tensor(0.9466)
12640 0.1980956792831421 tensor(0.9466)
12650 0.19799062609672546 tensor(0.9466)
12660 0.19788570702075958 tensor(0.9466)
12670 0.19778083264827728 tensor(0.9466)
12680 0.19767612218856812 tensor(0.9466)
12690 0.1975715160369873 tensor(0.9466)
12700 0.19746698439121246 tensor(0.9466)
12710 0.19736255705356598 tensor(0.9468)
12720 0.19725823402404785 tensor(0.9468)
12730 0.1971540004014

14510 0.180024653673172 tensor(0.9524)
14520 0.1799355298280716 tensor(0.9524)
14530 0.17984645068645477 tensor(0.9524)
14540 0.1797574758529663 tensor(0.9524)
14550 0.17966850101947784 tensor(0.9524)
14560 0.17957964539527893 tensor(0.9524)
14570 0.179490864276886 tensor(0.9524)
14580 0.17940214276313782 tensor(0.9524)
14590 0.17931349575519562 tensor(0.9524)
14600 0.1792249083518982 tensor(0.9524)
14610 0.17913638055324554 tensor(0.9524)
14620 0.17904792726039886 tensor(0.9524)
14630 0.17895956337451935 tensor(0.9524)
14640 0.1788712590932846 tensor(0.9524)
14650 0.17878304421901703 tensor(0.9524)
14660 0.17869484424591064 tensor(0.9524)
14670 0.178606778383255 tensor(0.9524)
14680 0.17851874232292175 tensor(0.9524)
14690 0.17843078076839447 tensor(0.9524)
14700 0.17834286391735077 tensor(0.9524)
14710 0.17825505137443542 tensor(0.9524)
14720 0.17816728353500366 tensor(0.9524)
14730 0.17807962000370026 tensor(0.9524)
14740 0.17799198627471924 tensor(0.9524)
14750 0.17790444195270538 

16530 0.16330432891845703 tensor(0.9568)
16540 0.16322733461856842 tensor(0.9568)
16550 0.1631503850221634 tensor(0.9568)
16560 0.16307350993156433 tensor(0.9568)
16570 0.16299669444561005 tensor(0.9568)
16580 0.16291992366313934 tensor(0.9568)
16590 0.1628432273864746 tensor(0.9568)
16600 0.16276654601097107 tensor(0.9568)
16610 0.1626899391412735 tensor(0.9568)
16620 0.16261336207389832 tensor(0.9568)
16630 0.1625368595123291 tensor(0.9568)
16640 0.16246041655540466 tensor(0.9568)
16650 0.16238397359848022 tensor(0.9568)
16660 0.16230762004852295 tensor(0.9572)
16670 0.16223131120204926 tensor(0.9574)
16680 0.16215503215789795 tensor(0.9574)
16690 0.1620788276195526 tensor(0.9578)
16700 0.16200266778469086 tensor(0.9578)
16710 0.16192655265331268 tensor(0.9580)
16720 0.16185049712657928 tensor(0.9580)
16730 0.16177451610565186 tensor(0.9580)
16740 0.16169854998588562 tensor(0.9580)
16750 0.16162264347076416 tensor(0.9580)
16760 0.16154679656028748 tensor(0.9580)
16770 0.1614709645509

18550 0.14871907234191895 tensor(0.9622)
18560 0.14865128695964813 tensor(0.9626)
18570 0.1485835760831833 tensor(0.9626)
18580 0.14851586520671844 tensor(0.9626)
18590 0.14844824373722076 tensor(0.9626)
18600 0.14838065207004547 tensor(0.9626)
18610 0.14831310510635376 tensor(0.9626)
18620 0.14824557304382324 tensor(0.9626)
18630 0.1481780707836151 tensor(0.9626)
18640 0.14811065793037415 tensor(0.9626)
18650 0.14804324507713318 tensor(0.9628)
18660 0.14797590672969818 tensor(0.9628)
18670 0.14790856838226318 tensor(0.9628)
18680 0.14784128963947296 tensor(0.9628)
18690 0.14777405560016632 tensor(0.9628)
18700 0.14770686626434326 tensor(0.9628)
18710 0.1476396918296814 tensor(0.9628)
18720 0.1475725769996643 tensor(0.9628)
18730 0.147505521774292 tensor(0.9628)
18740 0.14743846654891968 tensor(0.9628)
18750 0.14737145602703094 tensor(0.9628)
18760 0.14730452001094818 tensor(0.9628)
18770 0.1472375988960266 tensor(0.9628)
18780 0.14717072248458862 tensor(0.9628)
18790 0.147103890776634

20570 0.13579995930194855 tensor(0.9660)
20580 0.13573959469795227 tensor(0.9660)
20590 0.13567925989627838 tensor(0.9662)
20600 0.13561899960041046 tensor(0.9662)
20610 0.13555873930454254 tensor(0.9662)
20620 0.1354985237121582 tensor(0.9662)
20630 0.13543833792209625 tensor(0.9662)
20640 0.1353781670331955 tensor(0.9666)
20650 0.13531805574893951 tensor(0.9666)
20660 0.13525798916816711 tensor(0.9666)
20670 0.13519792258739471 tensor(0.9670)
20680 0.1351378858089447 tensor(0.9670)
20690 0.13507792353630066 tensor(0.9670)
20700 0.13501796126365662 tensor(0.9670)
20710 0.13495801389217377 tensor(0.9670)
20720 0.1348981261253357 tensor(0.9670)
20730 0.13483826816082 tensor(0.9672)
20740 0.13477842509746552 tensor(0.9672)
20750 0.13471867144107819 tensor(0.9672)
20760 0.13465888798236847 tensor(0.9672)
20770 0.13459916412830353 tensor(0.9672)
20780 0.13453947007656097 tensor(0.9674)
20790 0.1344798058271408 tensor(0.9674)
20800 0.13442017138004303 tensor(0.9674)
20810 0.1343605667352676

22590 0.12424907088279724 tensor(0.9718)
22600 0.1241949200630188 tensor(0.9718)
22610 0.12414082884788513 tensor(0.9718)
22620 0.12408673018217087 tensor(0.9718)
22630 0.12403266876935959 tensor(0.9718)
22640 0.1239786371588707 tensor(0.9718)
22650 0.123924620449543 tensor(0.9718)
22660 0.12387064099311829 tensor(0.9718)
22670 0.12381669878959656 tensor(0.9720)
22680 0.12376279383897781 tensor(0.9720)
22690 0.12370888888835907 tensor(0.9720)
22700 0.12365501374006271 tensor(0.9720)
22710 0.12360119819641113 tensor(0.9722)
22720 0.12354739010334015 tensor(0.9722)
22730 0.12349360436201096 tensor(0.9722)
22740 0.12343985587358475 tensor(0.9722)
22750 0.12338612973690033 tensor(0.9722)
22760 0.1233324334025383 tensor(0.9722)
22770 0.12327877432107925 tensor(0.9724)
22780 0.1232251226902008 tensor(0.9724)
22790 0.12317152321338654 tensor(0.9724)
22800 0.12311794608831406 tensor(0.9724)
22810 0.12306440621614456 tensor(0.9726)
22820 0.12301087379455566 tensor(0.9726)
22830 0.12295737117528

24600 0.11391030251979828 tensor(0.9750)
24610 0.11386148631572723 tensor(0.9750)
24620 0.11381268501281738 tensor(0.9752)
24630 0.11376389116048813 tensor(0.9752)
24640 0.11371516436338425 tensor(0.9752)
24650 0.11366641521453857 tensor(0.9752)
24660 0.11361772567033768 tensor(0.9752)
24670 0.11356904357671738 tensor(0.9752)
24680 0.11352039873600006 tensor(0.9752)
24690 0.11347176879644394 tensor(0.9752)
24700 0.1134231686592102 tensor(0.9752)
24710 0.11337458342313766 tensor(0.9752)
24720 0.11332602798938751 tensor(0.9752)
24730 0.11327751725912094 tensor(0.9752)
24740 0.11322900652885437 tensor(0.9752)
24750 0.11318051815032959 tensor(0.9752)
24760 0.1131320521235466 tensor(0.9752)
24770 0.113083615899086 tensor(0.9752)
24780 0.11303521692752838 tensor(0.9752)
24790 0.11298682540655136 tensor(0.9752)
24800 0.11293846368789673 tensor(0.9752)
24810 0.11289010941982269 tensor(0.9752)
24820 0.11284180730581284 tensor(0.9752)
24830 0.11279352754354477 tensor(0.9752)
24840 0.112745247781

26610 0.10457123816013336 tensor(0.9782)
26620 0.1045270636677742 tensor(0.9782)
26630 0.10448291152715683 tensor(0.9782)
26640 0.10443876683712006 tensor(0.9782)
26650 0.10439465194940567 tensor(0.9782)
26660 0.10435055941343307 tensor(0.9782)
26670 0.10430648177862167 tensor(0.9782)
26680 0.10426244884729385 tensor(0.9782)
26690 0.10421842336654663 tensor(0.9782)
26700 0.10417440533638 tensor(0.9782)
26710 0.10413041710853577 tensor(0.9782)
26720 0.10408645868301392 tensor(0.9782)
26730 0.10404251515865326 tensor(0.9782)
26740 0.10399860888719559 tensor(0.9782)
26750 0.10395470261573792 tensor(0.9782)
26760 0.10391084104776382 tensor(0.9782)
26770 0.10386697947978973 tensor(0.9782)
26780 0.10382314771413803 tensor(0.9782)
26790 0.10377933084964752 tensor(0.9782)
26800 0.1037355363368988 tensor(0.9784)
26810 0.10369177162647247 tensor(0.9784)
26820 0.10364800691604614 tensor(0.9784)
26830 0.10360429435968399 tensor(0.9784)
26840 0.10356059670448303 tensor(0.9784)
26850 0.1035169064998

28620 0.09610918909311295 tensor(0.9814)
28630 0.0960690900683403 tensor(0.9814)
28640 0.09602903574705124 tensor(0.9814)
28650 0.09598897397518158 tensor(0.9814)
28660 0.0959489569067955 tensor(0.9814)
28670 0.09590896964073181 tensor(0.9814)
28680 0.09586896747350693 tensor(0.9814)
28690 0.09582899510860443 tensor(0.9814)
28700 0.09578904509544373 tensor(0.9814)
28710 0.09574911743402481 tensor(0.9814)
28720 0.0957091897726059 tensor(0.9814)
28730 0.09566929936408997 tensor(0.9814)
28740 0.09562942385673523 tensor(0.9814)
28750 0.09558956325054169 tensor(0.9814)
28760 0.09554972499608994 tensor(0.9814)
28770 0.09550990909337997 tensor(0.9816)
28780 0.09547010809183121 tensor(0.9816)
28790 0.09543032944202423 tensor(0.9816)
28800 0.09539058059453964 tensor(0.9816)
28810 0.09535083919763565 tensor(0.9816)
28820 0.09531112015247345 tensor(0.9816)
28830 0.09527140110731125 tensor(0.9816)
28840 0.09523171931505203 tensor(0.9818)
28850 0.0951920598745346 tensor(0.9818)
28860 0.095152392983

30630 0.08842246979475021 tensor(0.9834)
30640 0.08838602155447006 tensor(0.9834)
30650 0.0883495882153511 tensor(0.9834)
30660 0.08831316977739334 tensor(0.9834)
30670 0.08827678859233856 tensor(0.9834)
30680 0.08824038505554199 tensor(0.9834)
30690 0.08820401877164841 tensor(0.9834)
30700 0.08816766738891602 tensor(0.9834)
30710 0.08813133835792542 tensor(0.9834)
30720 0.08809502422809601 tensor(0.9834)
30730 0.08805873245000839 tensor(0.9834)
30740 0.08802244067192078 tensor(0.9834)
30750 0.08798617124557495 tensor(0.9834)
30760 0.08794993907213211 tensor(0.9834)
30770 0.08791369944810867 tensor(0.9834)
30780 0.08787749707698822 tensor(0.9834)
30790 0.08784129470586777 tensor(0.9834)
30800 0.08780509978532791 tensor(0.9834)
30810 0.08776894956827164 tensor(0.9834)
30820 0.08773281425237656 tensor(0.9836)
30830 0.08769667148590088 tensor(0.9836)
30840 0.08766056597232819 tensor(0.9836)
30850 0.08762446790933609 tensor(0.9836)
30860 0.08758839219808578 tensor(0.9836)
30870 0.087552331

32640 0.08142770826816559 tensor(0.9844)
32650 0.08139451593160629 tensor(0.9844)
32660 0.08136136084794998 tensor(0.9844)
32670 0.08132820576429367 tensor(0.9844)
32680 0.08129504323005676 tensor(0.9844)
32690 0.08126193284988403 tensor(0.9844)
32700 0.0812288150191307 tensor(0.9844)
32710 0.08119572699069977 tensor(0.9844)
32720 0.08116264641284943 tensor(0.9844)
32730 0.08112958073616028 tensor(0.9844)
32740 0.08109652996063232 tensor(0.9844)
32750 0.08106350153684616 tensor(0.9844)
32760 0.0810304805636406 tensor(0.9844)
32770 0.08099746704101562 tensor(0.9844)
32780 0.08096449077129364 tensor(0.9844)
32790 0.08093152195215225 tensor(0.9844)
32800 0.08089857548475266 tensor(0.9844)
32810 0.08086562156677246 tensor(0.9844)
32820 0.08083267509937286 tensor(0.9844)
32830 0.08079978078603745 tensor(0.9844)
32840 0.08076688647270203 tensor(0.9844)
32850 0.0807340145111084 tensor(0.9844)
32860 0.08070115000009537 tensor(0.9844)
32870 0.08066828548908234 tensor(0.9844)
32880 0.08063545078

34650 0.0750584751367569 tensor(0.9868)
34660 0.07502825558185577 tensor(0.9870)
34670 0.07499805837869644 tensor(0.9870)
34680 0.07496786117553711 tensor(0.9870)
34690 0.07493768632411957 tensor(0.9870)
34700 0.07490753382444382 tensor(0.9870)
34710 0.07487738877534866 tensor(0.9870)
34720 0.07484724372625351 tensor(0.9870)
34730 0.07481713593006134 tensor(0.9870)
34740 0.07478702813386917 tensor(0.9870)
34750 0.07475694268941879 tensor(0.9870)
34760 0.07472686469554901 tensor(0.9872)
34770 0.07469680160284042 tensor(0.9872)
34780 0.07466675341129303 tensor(0.9872)
34790 0.07463672757148743 tensor(0.9872)
34800 0.07460670173168182 tensor(0.9874)
34810 0.07457670569419861 tensor(0.9876)
34820 0.0745466947555542 tensor(0.9876)
34830 0.07451671361923218 tensor(0.9876)
34840 0.07448675483465195 tensor(0.9876)
34850 0.07445679605007172 tensor(0.9876)
34860 0.07442685216665268 tensor(0.9876)
34870 0.07439693063497543 tensor(0.9876)
34880 0.07436703145503998 tensor(0.9876)
34890 0.0743371248

In [None]:
# plot based on learned features
import matplotlib.pyplot as plt
plt.imshow(net.W_1[:,6].reshape(28,28).detach())