Motivation: BatchNorm was one of the first NN modification (2015) which helped `stabilize` training process of deep NN. It normalizes outputs (usually from Dense or Conv layers - layers which have multiplication operation (x*w + b) inside of it that can lead to quite extreme values) before feeding them to activation function.

In [1]:
import torch
import torch.nn.functional as F

In [2]:
with open('unique_english_words.txt', 'r') as f:
    words = [word.rstrip() for word in f.readlines()] 

SPECIAL_TOKEN = '.'
unique_symbols = sorted(list(set(list(''.join(words)))))
unique_symbols.append(SPECIAL_TOKEN)
print(f'Num of unique symbols: {len(unique_symbols)}')

stoi = {s: i for i, s in enumerate(unique_symbols)}
itos = {i: s for s, i in stoi.items()}

Num of unique symbols: 28


In [3]:
import random

random.seed(23)
random.shuffle(words)

train_split = int(0.8 * len(words))
val_split = int(0.9 * len(words))

train_words = words[:train_split]
val_words = words[train_split:val_split]
test_words = words[val_split:]

print('train:', len(train_words))
print('val:', len(val_words))
print('test:', len(test_words))

train: 758
val: 95
test: 95


In [4]:
# build dataset on full data
block_size = 3 # how many characters we are gonna use to predict the next one

def build_dataset(words):
    X = []
    y = []
    
    for word in words:
        word = SPECIAL_TOKEN * block_size + word + SPECIAL_TOKEN
        for i in range(block_size, len(word)):
            context = word[i - block_size:i]
            ch_to_predict = word[i]
    
            X.append([stoi[ch] for ch in context])
            y.append(stoi[ch_to_predict])
    
    X = torch.tensor(X)
    y = torch.tensor(y)

    return X, y


X_train, y_train = build_dataset(train_words)
X_val, y_val = build_dataset(val_words)
X_test, y_test = build_dataset(test_words)

print(X_train.shape, y_train.shape)
print(X_val.shape, y_val.shape)
print(X_test.shape, y_test.shape)

torch.Size([7182, 3]) torch.Size([7182])
torch.Size([904, 3]) torch.Size([904])
torch.Size([931, 3]) torch.Size([931])


In [5]:
g = torch.Generator().manual_seed(23)

vector_dim = 2
C = torch.randn((len(unique_symbols), vector_dim),             generator = g)

hidden_layer_size = 100
W1 = torch.randn((block_size * vector_dim, hidden_layer_size), generator = g)
b1 = torch.randn((hidden_layer_size,),                         generator = g)
W2 = torch.randn((hidden_layer_size, len(unique_symbols)),     generator = g)
b2 = torch.randn((len(unique_symbols),),                       generator = g)

## add new trainable params
bn_gain = torch.ones((1, hidden_layer_size))
bn_bias = torch.zeros((1, hidden_layer_size))


params = [C, W1, b1, W2, b2, bn_gain, bn_bias]
for p in params:
    p.requires_grad = True

n_iter = 50_000
losses_train = []
losses_val = []
for i in range(n_iter):

    batch_size = 64
    rand_indecies = torch.randint(0, X_train.shape[0], (batch_size, ), generator = g)

    # ---------------forward pass---------------

    emb = C[X_train[rand_indecies, ...]] # shape: [batch_size, block_size, vector_dim]
    h_preact = emb.view(-1, block_size * vector_dim) @ W1 + b1

    ## BATCH NORM
    h_preact_norm = (h_preact - h_preact.mean(dim = 0, keepdim = True)) / h_preact.std(dim = 0, keepdim = True)
    h_scaled = bn_gain * h_preact_norm + bn_bias
    
    h = torch.tanh(h_scaled)
    logits = h @ W2 + b2
    
    loss = F.cross_entropy(logits, y_train[rand_indecies])
    losses_train.append(loss.item())
    
    # ---------------backward pass---------------
    for p in params:
        p.grad = None
    
    loss.backward()
    
    # update
    lr = 0.1 if i < 30_000 else 0.01
    for p in params:
        p.data += -lr * p.grad

    if i % 10_000 == 0:
        print(f'{i}/{n_iter}: {loss:.4f}')

0/50000: 12.2216
10000/50000: 2.5290
20000/50000: 2.6898
30000/50000: 2.4662
40000/50000: 2.4155


In [6]:
# calibrate the batch norm in the end of training
# (now we cannot evaluate or get prediction for a single example 
# because NN expects getting a batch to calculate its mean and std)

