In [11]:
import torch
import torch.utils.data
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
from torchvision.utils import save_image
import os
import random
from sklearn.metrics import confusion_matrix,accuracy_score,precision_score,recall_score

In [21]:
batch_size=256
epochs=300
seed=1
cuda=False and torch.cuda.is_available()
log_interval=10
r_dim=128
#z_dim=128
path_save="results_cnp_class/"

num_class=10
context_target_ratio=0.5 #0.7 context, 0.3 target

In [22]:

torch.manual_seed(seed)
random.seed(seed)
device = torch.device("cpu") #"cuda" if args.cuda else 

kwargs = {'num_workers': 1, 'pin_memory': True} if cuda else {}

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=True, download=True,
                   transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('../data', train=False, transform=transforms.ToTensor()),
    batch_size=batch_size, shuffle=True, **kwargs)


In [4]:
def get_context_idx(N,batch_size):
    # generate the indeces of the N context points in a flattened image
    idx = random.sample(range(0, batch_size), N)
    idx = torch.tensor(idx, device=device)
    return idx

In [5]:
def generate_grid(h, w):
    rows = torch.linspace(0, 1, h, device=device)
    cols = torch.linspace(0, 1, w, device=device)
    grid = torch.stack([cols.repeat(h, 1).t().contiguous().view(-1), rows.repeat(w)], dim=1)
    grid = grid.unsqueeze(0)
    return grid

In [6]:
def idx_to_y(idx, data):
    # get the [0;1] pixel intensity at each index
    y = torch.index_select(data, dim=1, index=idx)
    return y


In [7]:
def idx_to_x(idx, batch_size):
    # From flat idx to 2d coordinates of the 28x28 grid. E.g. 35 -> (1, 7)
    
    x = torch.index_select(x_grid, dim=1, index=idx)
    
    x = x.expand(batch_size, -1, -1)
    
    return x

In [63]:
class CNPClass(nn.Module):
    def __init__(self, r_dim):#,z_dim
        super(CNPClass, self).__init__()
        self.r_dim = r_dim
        #self.z_dim = z_dim
        self.h_1 = nn.Linear(784+1, 256)
        self.h_2 = nn.Linear(256, 256)
        self.h_3 = nn.Linear(256, self.r_dim)

        self.g_1 = nn.Linear(num_class*self.r_dim + 784, 512)
        self.g_2 = nn.Linear(512,256)
        self.g_3 = nn.Linear(256,256)
        self.g_proba = nn.Linear(256, num_class)
        self.g_std = nn.Linear(256, num_class)
        

    def h(self, x_y):
        x_y = F.relu(self.h_1(x_y))
        x_y = F.relu(self.h_2(x_y))
        x_y = F.relu(self.h_3(x_y))
        return x_y

    def aggregate(self, r,y):
        #r is of size batch_size*r_dim
        r_list=[]
        y=y.view(y.shape[0])
        for i in range(num_class):
            
            r_i=r[y==i]
            if r_i.shape[0]==0:
                r_list.append(torch.zeros(r_dim))
            else:
                r_list.append(torch.mean(r_i, dim=0))
        return  torch.cat(r_list,dim=0)

    def g(self, rep, x_target):
        r_et_x = torch.cat([rep, x_target], dim=1)
        input = F.relu(self.g_1(r_et_x))
        input = F.relu(self.g_2(input))
        input = F.relu(self.g_3(input))
        proba=F.log_softmax(self.g_proba(input),dim=1)
        std=0.1 + 0.9 * F.softplus(self.g_proba(input))
        return proba,std

    def xy_to_r_params(self, x, y):

        x_y = torch.cat([x, y.float()], dim=1)
        r_i = self.h(x_y)        
        r = self.aggregate(r_i,y)
        
        return r

    def forward(self, x_context, y_context,x_target):
        r_context = self.xy_to_r_params(x_context, y_context)  # (mu, logvar) of z
        #expand r to pass in the MLP
        r_expand = r_context.expand(x_target.shape[0], -1)
        # reconstruct the whole image including the provided context points
        proba,std = self.g(r_expand, x_target)
        return proba,std
        

