diff --git a/src/main.py b/src/main.py index ea71b1b..383aafc 100644 --- a/src/main.py +++ b/src/main.py @@ -42,17 +42,16 @@ print('===> Building model...') model = Net().to(device) -model_params = model.parameters() l1_loss = nn.L1Loss() -optimizer = optim.Adamax(model_params, lr=0.001) +optimizer = optim.Adamax(model.parameters(), lr=0.001) # ---------------------------------------------------------------------- def detach_all(arg): """Wraps hidden states in new Variables, to detach them from their history.""" - if type(arg) == torch.Tensor: - arg.detach_() # Variable(arg.data) + if type(arg) == nn.Parameter: + arg.data.detach_() # Variable(arg.data) else: for v in arg: detach_all(v)