In [9]:
import math
import torch

from torch import nn
from typing import Union
from collections import OrderedDict

from torcheval import metrics
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision import datasets

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [10]:
def init_weights_pre_relu(input_dim, output_dim):
    """ Since we're using RELU activation, we'll implement the `he` initialization.
    I have ignored bias initialization problems, as we've got no "real training".
    No consideration on imbalance etc.
    We can test this using statistics and run few simulations to approx results with expectancy
    """
    std = math.sqrt(2 / input_dim)
    weights = torch.randn((input_dim, output_dim)) * std
    return weights


class SplitLinear(nn.Module):
    def __init__(self, input_dim, verbose=False):
        super().__init__()
        self.verbose = verbose
        output_dim = input_dim
        assert input_dim % 2 == 0, f"input_dim: {input_dim} should be even."

        self.network = nn.Sequential(OrderedDict([
            ("l1", nn.Linear(input_dim // 2, output_dim // 2)),
            ("a1", nn.ReLU())
        ]))
        # Custom weights creation!
        he_weights = init_weights_pre_relu(input_dim // 2, output_dim // 2)
        he_weights.requires_grad = True
        custom_weight = nn.Parameter(he_weights)
        self.network.l1.weight = custom_weight

    def set_verbose(self, verbose):
        self.verbose = verbose

    def forward(self, x: torch.Tensor):
        assert x.shape[1] % 2 == 0, f"x.shape[1]: {x.shape[1]} should be even."
        x1, x2 = x.split(x.shape[1] // 2, dim=-1)
        if self.verbose:
            print(f"x1: {x1}\nx2: {x2}")
        out1, out2 = self.network(x1), self.network(x2)
        if self.verbose:
            print(f"out1: {out1}\nout2: {out2}")
        return torch.cat([out1, out2], dim=-1)

In [17]:
def q1():
    N = 2  # Batch size
    M = 4  # Features (1d)

    model = SplitLinear(M, verbose=True)
    x = torch.rand((N, M))

    print(x)
    y = model(x)
    print(y)
    print(x.shape)
    print(y.shape)
    print(f"Shapes equal: {x.shape == y.shape}")

q1()

tensor([[0.4550, 0.0670, 0.9651, 0.1678],
        [0.9308, 0.7598, 0.4741, 0.5297]])
x1: tensor([[0.4550, 0.0670],
        [0.9308, 0.7598]])
x2: tensor([[0.9651, 0.1678],
        [0.4741, 0.5297]])
out1: tensor([[0., 0.],
        [0., 0.]], grad_fn=<ReluBackward0>)
out2: tensor([[0.1036, 0.0000],
        [0.0000, 0.0000]], grad_fn=<ReluBackward0>)
tensor([[0.0000, 0.0000, 0.1036, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000]], grad_fn=<CatBackward0>)
torch.Size([2, 4])
torch.Size([2, 4])
Shapes equal: True


In [11]:
class DropNorm(nn.Module):
    def __init__(self, input_dim: Union[tuple, list, int]):
        super().__init__()
        self.eps = 1e-16
        # We init params so that y_i = x_i, similarly to batch norm
        self.gamma = nn.Parameter(torch.ones(input_dim))
        self.beta = nn.Parameter(torch.zeros(input_dim))

    def dropout(self, x: torch.Tensor):
        # hard set of p to 0.5 like required.
        p = 0.5
        if not self.training:
            return x
        feature_shape = x.shape[1:]
        ele_num = math.prod(feature_shape)
        # bitwise check for `even` num
        assert ele_num & 1 == 0
        half_ele = ele_num // 2
        # The following process making sure we're dropping EXACTLY 1/2 of the `neurons`
        # Creating tensor with half 1 and half 0
        mask = torch.cat([torch.ones(half_ele, dtype=torch.float, device=x.device),
                          torch.zeros(half_ele, dtype=torch.float, device=x.device)])
        # Generate random permutation (to order the 1s and 0s) <=> shuffle
        perm = torch.randperm(ele_num, device=x.device)
        # Shuffle the mask, reshape to original feature shape
        mask = mask[perm].reshape(feature_shape)
        return x * mask / p, mask

    def normalize(self, x):
        # We want all dims EXCEPT the batch dim, to be included in the mean
        # meaning every sample will have its own mew, sig2, and eventually norm_x.
        dims = tuple(range(1, x.dim()))
        mew = torch.mean(x, dtype=torch.float32, dim=dims, keepdim=True)
        # std^2 | known also as `variance`
        sig2 = torch.sum((x - mew) ** 2, dim=dims, keepdim=True) / math.prod(x.shape[1:])
        norm_x = (x - mew) / torch.sqrt(sig2 + self.eps)
        return norm_x

    def forward(self, x):
        """ When training, we use dropout -> normalization and we mult with mask as requested
            (we must multiply again with the mask, as beta might not be 0, and we want 0s)
        When not training, we only use normalize(x)*gamma + beta."""
        if self.training:
            out1, mask = self.dropout(x)
            out2 = self.normalize(out1)
            # We multiply at mask again because parameters that were zeroed in dropout should stay zeroed
            out2 = (self.gamma * out2 + self.beta) * mask
        else:
            out2 = self.gamma * self.normalize(x) + self.beta
        return out2


class BasicNetwork(nn.Module):
    """
    Conv net, using natural dropout and layernorm
    """
    def __init__(self, input_shape=(1, 28, 28), num_classes=10):
        """ 28 => 14 => 7 """
        super().__init__()
        c, h, w = input_shape
        self.input_shape = input_shape
        self.backbone = nn.Sequential(
            nn.Conv2d(c, 32, 3, stride=1, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.5),
            nn.LayerNorm([32, h // 2, w // 2]),

            nn.Conv2d(32, 64, 3, stride=1, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Dropout(0.5),
            nn.LayerNorm([64, h // 4, w // 4]),

            nn.Conv2d(64, 128, 3, stride=1, padding='same'),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.LayerNorm([128, h // 4, w // 4]),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear((h // 4) * (w // 4) * 128, 256),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.LayerNorm(256),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        return x


class MySpecialNetwork(nn.Module):
    """ Conv net, using custom dropout and layernorm """
    def __init__(self, input_shape=(1, 28, 28), num_classes=10):
        """ 28 => 14 => 7 """
        super().__init__()
        c, h, w = input_shape
        self.input_shape = input_shape
        self.backbone = nn.Sequential(
            nn.Conv2d(c, 32, 3, stride=1, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            DropNorm([32, h // 2, w // 2]),

            nn.Conv2d(32, 64, 3, stride=1, padding='same'),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            DropNorm([64, h // 4, w // 4]),

            nn.Conv2d(64, 128, 3, stride=1, padding='same'),
            nn.ReLU(),
            DropNorm([128, h // 4, w // 4]),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear((h // 4) * (w // 4) * 128, 256),
            nn.ReLU(),
            DropNorm(256),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        x = self.backbone(x)
        x = self.classifier(x)
        return x


In [12]:
def norm_example():
    x = torch.arange(0, 3 * 2 * 4).reshape(3, 2, 4)
    print(x)
    # We want all dims EXCEPT the batch dim, to be included in the mean
    dims = tuple(range(1, x.dim()))
    mew = torch.mean(x, dtype=torch.float32, dim=dims, keepdim=True)
    sig2 = torch.sum((x - mew) ** 2, dim=dims, keepdim=True) / math.prod(x.shape[1:])
    eps = 1e-16

    norm_x = (x - mew) / torch.sqrt(sig2 + eps)
    print(norm_x)
    
# Example of the norm implementation
norm_example()

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23]]])
tensor([[[-1.5275, -1.0911, -0.6547, -0.2182],
         [ 0.2182,  0.6547,  1.0911,  1.5275]],

        [[-1.5275, -1.0911, -0.6547, -0.2182],
         [ 0.2182,  0.6547,  1.0911,  1.5275]],

        [[-1.5275, -1.0911, -0.6547, -0.2182],
         [ 0.2182,  0.6547,  1.0911,  1.5275]]])


In [13]:
def validation_loop(model, val_loader, loss_fn) -> (float, float):
    """ validation loop copied from maman13 with few modifications """
    val_loss = 0.
    metric = metrics.MulticlassAccuracy(device=DEVICE)
    model.eval()
    with torch.no_grad():
        for x, y in val_loader:
            x, y = x.to(DEVICE), y.to(DEVICE)
            preds = model(x)
            loss = loss_fn(preds, y)
            metric.update(preds, y)
            val_loss += loss.item()

    avg_val_loss = val_loss / len(val_loader)
    return avg_val_loss, metric.compute()


def train_model(
        model: nn.Module,
        train_loader: DataLoader,
        valid_loader: DataLoader,
        loss_fn: nn.Module,
        epochs: int = 10,
        verbose: int = 1,
        verbose_batch: int = 1,
        lr: float = 1e-4,
        wd: float = 0.05) -> nn.Module:
    """
    train loop, copied from maman13 with few modifications

    :param model:
    :param train_loader:
    :param valid_loader:
    :param epochs:
    :param verbose: [0, 1, 2] Level of printing information (0 None, 2 Max)
    :param verbose_batch: if verbose is 2, how many batches before printing metrices and loss.
    :param lr: learning rate
    :param wd: weight decay
    :return: a model
    """
    opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
    metric = metrics.MulticlassAccuracy(device=DEVICE)
    for epoch in range(epochs):
        running_loss = 0.
        model.train()
        metric.reset()
        for i, (x, y) in enumerate(train_loader):
            x, y = x.to(DEVICE), y.to(DEVICE)
            opt.zero_grad()
            preds = model(x)
            loss = loss_fn(preds, y)
            metric.update(preds, y)
            loss.backward()
            opt.step()
            running_loss += loss.item()

            # Print every `verbose_batch` batches
            if verbose >= 2 and i % verbose_batch == 0:
                print(f"Epoch [{epoch + 1}/{epochs}], "
                      f"Step [{i}/{len(train_loader)}], "
                      f"Loss: {loss.item():.4f}", sep=',')

        # End of epoch. Run validation and print outcomes
        avg_val_loss, metric_val = validation_loop(model, valid_loader, loss_fn)
        if verbose >= 1:
            print(f"Epoch [{epoch + 1:4}/{epochs}]", end=f", ")
            print(f"trn los: {running_loss / len(train_loader):8.4f},", f"trn acc: {metric.compute():6.4f}",
                  end=', ')
            print(f"val loss: {avg_val_loss:8.4f}, val acc: {metric_val:6.4f}")

    return model

In [14]:
# Load data 

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # mean and std from MNIST stats
])
train_data = datasets.MNIST('../MNIST_data', download=True, train=True, transform=transform)
test_data = datasets.MNIST('../MNIST_data', download=True, train=False, transform=transform)

train_loader = DataLoader(train_data, batch_size=50, shuffle=True)
test_loader = DataLoader(test_data, batch_size=100)

# Set loss fn
loss_fn = nn.CrossEntropyLoss()

In [15]:
model = BasicNetwork().to(DEVICE)
train_model(model, train_loader, test_loader, loss_fn, verbose=1, verbose_batch=100)

Epoch [   1/10], trn los:   0.6348, trn acc: 0.8238, val loss:   0.1216, val acc: 0.9626
Epoch [   2/10], trn los:   0.2371, trn acc: 0.9432, val loss:   0.0746, val acc: 0.9765
Epoch [   3/10], trn los:   0.1894, trn acc: 0.9592, val loss:   0.0535, val acc: 0.9850
Epoch [   4/10], trn los:   0.1812, trn acc: 0.9663, val loss:   0.0531, val acc: 0.9837
Epoch [   5/10], trn los:   0.1905, trn acc: 0.9700, val loss:   0.0476, val acc: 0.9871
Epoch [   6/10], trn los:   0.2167, trn acc: 0.9718, val loss:   0.0554, val acc: 0.9879
Epoch [   7/10], trn los:   0.2575, trn acc: 0.9736, val loss:   0.0676, val acc: 0.9869
Epoch [   8/10], trn los:   0.3001, trn acc: 0.9745, val loss:   0.0796, val acc: 0.9862
Epoch [   9/10], trn los:   0.3108, trn acc: 0.9742, val loss:   0.0765, val acc: 0.9870
Epoch [  10/10], trn los:   0.2993, trn acc: 0.9711, val loss:   0.0640, val acc: 0.9884


BasicNetwork(
  (backbone): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Dropout(p=0.5, inplace=False)
    (4): LayerNorm((32, 14, 14), eps=1e-05, elementwise_affine=True)
    (5): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (6): ReLU()
    (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (8): Dropout(p=0.5, inplace=False)
    (9): LayerNorm((64, 7, 7), eps=1e-05, elementwise_affine=True)
    (10): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (11): ReLU()
    (12): Dropout(p=0.5, inplace=False)
    (13): LayerNorm((128, 7, 7), eps=1e-05, elementwise_affine=True)
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=6272, out_features=256, bias=True)
    (2): ReLU()
    (3): Dropout(p=0.5, inplace=False)
    (4):

In [16]:
model = MySpecialNetwork().to(DEVICE)
train_model(model, train_loader, test_loader, loss_fn, verbose=1, verbose_batch=100)

Epoch [   1/10], trn los:   0.6143, trn acc: 0.8345, val loss:   0.1317, val acc: 0.9580
Epoch [   2/10], trn los:   0.2272, trn acc: 0.9473, val loss:   0.0861, val acc: 0.9737
Epoch [   3/10], trn los:   0.1853, trn acc: 0.9608, val loss:   0.0819, val acc: 0.9747
Epoch [   4/10], trn los:   0.1768, trn acc: 0.9680, val loss:   0.0716, val acc: 0.9778
Epoch [   5/10], trn los:   0.1849, trn acc: 0.9728, val loss:   0.0749, val acc: 0.9787
Epoch [   6/10], trn los:   0.2105, trn acc: 0.9737, val loss:   0.0741, val acc: 0.9817
Epoch [   7/10], trn los:   0.2533, trn acc: 0.9747, val loss:   0.0847, val acc: 0.9824
Epoch [   8/10], trn los:   0.2933, trn acc: 0.9757, val loss:   0.0968, val acc: 0.9830
Epoch [   9/10], trn los:   0.3083, trn acc: 0.9750, val loss:   0.0918, val acc: 0.9824
Epoch [  10/10], trn los:   0.2986, trn acc: 0.9749, val loss:   0.0820, val acc: 0.9844


MySpecialNetwork(
  (backbone): Sequential(
    (0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (1): ReLU()
    (2): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): DropNorm()
    (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (5): ReLU()
    (6): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (7): DropNorm()
    (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=same)
    (9): ReLU()
    (10): DropNorm()
  )
  (classifier): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): Linear(in_features=6272, out_features=256, bias=True)
    (2): ReLU()
    (3): DropNorm()
    (4): Linear(in_features=256, out_features=10, bias=True)
  )
)