In [64]:
def np_loss(proba,std, y,criterion = nn.CrossEntropyLoss()):#, z_all, z_context
    #need improvement
    '''
    batchsize=y.shape[0]
    y=y.view(batchsize).long()
    loss=torch.tensor(0.)
    a=0
    for i in range(num_class):
        loss+=torch.sum(torch.log(proba[y==i,i]))
    '''
    label=torch.zeros_like(y)
    
    return criterion(proba, y)+torch.log(std).sum(dim=0).mean()

In [65]:
model = CNPClass(r_dim).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-3)
x_grid = generate_grid(28, 28)
os.makedirs(path_save, exist_ok=True)

In [66]:
def train(epoch):
    model.train()
    train_loss = 0
    for batch_idx, (img, label) in enumerate(train_loader):
        
        batch_size = img.shape[0]
        img=img.to(device).squeeze().view(batch_size,-1)
        label=label.to(device).view(batch_size,1)
        
        #context_idx = get_context_idx(int(context_target_ratio*batch_size),batch_size)
        #create context set
        #img_context = img[context_idx]
        #label_context = label[context_idx]
        
        
        proba,std= model(img, label,img)#img_context, label_context,
        
        optimizer.zero_grad()
        
        loss = np_loss(proba,std,label.view(batch_size))
        loss.backward()
        train_loss += loss.item()
        optimizer.step()
        
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(label), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader),
                       loss.item() / len(label)))
    print('====> Epoch: {} Average loss: {:.4f}'.format(
        epoch, train_loss / len(train_loader.dataset)))

In [67]:
def test(epoch):
    model.eval()
    test_loss = 0
    label_true=[]
    label_pred=[]
    with torch.no_grad():
        for batch_idx, (img, label) in enumerate(test_loader):
            batch_size = img.shape[0]
            img=img.to(device).squeeze().view(batch_size,-1)
            label=label.to(device).view(batch_size,1)

            #N = int(context_target_ratio*batch_size)
            #context_idx = get_context_idx(N,batch_size)
            #img_context = img[context_idx]
            #label_context = label[context_idx]
            proba,std= model(img, label,img)#_context
            label_true.append(label.view(batch_size))
            label_pred.append(torch.argmax(proba,dim=1).view(batch_size))     
            test_loss += criterion(proba,label.view(batch_size)) #z_all, z_context
            
    
    y_pred=torch.cat(label_pred,dim=0).numpy()
    y_true=torch.cat(label_true,dim=0).numpy()
    
    test_loss /= len(test_loader.dataset)
    print('====> Test set loss: {:.4f}'.format(test_loss))
    print(confusion_matrix(y_true,y_pred))
    print("Accuracy:",accuracy_score(y_true,y_pred))

In [68]:
for epoch in range(1, epochs + 1):
    train(epoch)
    test(epoch)

====> Epoch: 1 Average loss: -2.2465
====> Test set loss: 0.0080
[[ 660   17    0    0    0    0   80   17  206    0]
 [   2 1048    0    1    0    4   38   29   13    0]
 [ 410   97    0   14    0    7  241   52  211    0]
 [ 276  169    0   49    0   29  216   45  226    0]
 [  78  358    0    0    0    5  276   94  155   16]
 [ 152  216    0    0    0   52  216   72  183    1]
 [ 281  123    0    0    0    1  296   57  200    0]
 [  20  500    0    0    0    2  224  157  116    9]
 [ 229   87    0    1    0   12  255   62  327    1]
 [  52  330    0    0    0    3  288  130  147   59]]
