In [1]:
import random

import torch

In [2]:
class DynamicNet(torch.nn.Module):
    def __init__(self, n_in, n_hidden, n_out):
        super(DynamicNet, self).__init__()
        self.input_linear = torch.nn.Linear(n_in, n_hidden)
        self.middle_linear = torch.nn.Linear(n_hidden, n_hidden)
        self.output_linear = torch.nn.Linear(n_hidden, n_out)
        
    def forward(self, x):
        '''Randomly opt to utse the middle layer 0-3 times'''
        h_relu = self.input_linear(x).clamp(min=0)
        for _ in range(random.randint(0, 3)):
            h_relu = self.middle_linear(h_relu).clamp(min=0)
        y_pred = self.output_linear(h_relu)
        return y_pred

In [3]:
BATCH = 64
N_IN = 1000
N_HIDDEN = 100
N_OUT = 10
ETA = 1e-4
EPOCHS = 500

In [4]:
x = torch.randn(BATCH, N_IN)
y = torch.randn(BATCH, N_OUT)

In [5]:
model = DynamicNet(N_IN, N_HIDDEN, N_OUT)

In [6]:
criterion = torch.nn.MSELoss(reduction='sum')
optimizer = torch.optim.SGD(model.parameters(), lr=ETA, momentum=0.9)

In [7]:
for i in range(EPOCHS):
    y_pred = model(x)
    loss = criterion(y_pred, y)
    if i % 5 == 0:
        print('%3d: %10.5f' % (i, loss.item()))
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

  0:  680.01129
  5:  629.60736
 10:  576.00116
 15:  664.12860
 20:  637.16675
 25:  543.65039
 30:  359.29688
 35:  130.62375
 40:   69.90417
 45:   95.13435
 50:  187.87845
 55:  139.01768
 60:   62.25080
 65:   42.61739
 70:   81.49670
 75:   98.60233
 80:  129.66028
 85:   39.84469
 90:   41.74377
 95:   38.56695
100:   33.14603
105:   12.15540
110:   30.01736
115:    9.31846
120:   38.20812
125:    7.92883
130:    6.08496
135:    9.31918
140:    8.62426
145:    4.73302
150:   33.53943
155:    4.34914
160:    2.14262
165:    8.99930
170:    2.36260
175:    3.93586
180:    4.20734
185:    8.45644
190:    1.81362
195:    1.89961
200:    4.87334
205:    2.05342
210:    1.06519
215:    1.29791
220:    1.52088
225:    2.13221
230:    0.57581
235:    2.67284
240:   32.80236
245:   43.07402
250:   14.67170
255:   13.03360
260:   12.75906
265:   21.47322
270:    5.62084
275:    2.44912
280:    2.65326
285:    5.01941
290:    7.14621
295:    3.84539
300:   12.01002
305:    4.95244
310:    