In [7]:
import import_ipynb
from Reservoir import Reservoir

import os
import numpy as np

import torch
import torch.optim as optim
import torch.nn as nn
from tqdm.notebook import tqdm_notebook
from tensorboardX import SummaryWriter

# Ignore warnings
import warnings
import time
warnings.filterwarnings("ignore")

os.chdir("./MNIST")
from MNIST_model import ConvNet
os.chdir("./..")

In [8]:
class Model():
    
    def __init__(self,device):
        outputs = 10
        image_shape = (1,28,28)
        self.device = device
        self.network = ConvNet(outputs, image_shape).to(device)
        self.optimizer = optim.Adam(self.network.parameters(), amsgrad=True, weight_decay=0.01)
        self.criterion = nn.CrossEntropyLoss()
        self.start_epoch = 0
        
    def load_weights(self, load_path):
        checkpoint = torch.load(load_path)
        self.network.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.start_epoch = checkpoint['epoch']
        
    def save_weights(self,epoch, save_path):
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.network.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict()}, 
            save_path)
       
    def current_snapshot_name(self):
        from time import gmtime, strftime
        import socket

        hostname = socket.gethostname()

        date = strftime("%b%d_", gmtime())
        clock = strftime("%X", gmtime())
        now = clock.split(":")
        now = date+'-'.join(now)

        name = now+"_"+hostname
        return name

In [10]:
if __name__ == "__main__":
    dataset = Reservoir("./data/training/images/train-images-idx3-ubyte.gz",
                        "./data/training/labels/train-labels-idx1-ubyte.gz",
                        "./data/testing/images/t10k-images-idx3-ubyte.gz",
                        "./data/testing/labels/t10k-labels-idx1-ubyte.gz")
    dataloader = dataset.training_pool.dataloader
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Model(device)
    for i, batch in enumerate(dataloader):
        if i > 0:
            break
        imgs = batch['image'].float()
        print(imgs[0].shape)
        print("input", imgs.shape)
        out = model.network(imgs)
        print("output", out)

Total training stacks 5400
Total validation stacks 600
Total testing stacks 10000
units after conv 588
fc parameters:  5890
torch.Size([1, 28, 28])
input torch.Size([10, 1, 28, 28])
output tensor([[-0.0343,  0.5896, -1.9539, -0.6913,  1.4539,  0.0883,  1.5611, -0.1624,
          1.1815, -0.5465],
        [-0.1154, -0.8287, -0.3412,  0.2032,  0.2519, -0.6069,  1.1414,  0.8998,
         -0.4664, -0.6901],
        [ 0.6060, -0.2796, -0.5201,  0.7641,  0.1841,  0.3456,  0.8123,  0.3900,
          0.3624, -0.3430],
        [ 1.2073, -0.3711, -0.2898, -0.1989,  0.2594, -0.1146,  1.1477,  0.0136,
          0.4205, -0.1950],
        [ 0.5829,  0.1388, -0.4576,  0.2663,  0.8481, -0.6088,  0.4881, -0.2709,
         -0.2681,  0.0840],
        [ 0.2852,  0.2535, -0.5762,  0.6259,  0.5749, -0.9295,  0.4146, -0.0435,
         -0.3615, -0.3470],
        [ 0.0051, -0.0673, -0.3772, -0.4821, -0.0168, -0.0228,  0.9538, -0.1893,
         -1.0441, -0.3765],
        [ 0.5193, -0.4930, -0.9534,  0.1548,  0.