In [None]:
import numpy as np

class SGD:
    def __init__(self, model, lr:float):
        self.model = model
        self.lr = lr
        
    def update_params(self):
        for key in self.model.params.keys():
            self.model.params[key] -= self.lr * self.model.grads[key]
            
class Adam:
    def __init__(self, model, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8):
        self.model = model
        self.lr = lr
        self.beta1 = beta1
        self.beta2 = beta2
        self.eps = eps
        self.m = {}
        self.v = {}
        self.t = 0
        for key in self.model.params.keys():
            self.m[key] = np.zeros_like(self.model.params[key])
            self.v[key] = np.zeros_like(self.model.params[key])
        
    def update_params(self):
        self.t += 1
        for key in self.model.params.keys():
            self.m[key] = self.beta1 * self.m[key] + (1 - self.beta1) * self.model.grads[key]
            self.v[key] = self.beta2 * self.v[key] + (1 - self.beta2) * np.square(self.model.grads[key])
            m_pred = self.m[key] / (1 - self.beta1 ** self.t)
            v_pred = self.v[key] / (1 - self.beta2 ** self.t)
            self.model.params[key] -= self.lr * m_pred / (np.sqrt(v_pred) + self.eps)
