In [9]:
import torch
import torch.nn as nn

class SymmetricNet(nn.Module):
    def __init__(self):
        super(SymmetricNet, self).__init__()
        self.fc1 = nn.Linear(10, 10)
        self.fc2 = nn.Linear(10, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [16]:
def init_weights_symmetric(m):
    if type(m) == nn.Linear:
        torch.nn.init.constant_(m.weight, 0.0)
        m.bias.data.fill_(0.0)

net = SymmetricNet()
net.apply(init_weights_symmetric)

SymmetricNet(
  (fc1): Linear(in_features=10, out_features=10, bias=True)
  (fc2): Linear(in_features=10, out_features=1, bias=True)
)

In [17]:
for name, param in net.named_parameters():
    if param.requires_grad:
        print(name, param.data)

fc1.weight tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
fc1.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
fc2.weight tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
fc2.bias tensor([0.])


In [18]:
# Step 1: Create a dataset
inputs = torch.randn(100, 10)
targets = torch.randn(100, 1)

# Step 2: Define a loss function and an optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

# Step 3: Train the network
for epoch in range(100):  # loop over the dataset multiple times
    optimizer.zero_grad()  # zero the parameter gradients
    outputs = net(inputs)  # forward pass
    loss = criterion(outputs, targets)  # compute loss
    loss.backward()  # backward pass
    optimizer.step()  # update weights

# Step 4: Display the weights after training
for name, param in net.named_parameters():
    if param.requires_grad:
        print(name, param.data)

fc1.weight tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
fc1.bias tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
fc2.weight tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]])
fc2.bias tensor([-0.0237])


In [19]:
def break_symmetry(m):
    if type(m) == nn.Linear:
        m.weight.data += torch.randn(m.weight.size()) * 0.01

net.apply(break_symmetry)

SymmetricNet(
  (fc1): Linear(in_features=10, out_features=10, bias=True)
  (fc2): Linear(in_features=10, out_features=1, bias=True)
)

In [20]:
# Step 1: Create a dataset
inputs = torch.randn(100, 10)
targets = torch.randn(100, 1)

# Step 2: Define a loss function and an optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(net.parameters(), lr=0.01)

# Step 3: Train the network
for epoch in range(100):  # loop over the dataset multiple times
    optimizer.zero_grad()  # zero the parameter gradients
    outputs = net(inputs)  # forward pass
    loss = criterion(outputs, targets)  # compute loss
    loss.backward()  # backward pass
    optimizer.step()  # update weights

# Step 4: Display the weights after training
for name, param in net.named_parameters():
    if param.requires_grad:
        print(name, param.data)

fc1.weight tensor([[-0.0049, -0.0083,  0.0026,  0.0118, -0.0168, -0.0037,  0.0087,  0.0278,
         -0.0092,  0.0028],
        [-0.0061,  0.0022,  0.0040,  0.0141, -0.0109,  0.0092, -0.0142,  0.0014,
          0.0175, -0.0186],
        [ 0.0154,  0.0003, -0.0015,  0.0139, -0.0163, -0.0200,  0.0095,  0.0049,
         -0.0091,  0.0059],
        [ 0.0142, -0.0185,  0.0069,  0.0168,  0.0153, -0.0005,  0.0013,  0.0054,
          0.0022,  0.0024],
        [-0.0019,  0.0129,  0.0107, -0.0024,  0.0120,  0.0164, -0.0028, -0.0134,
          0.0121,  0.0049],
        [-0.0004, -0.0053, -0.0040, -0.0240, -0.0023, -0.0016, -0.0005,  0.0137,
          0.0024, -0.0041],
        [-0.0052,  0.0102,  0.0134, -0.0116, -0.0111, -0.0033, -0.0110, -0.0127,
         -0.0112,  0.0083],
        [ 0.0096, -0.0079, -0.0024, -0.0140, -0.0100, -0.0041,  0.0018, -0.0055,
         -0.0070,  0.0104],
        [ 0.0147, -0.0085, -0.0007,  0.0235,  0.0084,  0.0034,  0.0020, -0.0086,
         -0.0156,  0.0105],
        