In [2]:
import torch
from torch import nn, optim

In [3]:
class SGD(nn.Module):
  def __init__(self, params, lr, wd=0.1):
    super().__init__()
    self.params = list(params)
    self.lr, self.wd = lr, wd

  def step(self):
    with torch.no_grad():
      for p in self.params:
        self.reg_step(p)
        self.opt_step(p)

  def opt_step(self, p): p -= p.grad * self.lr

  def ref_step(self, p):
    if self.wd != 0: self.p *= 1 - self.lr * self.wd

  def zero_grad(self):
    for p in self.params: p.grad.data.zero_()

In [4]:
class Momentum(SGD):
  def __init__(self, params, lr, wd=0.1, mom=0.9):
    super().__init__(params, lr=lr, wd=wd)
    self.mom = mom

  def opt_step(self, p):
    if not hasattr(p, 'grad_avg'): p.grad_avg = torch.zeros_like(p.grad)
    p.grad_avg = p.grad_avg * self.mom + p.grad * (1 - self.mom)
    p -= self.lr * p.grad_avg