
# Problem 5: Weight Initialization and Training Dynamics (15 points)

### Q5.1: Implement Initialization Schemes (3 points)

In [2]:
import math
import torch

In [3]:
def initialize_weights(shape, method):
    """
    Args:
        shape: tuple of (fan_in, fan_out)
        method: 'zero', 'small_random', 'xavier', 'he'
    Returns:
        torch.Tensor of initialized weights
    """

    if len(shape) != 2:
        raise ValueError("Shape must be a tuple of (fan_in, fan_out)")

    fan_in, fan_out = shape


    if method == 'zero':
        return torch.zeros(shape)
    elif method == 'small_random':
        return torch.randn(shape) * 0.01
    
    elif method == 'xavier':
        sigma = math.sqrt(2 / (fan_in + fan_out))
        return torch.randn(shape) * sigma
    elif method == 'he':
        sigma = math.sqrt(2 / fan_in)
        return torch.randn(shape) * sigma
    else:
        raise ValueError("Unknown method")


In [4]:
print(initialize_weights((10, 10), 'zero'))

tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])


In [5]:
print(initialize_weights((10, 10), 'small_random'))

tensor([[-7.3728e-03,  2.0317e-02, -9.1637e-03, -7.7929e-03, -1.6863e-02,
          1.9904e-03, -2.0977e-03,  8.7249e-03,  4.7539e-03, -5.0325e-05],
        [ 2.1910e-02, -8.2438e-03,  1.8353e-03,  1.9136e-02,  2.1032e-02,
          2.6045e-02,  8.7373e-03,  4.9453e-03, -5.2550e-04, -9.2108e-03],
        [ 9.8186e-06,  1.1029e-02,  7.2815e-03,  9.2429e-03,  1.6524e-02,
         -2.0089e-02,  1.2951e-02,  1.1343e-02, -1.2502e-03,  1.1907e-02],
        [ 8.5119e-04,  1.0017e-02, -5.3816e-03, -4.1842e-03,  2.4173e-03,
          1.5439e-02, -1.2378e-03,  8.5695e-03, -8.1867e-03,  7.1069e-03],
        [-1.0179e-02, -9.4137e-03,  3.2180e-03,  7.6749e-03, -1.2479e-02,
         -1.1913e-03, -6.4261e-03,  1.7171e-02,  1.0931e-02,  4.0554e-04],
        [-5.6521e-03, -1.6655e-02,  8.7556e-03, -1.1121e-03,  8.4691e-03,
         -2.4977e-02, -1.4781e-02,  3.5186e-03, -3.5608e-03,  3.5375e-03],
        [ 2.7990e-03, -4.6826e-03, -2.6930e-02,  8.3600e-03, -2.8309e-03,
          3.4185e-03, -1.8099e-0

In [6]:
print(initialize_weights((10, 10), 'xavier'))

tensor([[-1.9522e-01,  4.6267e-01, -3.9077e-03,  3.1891e-01, -8.1778e-02,
         -1.0004e-01, -1.6565e-01, -7.0138e-02,  2.3415e-01, -4.3308e-01],
        [-4.9960e-01, -2.1628e-02,  3.5275e-01,  7.1366e-02,  1.9426e-01,
          1.8613e-02, -2.0105e-01,  4.0095e-01, -3.6415e-01, -5.7032e-01],
        [-7.4455e-01, -3.5765e-01,  9.0227e-02, -8.2679e-02,  6.4100e-02,
          4.1779e-01,  5.2513e-01, -4.5788e-02,  2.4235e-01,  8.4194e-02],
        [ 2.3923e-01,  1.2945e-01,  2.3440e-01, -6.0221e-01, -2.0571e-01,
          3.0181e-01, -6.6088e-01, -4.8119e-01, -3.7612e-01,  3.0595e-01],
        [ 6.0871e-01,  3.8558e-01, -4.7836e-01,  7.9041e-02,  3.2196e-01,
          7.2231e-01,  1.3444e-01, -2.1802e-01, -6.0108e-02, -2.9329e-01],
        [ 1.0758e-01,  4.0476e-01,  9.9169e-02, -4.6255e-02, -3.3340e-01,
         -4.6265e-01,  5.3194e-02, -2.7121e-01,  4.1252e-04, -2.3865e-01],
        [ 2.3185e-01, -6.1036e-02,  3.5239e-01,  4.2175e-01,  2.6114e-01,
          1.1305e-01, -4.8434e-0

In [7]:
print(initialize_weights((10, 10), 'he'))

tensor([[-0.4747,  0.2276,  0.6711, -0.0488,  0.3201,  0.2678, -0.2571,  0.7028,
         -0.1859, -0.0471],
        [-0.5096, -0.1429, -0.4816,  0.4405,  0.1535, -0.0602, -0.3839, -0.2299,
          0.0330,  0.1493],
        [-0.4028, -0.5051, -0.6254, -0.7087, -0.5010,  0.0422, -0.4240,  0.3202,
         -0.3374,  1.3728],
        [-0.4379,  0.1670,  0.0026,  0.0277,  0.4402, -0.0052,  0.4097,  0.5598,
          0.8233, -0.2221],
        [-0.8151,  0.1944,  0.3852,  0.7843,  0.2673,  0.8377,  0.2311, -0.1750,
         -0.2108,  0.8659],
        [-0.8962, -0.3528, -0.2630, -0.0697,  0.1602, -0.8806,  0.6894, -0.4559,
         -0.7360,  0.3275],
        [ 0.4181,  0.1856,  0.0891,  0.6988,  0.2787,  0.0607,  0.1608, -0.4034,
         -0.1173,  0.7348],
        [-0.2997, -0.2791, -0.5637,  0.0288,  0.5149, -0.1368,  0.0627, -0.1159,
         -0.1855,  0.1012],
        [ 0.1542, -0.6843,  0.9115, -0.0068,  0.1636, -0.3744,  0.0169,  0.1341,
         -0.2064,  0.9942],
        [-0.1578, -