In [1]:
import import_ipynb
from Reservoir import Reservoir

import os
import torch
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.nn.functional as F
from tensorboardX import SummaryWriter

# Ignore warnings
import warnings
import time
warnings.filterwarnings("ignore")


from tqdm.notebook import tqdm_notebook

plt.ion()   # interactive mode

importing Jupyter notebook from Reservoir.ipynb


In [2]:
if __name__ == "__main__":
    dataset = Reservoir("./data/training/images/train-images-idx3-ubyte.gz",
                        "./data/training/labels/train-labels-idx1-ubyte.gz",
                        "./data/testing/images/t10k-images-idx3-ubyte.gz",
                        "./data/testing/labels/t10k-labels-idx1-ubyte.gz")

Total training stacks 5400
Total validation stacks 600
Total testing stacks 10000


In [4]:
class ConvNet(nn.Module):
    def __init__(self,outputs,image_shape):
        super(ConvNet, self).__init__()
        img_size = list(image_shape)
        img_size = torch.Size([1] + img_size)
        empty = torch.zeros(img_size)
        
        channels = 3
        kernel = 3
        padding = 1
        self.conv1 = nn.Sequential(nn.Conv2d(image_shape[0],
                                             out_channels = channels,
                                             kernel_size = kernel,
                                             padding = padding),
                                  nn.BatchNorm2d(channels),
                                  nn.MaxPool2d(2),
                                  nn.ReLU())
        units = self.conv1(empty).numel()
        print("units after conv", units)
        self.fc = nn.Sequential(nn.Linear(units, outputs))
        print("fc parameters: ",sum(p.numel() for p in self.fc.parameters()))
    
    def forward(self, x):
        #x: batch, channel, height, width
        batch_size = len(x)
        out = self.conv1(x)
        out = out.reshape((batch_size,-1))
        out = self.fc(out)
        return out
        
    def load_weights(self,path):
        checkpoint = torch.load(path)
        self.load_state_dict(checkpoint['model_state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        self.start_epoch = checkpoint['epoch']
        
    def save_weights(self,optimizer,epoch,path):
        torch.save({
            'epoch': epoch,
            'model_state_dict': self.state_dict(),
            'optimizer_state_dict': optimizer.state_dict()}, 
            path)
       
    def current_snapshot_name():
        from time import gmtime, strftime
        import socket

        hostname = socket.gethostname()

        date = strftime("%b%d_", gmtime())
        clock = strftime("%X", gmtime())
        now = clock.split(":")
        now = date+'-'.join(now)

        name = now+"_"+hostname
        return name
    
if __name__ == "__main__":
    image = dataset.training_pool[0]['image']
    image = image.reshape((1,28,28))
    net = ConvNet(1,image.shape)

(1, 28, 28)
units after conv 588
fc parameters:  589


In [7]:
if __name__ == "__main__":
    dataloader = dataset.training_pool.dataloader
    for i, batch in enumerate(dataloader):
        if i > 0:
            break
        
        imgs = batch['image'].float()
        imgs = np.transpose(imgs,(0,3,1,2))
        print(imgs[0].shape)
        print("input", imgs.shape)
        out = net(imgs)
        print("output", out)

torch.Size([1, 28, 28])
input torch.Size([10, 1, 28, 28])
output tensor([[-0.8686],
        [-1.2079],
        [-1.0248],
        [-1.2757],
        [-1.0928],
        [-0.3885],
        [-0.8365],
        [-0.8118],
        [-1.5839],
        [-1.0070]], grad_fn=<AddmmBackward>)