with torch.no_grad():
    emb = C[X_train]
    h_preact = emb.view(-1, block_size * vector_dim) @ W1 + b1

    # measure mean and std of the whole X_train dataset
    h_preact_mean = h_preact.mean(dim = 0, keepdim = True)
    h_preact_std = h_preact.std(dim = 0, keepdim = True)

In [7]:
@torch.no_grad()
def eval(split):
    x, y = {'train' : [X_train, y_train], 
            'val' : [X_val, y_val], 
            'test' : [X_test, y_test]}[split]
    
    emb = C[x]
    h_preact = emb.view(-1, block_size * vector_dim) @ W1 + b1
    #h_preact_norm = (h_preact - h_preact.mean(dim = 0, keepdim = True)) / h_preact.std(dim = 0, keepdim = True)
    h_preact_norm = (h_preact - h_preact_mean) / h_preact_std
    h_scaled = bn_gain * h_preact_norm + bn_bias
    h = torch.tanh(h_scaled)
    
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, y)
    return loss

In [8]:
eval('train')

tensor(2.3901)

In [9]:
eval('val')

tensor(2.5102)

## Calculating h_preact_mean and h_preact_std during training

But instead of calculating `h_preact_mean` and `h_preact_std` after training as kind of an additional training step we can calculate both these values during training. 

In [10]:
g = torch.Generator().manual_seed(23)

vector_dim = 2
C = torch.randn((len(unique_symbols), vector_dim),             generator = g)

hidden_layer_size = 100
W1 = torch.randn((block_size * vector_dim, hidden_layer_size), generator = g)
#b1 = torch.randn((hidden_layer_size,),                         generator = g)
W2 = torch.randn((hidden_layer_size, len(unique_symbols)),     generator = g)
b2 = torch.randn((len(unique_symbols),),                       generator = g)

## add new trainable params
bn_gain = torch.ones((1, hidden_layer_size))
bn_bias = torch.zeros((1, hidden_layer_size))

bn_mean_running = torch.zeros((1, hidden_layer_size)) # in the beggining means are roughly 0 
bn_std_running = torch.ones((1, hidden_layer_size))# and stds are roughly 1

params = [C, W1, W2, b2, bn_gain, bn_bias]
for p in params:
    p.requires_grad = True

n_iter = 50_000
losses_train = []
losses_val = []
for i in range(n_iter):

    batch_size = 64
    rand_indecies = torch.randint(0, X_train.shape[0], (batch_size, ), generator = g)

    # ---------------forward pass---------------

    emb = C[X_train[rand_indecies, ...]] # shape: [batch_size, block_size, vector_dim]
    h_preact = emb.view(-1, block_size * vector_dim) @ W1 #+ b1

    ## BATCH NORM
    h_preact_mean_batch = h_preact.mean(dim = 0, keepdim = True)
    h_preact_std_batch = h_preact.std(dim = 0, keepdim = True)
    
    with torch.no_grad():
        bn_mean_running = 0.999 * bn_mean_running + 0.001 * h_preact_mean_batch
        bn_std_running = 0.999 * bn_std_running + 0.001 * h_preact_std_batch
        
    h_preact_norm = (h_preact - h_preact_mean_batch) / h_preact_std_batch
    h_scaled = bn_gain * h_preact_norm + bn_bias
    
    h = torch.tanh(h_scaled)
    logits = h @ W2 + b2
    
    loss = F.cross_entropy(logits, y_train[rand_indecies])
    losses_train.append(loss.item())
    
    # ---------------backward pass---------------
    for p in params:
        p.grad = None
    
    loss.backward()
    
    # update
    lr = 0.1 if i < 30_000 else 0.01
    for p in params:
        p.data += -lr * p.grad

    if i % 10_000 == 0:
        print(f'{i}/{n_iter}: {loss:.4f}')

0/50000: 12.4702
10000/50000: 2.3545
20000/50000: 2.6113
30000/50000: 2.2846
40000/50000: 2.3991


In [11]:
# compare with the previous calculations (h_preact_mean)
h_preact_mean