Accuracy: 0.2648
====> Epoch: 2 Average loss: -2.2961
====> Test set loss: 0.0050
[[ 817    0  116   15    3   22    0    0    6    1]
 [   0 1079    7    2   46    0    0    0    0    1]
 [  29   14  882   11   66    4    0    6   16    4]
 [  45   13  356  209  125  140    0   18   56   48]
 [   6  125   46    1  522    3    0   62   37  180]
 [  35    6  116   83  129  323    0   17  128   55]
 [  

====> Epoch: 5 Average loss: -2.3008
====> Test set loss: 0.0016
[[ 941    0    6    2    3    7   11    1    9    0]
 [   0 1106    2    6    0    1    5    0   14    1]
 [   8   13  906   13   14    1   20    6   47    4]
 [   1    3   31  857    1   50    1   15   39   12]
 [   2    8    2    0  851    0   12    1   10   96]
 [   7    2    8   76   18  712   16    4   37   12]
 [   7    3   10    0   18   19  897    0    4    0]
 [   2   29   12    2    6    0    0  851   13  113]
 [   6   12   13   13   18   33   10    6  829   34]
 [   7    3    0    5   48    7    0   23   16  900]]
Accuracy: 0.885
====> Epoch: 6 Average loss: -2.3011
====> Test set loss: 0.0014
[[ 951    0    6    1    1    5   11    1    4    0]
 [   0 1105    1    6    0    1    4    1   16    1]
 [  11    9  918   26   10    1   17    7   30    3]
 [   1    1   25  919    1   24    0   13   22    4]
 [   2   10    6    0  893    0   13    3   10   45]
 [  10    2    6   91   15  708   19    4   31    6]
 [  1

====> Epoch: 10 Average loss: -2.3017
====> Test set loss: 0.0009
[[ 971    0    0    1    1    2    3    1    1    0]
 [   0 1099    3    2    1    1    5    1   22    1]
 [  22    1  949    5    6    0   12    6   29    2]
 [   5    1   15  924    0   28    1    7   19   10]
 [   4    3    3    0  923    0    9    3    5   32]
 [  15    1    2   27    4  801   17    2   16    7]
 [  14    3    3    0   17   10  907    0    4    0]
 [   2   13   13    1    8    1    0  922    6   62]
 [  12    4    4   11   15   14   16    4  875   19]
 [  13    4    1    6   22    4    0   10    5  944]]
Accuracy: 0.9315
====> Epoch: 11 Average loss: -2.3018
====> Test set loss: 0.0008
[[ 968    0    1    1    0    6    3    1    0    0]
 [   0 1106    3    3    1    1    4    1   15    1]
 [  19    1  955    4    6    0    5    5   35    2]
 [   2    0    9  943    1   30    0    7   17    1]
 [   2    2    4    0  943    0    6    2    4   19]
 [   8    1    0   19    2  840    6    0   14    2]
 [

====> Epoch: 14 Average loss: -2.3020
====> Test set loss: 0.0006
[[ 948    0    3    2    0    7   17    0    2    1]
 [   0 1121    3    2    0    1    4    1    3    0]
 [   4    2  994    7    4    0    5    5   10    1]
 [   0    1    4  989    2    3    0    4    6    1]
 [   0    3    4    0  948    1   10    1    3   12]
 [   4    1    0   28    1  835   11    0    8    4]
 [   3    3    2    0   10    8  931    0    1    0]
 [   1   13   17   11    7    2    0  930    3   44]
 [   2    7    8   26   11    9   12    3  888    8]
 [   5    9    0   13   17    8    1    7    3  946]]
Accuracy: 0.953
====> Epoch: 15 Average loss: -2.3021
====> Test set loss: 0.0006
[[ 967    0    1    2    0    1    7    0    2    0]
 [   0 1118    3    1    0    1    4    2    6    0]
 [  16    3  978    1    4    0    5    9   15    1]
 [   0    1    2  949    1   29    0    8   17    3]
 [   0    2    2    0  954    1    9    1    2   11]
 [   8    1    0    6    1  857    9    0    7    3]
 [ 

====> Epoch: 19 Average loss: -2.3022
====> Test set loss: 0.0004
[[ 965    0    1    4    0    3    4    1    2    0]
 [   0 1125    2    2    0    1    1    1    3    0]
 [   4    1  995   10    3    0    1   10    7    1]
 [   0    0    2  996    0    5    0    4    3    0]
 [   0    0    4    0  965    1    4    3    0    5]
 [   4    0    0   23    2  852    5    0    4    2]
 [   7    3    3    0   13    6  924    0    2    0]
 [   0   12    6    7    2    2    0  981    0   18]
 [   3    0    4   18    8    8    5    4  921    3]
 [   4    6    0   16   12    6    0    9    5  951]]
Accuracy: 0.9675
====> Epoch: 20 Average loss: -2.3023
====> Test set loss: 0.0005
[[ 962    0    1    2    0    2    6    0    6    1]
 [   0 1110    3    0    0    1    2    3   16    0]
 [   3    2  970    5    1    0    2   12   36    1]
 [   0    1    1  960    0    8    0    6   17   17]
 [   0    0    2    0  925    1    5    4    4   41]
 [   3    0    0    7    1  858    2    0   11   10]
 [

====> Epoch: 23 Average loss: -2.3023
====> Test set loss: 0.0004
[[ 961    0    2    3    1    2   10    0    1    0]
 [   0 1125    3    1    0    1    1    1    3    0]
 [   2    2 1005    6    2    0    2    8    4    1]
 [   0    0    3  995    0    3    0    6    2    1]
 [   0    0    1    0  965    0    8    2    0    6]
 [   2    0    0   39    2  835    7    0    4    3]
 [   2    3    3    1    7    4  937    0    1    0]
 [   1    6    9    5    1    0    0  983    0   23]
 [   2    0   11   23    9    6    8    3  901   11]
 [   3    3    0   12   11    2    0    5    1  972]]
Accuracy: 0.9679
====> Epoch: 24 Average loss: -2.3024
====> Test set loss: 0.0005
[[ 960    0    6    3    0    2    6    0    3    0]
 [   0 1116    4    1    0    1    4    1    8    0]
 [   0    0 1013    4    1    0    1    1   12    0]
 [   0    0    4  994    0    3    0    2    7    0]
 [   1    0    5    0  949    0    8    3    4   12]
 [   3    0    0   27    1  851    3    0    6    1]
 [

====> Epoch: 28 Average loss: -2.3024
====> Test set loss: 0.0005
[[ 970    0    0    2    0    2    3    0    3    0]
 [   0 1118    4    2    0    2    3    1    5    0]
 [  15    0  999    7    1    0    2    3    5    0]
 [   0    0    3  974    0   27    0    4    2    0]
 [   2    0    1    0  951    0   17    2    3    6]
 [   4    0    0    5    0  880    3    0    0    0]
 [   9    3    1    0    2   12  931    0    0    0]
 [   3    6   14    4    1    1    0  995    4    0]
 [   9    0    3    7    1   19    7    4  923    1]
 [   6    4    0   10    5   15    2   10   18  939]]
Accuracy: 0.968
====> Epoch: 29 Average loss: -2.3024
====> Test set loss: 0.0005
[[ 958    0    6    3    0    3    2    0    5    3]
 [   0 1124    3    3    0    1    0    1    3    0]
 [   0    2 1007   10    2    0    1    2    8    0]
 [   0    0    2 1003    0    0    0    1    2    2]
 [   0    0    4    0  956    0    4    1    1   16]
 [   2    0    0   49    1  827    0    0    6    7]
 [ 

====> Epoch: 32 Average loss: -2.3025
====> Test set loss: 0.0004
[[ 952    1    5    2    0    6    8    2    3    1]
 [   0 1126    3    2    0    1    1    0    2    0]
 [   1    2 1005    5    3    1    1   10    4    0]
 [   0    2    4  980    0   14    0    5    2    3]
 [   0    0    3    0  961    0    8    1    0    9]
 [   0    0    0    8    1  880    1    0    0    2]
 [   1    3    4    0   12    8  928    0    2    0]
 [   0    8    8    3    1    1    0  996    1   10]
 [   1    3    3    7    3   11    9    3  923   11]
 [   3    3    0    4    5    4    1    5    2  982]]
Accuracy: 0.9733
====> Epoch: 33 Average loss: -2.3025
====> Test set loss: 0.0004
[[ 961    0    3    2    0    2    3    4    4    1]
 [   0 1124    4    1    0    1    3    0    2    0]
 [   1    3 1005    5    3    0    2    7    6    0]
 [   0    1    3  996    0    3    0    3    4    0]
 [   0    0    2    0  967    0    9    1    1    2]
 [   1    0    0   23    2  848    7    4    5    2]
 [

====> Epoch: 37 Average loss: -2.3025
====> Test set loss: 0.0005
[[ 972    0    0    3    0    0    2    0    3    0]
 [   0 1121    4    3    0    1    1    0    5    0]
 [   6    1 1000   12    1    0    2    3    7    0]
 [   0    0    3  996    0    7    0    1    2    1]
 [   4    1    1    1  945    0   17    4    1    8]
 [   3    0    0   12    1  870    4    0    2    0]
 [   7    3    4    0    4    4  935    0    1    0]
 [   2    8   13    8    1    2    0  982    3    9]
 [   4    1    1   11    2    8    5    2  937    3]
 [   4    3    0   13    3    4    1    4   11  966]]
Accuracy: 0.9724
====> Epoch: 38 Average loss: -2.3025
====> Test set loss: 0.0005
[[ 967    0    1    3    0    3    3    0    3    0]
 [   0 1120    6    3    0    1    0    2    3    0]
 [   2    1 1004   10    1    0    1    5    8    0]
 [   0    0    4  994    0    2    0    4    2    4]
 [   0    0    3    0  956    1    6    5    2    9]
 [   1    0    0   18    1  865    0    3    3    1]
 [

====> Epoch: 41 Average loss: -2.3025
====> Test set loss: 0.0005
[[ 970    0    0    2    0    2    2    0    4    0]
 [   0 1127    3    1    0    1    0    0    3    0]
 [   5    5  993    7    4    0    1    6   11    0]
 [   0    0    4  992    0    4    0    3    4    3]
 [   1    0    1    0  965    0    6    1    2    6]
 [   2    0    0   13    3  860    5    1    7    1]
 [   4    3    1    0   11    4  932    0    3    0]
 [   1    8    9    3    6    1    0  982    7   11]
 [   3    1    2    3    2    2    2    2  954    3]
 [   3    2    0    7   10    2    1    3   15  966]]
Accuracy: 0.9741
====> Epoch: 42 Average loss: -2.3025
====> Test set loss: 0.0006
[[ 971    1    0    0    0    1    3    1    3    0]
 [   0 1127    3    1    0    1    1    0    2    0]
 [   5    3 1005    6    2    0    1    1    9    0]
 [   1    2    4  986    0    8    0    4    3    2]
 [   1    4    1    0  945    0   14    5    1   11]
 [   3    1    0    8    0  872    3    2    2    1]
 [

====> Epoch: 46 Average loss: -2.3025
====> Test set loss: 0.0006
[[ 970    0    0    1    0    1    4    1    3    0]
 [   0 1122    3    4    0    1    1    0    4    0]
 [   9    2  996    6    1    0    2    5   11    0]
 [   0    0    3  987    0   16    0    2    1    1]
 [   1    1    1    0  950    0   17    5    1    6]
 [   2    0    0    6    1  875    6    0    2    0]
 [   2    3    2    0    2    5  943    0    1    0]
 [   1    7    9    5    0    2    0  997    3    4]
 [   5    1    2    5    2    9    5    2  940    3]
 [   6    5    0    8    5    7    1    8    7  962]]
Accuracy: 0.9742
====> Epoch: 47 Average loss: -2.3025
====> Test set loss: 0.0005
[[ 957    0    2    2    0    2    8    4    4    1]
 [   0 1121    2    3    0    0    5    0    4    0]
 [   0    3  987   11    6    0    1   11   12    1]
 [   0    1    3  980    0   12    0    6    3    5]
 [   0    0    2    0  963    0    6    3    0    8]
 [   1    0    0    6    1  875    2    0    5    2]
 [

====> Epoch: 50 Average loss: -2.3025
====> Test set loss: 0.0005
[[ 971    1    1    0    0    1    2    1    3    0]
 [   0 1125    7    0    0    0    0    0    3    0]
 [   2    0 1017    0    2    0    1    3    7    0]
 [   0    2    8  972    0   13    0    5    3    7]
 [   1    0    4    0  962    0    6    3    2    4]
 [   2    0    0    6    1  875    2    0    2    4]
 [   5    2    5    0    8    7  928    0    3    0]
 [   1    6   12    0    2    1    0  991    5   10]
 [   2    1    5    1    3    4    3    2  948    5]
 [   4    3    0    4    7    2    1    4    8  976]]
Accuracy: 0.9765
====> Epoch: 51 Average loss: -2.3025
====> Test set loss: 0.0006
[[ 968    1    0    2    0    1    4    0    4    0]
 [   0 1127    4    0    0    1    0    0    3    0]
 [   6    0 1009    5    3    0    2    0    7    0]
 [   0    0    4  987    0   14    0    2    2    1]
 [   0    1    2    0  960    0   10    1    1    7]
 [   2    0    0   13    1  869    5    0    1    1]
 [

====> Epoch: 55 Average loss: -2.3026
====> Test set loss: 0.0006
[[ 962    1    2    1    0    3    2    4    4    1]
 [   0 1128    2    1    0    0    0    0    4    0]
 [   0    6  998    7    4    0    1    9    7    0]
 [   0    2    3  990    0    3    0    4    1    7]
 [   0    1    2    0  965    0    3    1    1    9]
 [   0    2    0   17    1  862    1    2    5    2]
 [   5    3    3    0   17   17  908    0    5    0]
 [   1    8    7    1    3    0    0  997    3    8]
 [   1    1    2    3    2    2    1    3  952    7]
 [   3    2    0    4    6    2    0    5    4  983]]
Accuracy: 0.9745
====> Epoch: 56 Average loss: -2.3026
====> Test set loss: 0.0006
[[ 971    0    0    2    0    1    2    0    4    0]
 [   0 1117    3    3    0    0    2    1    9    0]
 [   6    0 1001    5    2    0    1    0   17    0]
 [   0    0    4  989    0    4    0    2    4    7]
 [   2    0    1    0  933    0    7    4    3   32]
 [   1    0    0    9    1  873    2    0    3    3]
 [

====> Epoch: 59 Average loss: -2.3026
====> Test set loss: 0.0006
[[ 969    1    1    2    0    2    2    0    3    0]
 [   0 1127    3    1    0    0    0    0    4    0]
 [   2    0 1018    1    2    0    1    2    6    0]
 [   0    1    5  990    0    4    0    2    4    4]
 [   0    0    4    0  961    0    9    1    0    7]
 [   1    0    0   12    0  860    9    0    6    4]
 [   3    3    3    1    6    3  937    0    2    0]
 [   1    7   13    4    3    0    0  986    3   11]
 [   0    1    3    4    2    1    4    2  952    5]
 [   4    2    1    7    9    1    1    4    6  974]]
Accuracy: 0.9774
====> Epoch: 60 Average loss: -2.3026
====> Test set loss: 0.0006
[[ 968    1    1    1    0    3    2    1    3    0]
 [   0 1127    2    1    0    1    0    1    3    0]
 [   2    1 1010    6    3    2    1    4    2    1]
 [   0    0    4  992    0    7    0    3    0    4]
 [   1    0    1    0  968    0    5    0    1    6]
 [   0    0    0   10    1  877    2    0    1    1]
 [

====> Epoch: 64 Average loss: -2.3026
====> Test set loss: 0.0006
[[ 970    1    1    1    0    0    3    1    3    0]
 [   0 1124    3    1    0    1    2    1    3    0]
 [   1    0 1022    1    2    0    1    3    2    0]
 [   0    0    6  986    0    4    0    5    1    8]
 [   0    0    2    0  965    0    9    1    0    5]
 [   1    0    0    9    1  859   14    1    3    4]
 [   5    2    4    0    5    2  939    0    1    0]
 [   1    8   14    0    3    0    0  988    5    9]
 [   1    1    3    5    1    4    9    2  943    5]
 [   2    2    0    3    6    1    1    3    2  989]]
Accuracy: 0.9785
====> Epoch: 65 Average loss: -2.3026
====> Test set loss: 0.0006
[[ 969    1    0    1    0    1    3    1    3    1]
 [   0 1126    2    1    0    1    0    1    4    0]
 [   1    5 1012    4    3    0    1    0    6    0]
 [   0    1    4  989    0    8    0    3    3    2]
 [   0    0    2    0  962    0    5    2    3    8]
 [   1    0    0   12    1  870    4    0    3    1]
 [

====> Epoch: 68 Average loss: -2.3025
====> Test set loss: 0.0007
[[ 973    1    0    0    0    0    2    1    3    0]
 [   0 1127    3    1    0    0    1    0    3    0]
 [   1    0 1015    2    0    0    1    0   13    0]
 [   1    2    7  980    0   11    0    4    4    1]
 [   0    0    4    0  966    0    7    1    2    2]
 [   2    2    0    9    1  864    7    2    5    0]
 [   5    3    2    0    8    6  932    0    2    0]
 [   1    8   14    2    2    0    0  993    3    5]
 [   0    2    2    3    1    2    4    2  955    3]
 [   2    4    1    3   12    6    1   10   14  956]]
Accuracy: 0.9761
====> Epoch: 69 Average loss: -2.3026
====> Test set loss: 0.0007
[[ 965    1    0    2    0    5    2    0    4    1]
 [   0 1127    2    1    0    1    1    0    3    0]
 [   1    3  988   13    4    0    1    3   19    0]
 [   0    0    3  982    0   16    0    2    3    4]
 [   0    0    1    0  972    0    5    0    2    2]
 [   0    0    0    5    1  883    1    0    2    0]
 [

====> Epoch: 73 Average loss: -2.3026
====> Test set loss: 0.0007
[[ 970    1    0    1    0    1    2    1    3    1]
 [   0 1125    2    2    0    0    1    1    4    0]
 [   3    0 1013    4    2    0    1    3    6    0]
 [   0    0    4  989    0    5    0    4    4    4]
 [   1    0    1    0  958    0    8    2    2   10]
 [   1    0    0   10    1  875    1    1    3    0]
 [   5    3    1    0    4    7  936    0    2    0]
 [   1    7    9    3    2    0    0  995    2    9]
 [   3    1    3    5    2    5    3    2  945    5]
 [   2    2    0    6    7    3    1    4    6  978]]
Accuracy: 0.9784
====> Epoch: 74 Average loss: -2.3026
====> Test set loss: 0.0007
[[ 971    1    0    1    0    0    2    1    3    1]
 [   0 1127    2    1    0    0    1    1    3    0]
 [   3    1 1011    4    2    0    1    4    6    0]
 [   0    0    4  988    0    3    0    5    4    6]
 [   1    0    1    0  966    0    7    1    1    5]
 [   1    0    0   10    1  874    2    1    3    0]
 [

====> Epoch: 77 Average loss: -2.3026
====> Test set loss: 0.0007
[[ 970    1    0    1    0    1    2    1    3    1]
 [   0 1127    2    1    0    0    1    1    3    0]
 [   3    0 1011    4    2    0    1    5    6    0]
 [   0    0    4  988    0    3    0    5    3    7]
 [   1    0    1    0  967    0    6    1    1    5]
 [   1    0    0   11    1  873    3    1    2    0]
 [   3    3    2    0    7    5  936    0    2    0]
 [   1    7    9    1    2    0    0  997    2    9]
 [   3    1    3    4    3    5    3    2  944    6]
 [   2    2    0    4    8    2    1    5    5  980]]
Accuracy: 0.9793
====> Epoch: 78 Average loss: -2.3026
====> Test set loss: 0.0007
[[ 970    1    0    1    0    1    2    1    3    1]
 [   0 1127    2    1    0    0    1    1    3    0]
 [   3    1 1012    4    2    0    1    4    5    0]
 [   0    0    4  989    0    2    0    5    4    6]
 [   1    0    1    0  967    0    6    1    1    5]
 [   1    0    0   12    1  872    3    1    2    0]
 [

====> Epoch: 82 Average loss: -2.3026
====> Test set loss: 0.0007
[[ 970    1    0    1    0    1    2    1    3    1]
 [   0 1127    2    1    0    0    1    1    3    0]
 [   3    0 1012    4    2    0    1    4    6    0]
 [   0    0    4  988    0    3    0    5    4    6]
 [   1    0    1    0  965    0    7    1    1    6]
 [   1    0    0   12    1  871    3    1    3    0]
 [   3    3    2    0    6    5  936    0    3    0]
 [   1    7    9    1    1    0    0  998    2    9]
 [   2    1    3    4    3    4    3    2  947    5]
 [   2    2    0    4    9    2    1    5    7  977]]
Accuracy: 0.9791
====> Epoch: 83 Average loss: -2.3026
====> Test set loss: 0.0008
[[ 970    1    0    1    0    1    2    1    3    1]
 [   0 1126    3    1    0    0    1    1    3    0]
 [   3    0 1012    4    2    0    1    5    5    0]
 [   0    0    4  989    0    2    0    5    4    6]
 [   1    0    1    0  965    0    7    1    1    6]
 [   1    0    0   12    1  871    3    1    3    0]
 [

====> Epoch: 86 Average loss: -2.3026
====> Test set loss: 0.0007
[[ 970    1    0    1    0    1    2    1    3    1]
 [   0 1127    2    1    0    0    1    1    3    0]
 [   3    0 1012    4    2    0    1    4    6    0]
 [   0    0    4  989    0    3    0    4    4    6]
 [   1    0    1    0  965    0    7    1    1    6]
 [   1    0    0   12    1  871    3    1    3    0]
 [   3    3    2    0    6    5  937    0    2    0]
 [   1    7    9    1    1    0    0  998    2    9]
 [   3    1    4    4    3    4    3    2  945    5]
 [   2    2    0    4    9    2    1    5    6  978]]
Accuracy: 0.9792
====> Epoch: 87 Average loss: -2.3026
====> Test set loss: 0.0008
[[ 969    1    0    1    0    1    3    1    3    1]
 [   0 1127    2    1    0    0    1    1    3    0]
 [   3    0 1012    4    2    0    1    4    6    0]
 [   0    0    4  988    0    3    0    5    4    6]
 [   1    0    1    0  965    0    7    1    1    6]
 [   1    0    0   12    1  871    3    1    3    0]
 [

====> Epoch: 91 Average loss: -2.3026
====> Test set loss: 0.0008
[[ 970    1    0    1    0    1    2    1    3    1]
 [   0 1126    3    1    0    0    1    1    3    0]
 [   3    0 1013    3    2    0    1    4    6    0]
 [   0    0    4  988    0    3    0    5    4    6]
 [   1    0    1    0  965    0    7    1    1    6]
 [   1    0    0   12    1  871    3    1    3    0]
 [   3    3    2    0    7    5  935    0    3    0]
 [   1    7    9    1    2    0    0  997    2    9]
 [   3    1    3    4    2    3    3    2  948    5]
 [   2    2    0    4    9    2    1    5    6  978]]
Accuracy: 0.9791
====> Epoch: 92 Average loss: -2.3026
====> Test set loss: 0.0009
[[ 969    1    0    1    0    1    3    1    3    1]
 [   0 1127    2    1    0    0    1    1    3    0]
 [   3    0 1013    3    2    0    1    4    6    0]
 [   0    0    4  987    0    4    0    5    4    6]
 [   1    0    1    0  965    0    7    1    1    6]
 [   1    0    0   12    1  871    3    1    3    0]
 [



KeyboardInterrupt: 

In [69]:
# Print model's state_dict
print("Model's state_dict:")
for param_tensor in model.state_dict():
    print(param_tensor, "\t", model.state_dict()[param_tensor].size())

# Print optimizer's state_dict
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():
    print(var_name, "\t", optimizer.state_dict()[var_name])

Model's state_dict:
h_1.weight 	 torch.Size([256, 785])
h_1.bias 	 torch.Size([256])
h_2.weight 	 torch.Size([256, 256])
h_2.bias 	 torch.Size([256])
h_3.weight 	 torch.Size([128, 256])
h_3.bias 	 torch.Size([128])
g_1.weight 	 torch.Size([512, 2064])
g_1.bias 	 torch.Size([512])
g_2.weight 	 torch.Size([256, 512])
g_2.bias 	 torch.Size([256])
g_3.weight 	 torch.Size([256, 256])
g_3.bias 	 torch.Size([256])
g_proba.weight 	 torch.Size([10, 256])
g_proba.bias 	 torch.Size([10])
g_std.weight 	 torch.Size([10, 256])
g_std.bias 	 torch.Size([10])
Optimizer's state_dict:
state 	 {140044872416496: {'step': 22319, 'exp_avg': tensor([[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00, -5.8337e-12],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00,  7.0658e-10],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...,  0.0000e+00,
          0.0000e+00, -8.9056e-16],
        ...,
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  ...

In [70]:
#save the model
torch.save(model.state_dict(), path_save+"model_param.pt")