In [18]:
import os
import pickle
import argparse
import time
import glob
import math

import numpy as np
from tqdm import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F

import wandb


class BnModel(torch.nn.Module):
    def __init__(self):
        super().__init__()

        self.conv1 = nn.Conv2d(1, 16, kernel_size=3)
        self.bn1 = nn.BatchNorm2d(16)
        self.relu1 = nn.ReLU()
        self.pool1 = nn.MaxPool2d((2, 2), 2)

        self.conv2 = nn.Conv2d(16, 32, kernel_size=3)
        self.bn2 = nn.BatchNorm2d(32)
        self.relu2 = nn.ReLU()
        self.pool2 = nn.MaxPool2d((2, 2), 2)

        self.conv3 = nn.Conv2d(32, 64, kernel_size=3)
        self.bn3 = nn.BatchNorm2d(64)
        self.relu3 = nn.ReLU()
        self.pool3 = nn.MaxPool2d((2, 2), 2)

        self.linear = nn.Linear(800, 10)

    def forward(self, x):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu1(x)
        x = self.pool1(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu2(x)
        x = self.pool2(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu3(x)
        x = self.pool3(x)

        x = x.view(x.shape[0], -1)

        x = self.linear(x)

        return x

In [None]:
from torchvision import datasets, transforms

In [None]:
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=True, download=True,
                    transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=16, shuffle=True)
test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('./data', train=False, transform=transforms.Compose([
                        transforms.ToTensor(),
                        transforms.Normalize((0.1307,), (0.3081,))
                    ])),
    batch_size=16, shuffle=True)

In [19]:
model = BnModel()

In [20]:
model(torch.rand((16, 1, 28, 28)))

tensor([[ 1.6677, -0.6971,  0.8865, -0.7557, -0.4095,  0.0749, -0.4439, -0.7791,
         -0.1379,  0.6308],
        [ 1.2749, -0.5110,  0.0318, -0.6762,  0.0443,  0.0099, -0.6583, -0.1688,
         -0.0811, -0.8272],
        [ 0.6621, -0.3624, -0.4092, -0.6143, -0.0137, -0.3406, -0.4109, -0.9980,
          0.4365,  0.2887],
        [ 0.7273, -1.0744, -0.0892, -0.9806, -0.2492, -0.5900, -0.1294,  0.2931,
         -0.0916, -0.2543],
        [ 1.7217, -0.7580,  0.1072, -0.3150, -1.0767, -0.9710, -0.3666,  0.3251,
         -0.7251,  0.4477],
        [ 1.3920, -1.2116, -0.0640, -0.8764, -0.7541,  0.2285, -0.6057, -0.2635,
         -0.7584,  0.8313],
        [ 0.8907, -0.5653,  0.5109, -0.2593, -0.7790, -0.0075, -0.8630,  0.2548,
         -0.0629,  0.2575],
        [ 0.4035, -0.8398, -0.6404, -0.1402, -0.7999, -0.2887,  0.1102,  0.1142,
         -0.0723,  0.2295],
        [ 0.4865, -0.7103,  0.5456, -0.9891, -0.5754, -0.3103, -0.0683, -0.0714,
         -0.3162,  0.3721],
        [ 1.2179, -