In [1]:
from fastai2.basics import *
from fastai2.optimizer import *

In [2]:
model = nn.Sequential(
    nn.Linear(5,100),
    nn.ReLU(inplace=True),
    nn.Linear(100,10),
    nn.ReLU(inplace=True),
    nn.Linear(10,100),
    nn.ReLU(inplace=True),
    nn.Linear(100,1000)
).cuda()
loss_func = nn.CrossEntropyLoss()

In [3]:
class OptimizedOptimizer(Optimizer):
    #Changing order to guarantee order of params in optimizer smallest to largest. 
    #This allows params to be cleared before getting to largest param update. 
    def all_params(self, n=slice(None), with_grad=False, sort_key=lambda p: np.prod(p[0].shape)):
        res = super().all_params(n=n, with_grad=with_grad)
        res.sort(sort_key)
        return res
    def zero_grad(self, clear=False):
        for p,*_ in self.all_params(with_grad=True):
            p.grad.detach_()
            p.grad.zero_()
            if(clear): del p.grad
def zero_grad(p, **kwargs):
    "Register the number of steps done in `state` for `p`"
    p.grad.detach_()
    p.grad.zero_()
    del p.grad
    return {}
def zero_Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., decouple_wd=True):
    "A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    cbs = [weight_decay]
    cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, adam_step,zero_grad]
    return OptimizedOptimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)
def clear_grads(params):
    for p in params:
        if(p.grad is not None):
            del p.grad

In [4]:
adam = torch.optim.Adam(model.parameters(),weight_decay=0.0001)

In [None]:
adam = Adam(model.parameters(),lr=0.001,wd=0.0001)

In [None]:
#2732544
adam = zero_Adam(model.parameters(),lr=0.001,wd=0.0001)

In [5]:
x = torch.randn([1,5]).cuda()

In [6]:
pred = model(x)
loss = loss_func(pred,torch.ones([1],dtype=torch.long).cuda())
loss.backward()
adam.step()
pred = model(x)
loss = loss_func(pred,torch.ones([1],dtype=torch.long).cuda())
loss.backward()
#clear_grads(model.parameters())
torch.cuda.memory_allocated(),torch.cuda.max_memory_allocated()
#(1253888, 2463232)

(1670144, 2463232)

In [None]:
pred = model(x)
loss = loss_func(pred,torch.ones([1],dtype=torch.long).cuda())
loss.backward()
adam.step()
pred = model(x)
loss = loss_func(pred,torch.ones([1],dtype=torch.long).cuda())
loss.backward()
adam.step()
adam.zero_grad(clear=False)
torch.cuda.memory_allocated(),torch.cuda.max_memory_allocated()
#(1670144, 2471424)
#(1253888, 2455040)

In [None]:
pred = model(x)
loss = loss_func(pred,torch.ones([1],dtype=torch.long).cuda())
loss.backward()
adam.step()
pred = model(x)
loss = loss_func(pred,torch.ones([1],dtype=torch.long).cuda())
loss.backward()
adam.step()
torch.cuda.memory_allocated(),torch.cuda.max_memory_allocated()
#(1670144, 2471424)
#(1253888, 2471424)
#(1253888, 2455040)
#(1670144, 2471424)

In [None]:
torch.cuda.max_memory_allocated()
#pytorch 25123840
#fastai 25119744

In [None]:
adam.step()
torch.cuda.memory_allocated(),torch.cuda.max_memory_allocated()
#(1465344, 2662400)
#(1706496, 2499584)

In [None]:
adam.zero_grad()
torch.cuda.memory_allocated(),torch.cuda.max_memory_allocated()

In [None]:
8196608/25123840

In [None]:
def weight_decay(p, lr, wd, do_wd=True, **kwargs):
    "Weight decay as decaying `p` with `lr*wd`"
    if do_wd and wd!=0: p.data.mul_(1 - lr*wd)

In [None]:
def average_grad(p, mom, dampening=False, grad_avg=None, **kwargs):
    "Keeps track of the avg grads of `p` in `state` with `mom`."
    if grad_avg is None: grad_avg = torch.zeros_like(p.grad.data)
    damp = 1-mom if dampening else 1.
    grad_avg.mul_(mom).add_(damp, p.grad.data)
    return {'grad_avg': grad_avg}

In [None]:
def average_sqr_grad(p, sqr_mom, dampening=True, sqr_avg=None, **kwargs):
    if sqr_avg is None: sqr_avg = torch.zeros_like(p.grad.data)
    damp = 1-sqr_mom if dampening else 1.
    sqr_avg.mul_(sqr_mom).addcmul_(damp, p.grad.data, p.grad.data)
    return {'sqr_avg': sqr_avg}

In [None]:
def step_stat(p, step=0, **kwargs):
    "Register the number of steps done in `state` for `p`"
    step += 1
    return {'step' : step}

In [None]:
def adam_step(p, lr, mom, step, sqr_mom, grad_avg, sqr_avg, eps, **kwargs):
    "Step for Adam with `lr` on `p`"
    debias1 = debias(mom,     1-mom,     step)
    debias2 = debias(sqr_mom, 1-sqr_mom, step)
    p.data.addcdiv_(-lr / debias1, grad_avg, (sqr_avg/debias2).sqrt() + eps)
    return p

In [None]:
def Adam(params, lr, mom=0.9, sqr_mom=0.99, eps=1e-5, wd=0., decouple_wd=True):
    "A `Optimizer` for Adam with `lr`, `mom`, `sqr_mom`, `eps` and `params`"
    cbs = [weight_decay]
    cbs += [partial(average_grad, dampening=True), average_sqr_grad, step_stat, adam_step]
    return Optimizer(params, cbs, lr=lr, mom=mom, sqr_mom=sqr_mom, eps=eps, wd=wd)

In [None]:
class Optimizer(_BaseOptimizer):
    "Base optimizer class for the fastai library, updating `params` with `cbs`"
    _keep_on_clear = ['force_train', 'do_wd']
    def __init__(self, params, cbs, train_bn=True, **defaults):
        params = L(params)
        self.cbs,self.state,self.train_bn = L(cbs),defaultdict(dict),train_bn
        defaults = merge(*self.cbs.attrgot('defaults'), defaults)
        self.param_groups = L(L(p) for p in params) if isinstance(params[0], (L,list)) else L([params])
        self.hypers = L({} for _ in range_of(self.param_groups))
        self.set_hypers(**defaults)
        self.frozen_idx = 0

    def zero_grad(self):
        for p,*_ in self.all_params(with_grad=True):
            p.grad.detach_()
            p.grad.zero_()

    def step(self):
        for p,pg,state,hyper in self.all_params(with_grad=True):
            for cb in self.cbs: state = _update(state, cb(p, **{**state, **hyper}))
            self.state[p] = state

    def clear_state(self):
        for p,pg,state,hyper in self.all_params():
            self.state[p] = {k: state[k] for k in self._keep_on_clear if k in state}

    def state_dict(self):
        state = [self.state[p] for p,*_ in self.all_params()]
        return {'state': state, 'hypers': self.hypers}

    def load_state_dict(self, sd):
        assert len(sd["hypers"]) == len(self.param_groups)
        assert len(sd["state"])  == sum([len(pg) for pg in self.param_groups])
        self.hypers = sd['hypers']
        self.state = {p: s for p,s in zip(self.all_params().itemgot(0), sd['state'])}

In [None]:
SortedOptimizer.all_params??

In [None]:
params = iter(sorted(list(model.parameters()),
                               key=lambda p: np.prod(p.shape)))

In [None]:
next(params)

In [None]:
fastai_Adam.all_params??

In [None]:
L??