# Learning Dynamics - Zero Initialization

In [None]:
import random

import numpy as np
import torch
from torch import Generator

from data import get_dataloader, seed_worker
from models import MNISTFNNModel, MNISTCNNModel
from train import Trainer

In [None]:
R_SEED = 4240

torch.manual_seed(R_SEED)
random.seed(R_SEED)
np.random.seed(R_SEED)

In [None]:
# !tensorboard --logdir results

# 1. MNIST dataset

### 1.1 FNN Model

In [None]:
FNN_EPOCHS = 1

In [None]:
generator = Generator()
generator.manual_seed(R_SEED)

train_loader = get_dataloader(train=True, batch_size=100, flatten=True, shuffle=False, num_workers=1,
                              worker_init_fn=seed_worker, generator=generator)

test_loader = get_dataloader(train=False, batch_size=100, flatten=True, shuffle=False)

In [None]:
model = MNISTFNNModel()

trainer = Trainer(model, model_name="MNIST_FNN")
trainer.train(train_loader, FNN_EPOCHS)
trainer.test(test_loader)

In [None]:
model = MNISTFNNModel()
model.zero_initialization("zero")

trainer = Trainer(model, model_name="MNIST_FNN_ZERO")
trainer.train(train_loader, FNN_EPOCHS)
trainer.test(test_loader)

In [None]:
model = MNISTFNNModel()
model.zero_initialization("uniform")

trainer = Trainer(model, model_name="MNIST_FNN_UNIFORM")
trainer.train(train_loader, FNN_EPOCHS)
trainer.test(test_loader)

In [None]:
model = MNISTFNNModel()
model.zero_initialization("normal")

trainer = Trainer(model, model_name="MNIST_FNN_NORMAL")
trainer.train(train_loader, FNN_EPOCHS)
trainer.test(test_loader)

# 1.2 CNN Model

In [None]:
CNN_EPOCHS = 1

In [None]:
generator = Generator()
generator.manual_seed(R_SEED)

train_loader = get_dataloader(train=True, batch_size=100, flatten=False, shuffle=False, num_workers=1,
                              worker_init_fn=seed_worker, generator=generator)

test_loader = get_dataloader(train=False, batch_size=100, flatten=False, shuffle=False)

In [None]:
model = MNISTCNNModel()

trainer = Trainer(model, model_name="MNIST_CNN")
trainer.train(train_loader, CNN_EPOCHS)
trainer.test(test_loader)

In [None]:
model = MNISTCNNModel()
model.zero_initialization("zero")

trainer = Trainer(model, model_name="MNIST_CNN_ZERO")
trainer.train(train_loader, CNN_EPOCHS)
trainer.test(test_loader)

In [None]:
model = MNISTCNNModel()
model.zero_initialization("uniform")

trainer = Trainer(model, model_name="MNIST_CNN_UNIFORM")
trainer.train(train_loader, CNN_EPOCHS)
trainer.test(test_loader)

In [None]:
model = MNISTCNNModel()
model.zero_initialization("normal")

trainer = Trainer(model, model_name="MNIST_CNN_NORMAL")
trainer.train(train_loader, CNN_EPOCHS)
trainer.test(test_loader)