tensor([[ 2.3351e+00,  1.2273e+00, -2.8906e+00,  3.0119e+00, -4.5581e-01,
         -2.8225e+00, -9.9405e-01,  4.6019e-01,  2.5960e+00,  4.0440e+00,
          3.5605e-01,  4.9826e-01,  2.5685e+00,  2.2120e+00, -1.8808e+00,
          2.3697e+00,  1.2427e+00, -3.1145e+00,  2.2045e+00, -1.3877e-01,
         -1.9311e+00, -2.9764e+00,  3.5898e+00, -3.4530e-01,  3.6254e+00,
         -1.4647e+00,  3.7733e+00,  1.4405e+00,  9.5061e-01, -2.9558e+00,
          1.2995e+00, -2.8438e+00, -2.6863e+00,  6.8943e-01, -4.4990e-01,
          3.2145e+00,  1.5520e+00,  2.5491e+00,  1.4607e+00,  1.2974e-01,
          1.6177e+00,  5.1302e-01,  1.4454e+00,  3.1738e+00, -5.6435e-02,
          2.9770e+00, -2.1937e+00, -9.7860e-01, -3.2245e+00,  3.5236e+00,
          1.0023e+00, -9.0487e-01, -7.2602e-01, -2.5360e+00, -9.0247e-01,
          1.4985e+00,  3.3589e+00,  3.4745e+00,  7.3291e-04,  6.8970e-01,
         -3.0321e+00,  2.7868e+00,  1.1086e+00, -3.0644e-01,  2.3946e+00,
         -1.3173e+00,  1.0737e+00,  3.

In [12]:
bn_mean_running

tensor([[ 3.1771e+00,  1.7308e+00, -2.4754e+00,  9.4514e-01, -9.6870e-01,
         -2.9707e+00, -1.4417e+00, -1.1528e+00,  2.2199e+00,  1.5296e+00,
          7.7523e-01, -1.8869e+00, -5.2390e-02,  9.1345e-01, -1.2627e+00,
          2.4405e+00,  2.0469e+00, -2.6711e+00,  9.5091e-01, -2.2353e-02,
         -1.4059e+00, -1.8211e+00,  1.7934e+00, -1.4754e+00,  1.1878e+00,
         -9.7796e-01,  6.7046e-01,  2.0725e+00,  1.9278e+00, -1.0178e+00,
          1.4137e-01, -1.2035e+00,  8.9007e-01, -9.7842e-01, -9.2034e-01,
          2.0171e+00,  1.2965e+00,  1.0301e+00,  6.2415e-01,  7.0876e-01,
         -1.2189e-01,  5.0645e-01,  8.0470e-01,  6.2983e-01,  6.2483e-02,
          1.8009e+00, -8.6541e-01, -2.4259e+00,  1.8633e+00,  1.8625e+00,
          6.2264e-01, -8.9219e-01, -8.1688e-01, -1.4510e+00, -1.3373e-01,
         -1.6592e+00,  2.8821e+00,  2.2213e+00, -8.2503e-01,  5.8001e-01,
         -1.2410e+00,  1.2783e+00, -7.9412e-01, -1.3878e+00, -5.3117e-01,
          1.1548e+00,  4.5306e-01, -7.

In [13]:
# prev std
h_preact_std

tensor([[2.9363, 1.7412, 3.9182, 2.8966, 2.6707, 3.2518, 1.7189, 2.2780, 3.1494,
         3.1384, 2.0034, 2.2811, 2.5869, 2.7062, 1.8783, 2.7020, 3.0690, 2.9032,
         2.9443, 1.9038, 2.7443, 2.6148, 2.7489, 2.6894, 3.4459, 2.2151, 2.8710,
         2.2993, 2.7605, 2.6212, 2.8322, 3.9083, 2.9984, 1.8093, 3.1868, 2.7740,
         3.4719, 2.2591, 2.9926, 1.9526, 2.0880, 2.4787, 3.0552, 3.1882, 2.2891,
         3.3970, 1.9313, 2.8811, 1.7758, 4.3282, 2.1878, 2.5333, 2.3449, 2.0383,
         3.3019, 2.3171, 3.6218, 3.5178, 1.6871, 2.6941, 2.6976, 2.5505, 2.9485,
         1.9415, 2.6900, 1.4374, 2.1330, 2.3417, 3.2730, 1.8795, 2.3437, 2.8062,
         3.9952, 2.3768, 2.0716, 2.4659, 2.9683, 2.7501, 3.4833, 2.1703, 2.4132,
         2.4932, 2.4337, 2.3812, 2.5581, 1.9627, 3.2943, 1.9530, 2.4780, 2.8287,
         2.1785, 3.3986, 2.2511, 2.8852, 3.3438, 2.1855, 4.6578, 2.7077, 2.9900,
         3.3061]])

In [14]:
bn_std_running

tensor([[3.5856, 2.4069, 4.3643, 3.0153, 2.8490, 3.7777, 2.7438, 2.7699, 3.5920,
         2.7576, 2.4465, 1.9967, 2.1019, 2.7547, 2.6853, 3.2136, 3.8016, 3.5354,
         2.8172, 2.2360, 2.9652, 2.0305, 2.9129, 3.2002, 2.9596, 2.4447, 2.6560,
         2.2383, 2.7925, 2.3356, 3.0917, 3.8815, 2.1013, 1.6304, 3.4772, 3.3397,
         3.4730, 2.3091, 3.4957, 1.8519, 1.2227, 2.3584, 3.3831, 2.9458, 2.5272,
         3.6508, 2.0844, 3.3121, 2.7163, 4.4587, 2.4109, 2.3644, 2.7221, 2.2900,
         3.4975, 3.1479, 4.1558, 3.9742, 2.3133, 2.6170, 2.6881, 2.1931, 3.4315,
         2.6275, 2.2457, 1.9830, 2.2481, 2.3449, 4.2280, 2.4107, 2.8586, 3.2970,
         4.4662, 2.1058, 2.5666, 2.8058, 3.4165, 2.2385, 3.5048, 2.2309, 2.9584,
         2.8954, 2.7225, 2.7810, 2.5457, 2.5616, 2.8408, 1.6021, 2.4618, 3.0471,
         2.9297, 3.9872, 2.0485, 3.3071, 3.4425, 1.7467, 5.2150, 2.6105, 3.3066,
         3.4044]])

In [15]:
# update eval func as well
@torch.no_grad()
def eval(split):
    x, y = {'train' : [X_train, y_train], 
            'val' : [X_val, y_val], 
            'test' : [X_test, y_test]}[split]
    
    emb = C[x]
    h_preact = emb.view(-1, block_size * vector_dim) @ W1 + b1
    h_preact_norm = (h_preact - bn_mean_running) / bn_std_running
    h_scaled = bn_gain * h_preact_norm + bn_bias
    h = torch.tanh(h_scaled)
    
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, y)
    return loss

In [16]:
eval('train')

tensor(3.4249)

In [17]:
eval('val')

tensor(3.4396)

In [18]:
# gradients of b1 tend to 0 because of batch norm (whaen we calculate mean and then subtract it we basically subtract this bias)
# so if we use batch norm bias becomes useless and we can skip it
b1.grad

tensor([-1.7462e-10, -6.9849e-10, -2.5029e-09, -6.9849e-10,  1.1642e-10,
         0.0000e+00,  5.2387e-10, -1.8626e-09,  6.7521e-09,  1.8626e-09,
        -9.3132e-10,  7.4506e-09, -6.5193e-09,  0.0000e+00,  2.7940e-09,
         1.8626e-09, -2.3283e-10,  5.8208e-10, -1.8626e-09, -9.3132e-10,
        -9.3132e-10, -3.7253e-09,  2.3283e-09, -6.0536e-09,  9.3132e-10,
         1.8626e-09,  6.9849e-10, -9.3132e-10,  1.8626e-09,  1.8626e-09,
        -4.6566e-09,  1.1642e-10,  1.3970e-09,  6.9849e-10, -4.6566e-10,
         2.7940e-09,  4.3656e-11,  9.3132e-10, -4.6566e-10, -1.8626e-09,
         5.5879e-09,  5.8208e-10,  9.3132e-10,  0.0000e+00, -1.0477e-09,
         5.5879e-09, -8.1491e-10,  3.2596e-09,  2.3283e-10, -2.9104e-10,
        -4.6566e-10,  0.0000e+00,  1.1642e-09,  4.6566e-10,  2.7649e-10,
         1.5716e-09, -4.6566e-10, -1.8626e-09, -1.1642e-09, -9.3132e-10,
        -1.8626e-09,  9.3132e-10, -3.7253e-09,  2.3283e-10, -1.8626e-09,
        -4.6566e-10,  9.3132e-10,  1.8626e-09,  1.8

In [19]:
b2.grad # for comparison

tensor([ 0.0012, -0.0330,  0.0163, -0.0162, -0.0298, -0.0261,  0.0201, -0.0042,
         0.0071,  0.0103,  0.0035, -0.0087,  0.0094,  0.0143,  0.0140,  0.0248,
        -0.0177,  0.0034,  0.0271, -0.0194, -0.0123,  0.0361,  0.0077, -0.0071,
         0.0037,  0.0264,  0.0096, -0.0605])