# SP2 for GLM

In [None]:
def run_sp2_glm(train_data, train_target, train_dataloader, epochs, seed=0):

    torch.manual_seed(seed)

    # parameters
    w = torch.zeros(train_data.shape[1], device=device).requires_grad_()

    # save loss and grad size to history
    hist = []

    # define loss function
    loss_function = lf.logreg

    def logreg_a(w, X, y):
        r = torch.exp(-y * (X @ w))
        return (r/(1 + r)) * -y

    def logreg_h(w, X, y):
        r = torch.exp(-y * (X @ w))
        return (r/torch.square(1 + r)) 


    loss = loss_function(w, train_data.to(device), train_target.to(device))
    g, = torch.autograd.grad(loss, w, create_graph=True)
    f_grad = g.clone().detach() 
    

    for epoch in range(epochs):

        loss = loss_function(w, train_data.to(device), train_target.to(device))
        g, = torch.autograd.grad(loss, w, create_graph=True)
        grad_norm_sq = torch.linalg.norm(g) ** 2  
        acc = (np.sign(train_data @ w.detach().numpy()) == train_target).sum() / train_target.shape[0]

        print(f"[{epoch}/{epochs}] | Loss: {loss.item()} | GradNorm^2: {grad_norm_sq.item()} | Accuracy: {acc}")
        hist.append([loss.item(), grad_norm_sq.item(), acc])
            
        for i, (batch_data, batch_target) in enumerate(train_dataloader):  

            loss = loss_function(w, batch_data, batch_target)
            g, = torch.autograd.grad(loss, w, create_graph=True)
            f_grad = g.clone().detach()

            a = logreg_a(w, batch_data, batch_target)
            h = logreg_h(w, batch_data, batch_target)

            det = a*a - 2 * h * loss
            if det >= 0:
                step_size = (1 - (torch.sqrt(det) / torch.abs(a))).item()
            else:
                step_size = 1.0

            with torch.no_grad():
                w.sub_((a/h)*(batch_data.flatten()/torch.norm(batch_data)**2), alpha=step_size)


    return hist


# SP2+

In [None]:
def run_sp2(train_data, train_target, train_dataloader, epochs, lr=1.0):

    w_tp1 = torch.zeros(train_data.shape[1], device=device).requires_grad_()
    w_t = w_tp1 * 1.0

    # logging 
    hist = []

    for epoch in range(epochs):

        loss = loss_function(w_tp1, train_data.to(device), train_target.to(device))
        g, = torch.autograd.grad(loss, w_tp1, create_graph=True)
        acc = (np.sign(train_data @ w_tp1.detach().numpy()) == train_target).sum() / train_target.shape[0]
        print(f"[{epoch}/{epochs}] | Loss: {loss.item()} | GradNorm^2: {(torch.linalg.norm(g) ** 2 ).item()} | Acc: {acc}")
        hist.append([loss.item(), (torch.linalg.norm(g) ** 2).item(), acc])

        for i, (batch_data, batch_target) in enumerate(train_dataloader):
            batch_data = batch_data.to(device)
            batch_target = batch_target.to(device)

            closure  = lambda w: loss_function(w, batch_data, batch_target)
            wdiff = torch.sub(w_t, w_tp1)
            hessvgrad = torch.autograd.functional.hvp(closure, w_tp1, wdiff, create_graph=True)[1]
            with torch.no_grad():
                q = loss + torch.dot(g, wdiff) + 0.5 * torch.dot(wdiff, hessvgrad)
                nablaq = torch.add(g, hessvgrad)
                nablaqnorm = torch.norm(nablaq)
                if nablaqnorm < 1e-22:
                    break
                w_t = w_tp1 * 1.0
                w_tp1.sub_(nablaq, alpha = lr*q/nablaqnorm**2)

    return hist

def run_sp2plus(train_data, train_target, train_dataloader, epochs, lr=1.0):

    w = torch.zeros(train_data.shape[1], device=device).requires_grad_()
    # save loss and grad size to history
    hist = []
       
    for epoch in range(epochs):

        loss = loss_function(w, train_data.to(device), train_target.to(device))
        g, = torch.autograd.grad(loss, w, create_graph=True)
        grad_norm_sq = torch.linalg.norm(g) ** 2  
        acc = (np.sign(train_data @ w.detach().numpy()) == train_target).sum() / train_target.shape[0]

        print(f"[{epoch}/{epochs}] | Loss: {loss.item()} | GradNorm^2: {grad_norm_sq.item()} | Accuracy: {acc}")
        hist.append([loss.item(), grad_norm_sq.item(), acc])
            

        for i, (batch_data, batch_target) in enumerate(train_dataloader):

            batch_data = batch_data.to(device)
            batch_target = batch_target.to(device)

            loss = loss_function(w, batch_data, batch_target)
            g, = torch.autograd.grad(loss, w, create_graph=True)
            f_grad = g.clone().detach()

            loss_closure = lambda w: loss_function(w, batch_data, batch_target)
            hgp = torch.autograd.functional.hvp(loss_closure, w, g, create_graph=True)[1]

            with torch.no_grad():
                gnormsq = torch.norm(f_grad)**2
                sps_step = loss.item() / gnormsq
                w.sub_(sps_step * f_grad, alpha=lr)
                gdiffHgp = torch.sub(f_grad, hgp, alpha=sps_step)
                if torch.norm(gdiffHgp)**2 > 1e-10:
                        w.sub_(0.5 * (sps_step**2) * gdiffHgp * torch.dot(f_grad, gdiffHgp)/ (torch.norm(gdiffHgp)**2))
            

    return hist


In [None]:
def hvp_from_grad(grads_tuple, list_params, vec_tuple):
    # don't damage grads_tuple. Grads_tuple should be calculated with create_graph=True
    dot = 0.
    for grad, vec in zip(grads_tuple, vec_tuple):
        dot += grad.mul(vec).sum()
    return torch.autograd.grad(dot, list_params, retain_graph=True)[0]


def custom_sp2plus(train_data, train_target, train_dataloader, epochs, lr=1.0):

    w = torch.zeros(train_data.shape[1], device=device).requires_grad_()
    # save loss and grad size to history
    hist = []

    eps = 1e-8
       
    for epoch in range(epochs):

        loss = loss_function(w, train_data.to(device), train_target.to(device))
        g, = torch.autograd.grad(loss, w, create_graph=True)
        grad_norm_sq = torch.linalg.norm(g) ** 2  
        acc = (np.sign(train_data @ w.detach().numpy()) == train_target).sum() / train_target.shape[0]

        print(f"[{epoch}/{epochs}] | Loss: {loss.item()} | GradNorm^2: {grad_norm_sq.item()} | Accuracy: {acc}")
        hist.append([loss.item(), grad_norm_sq.item(), acc])
            

        for i, (batch_data, batch_target) in enumerate(train_dataloader):

            batch_data = batch_data.to(device)
            batch_target = batch_target.to(device)

            loss = loss_function(w, batch_data, batch_target)
            g, = torch.autograd.grad(loss, w, create_graph=True)
            f_grad = g.clone().detach()

            loss_closure = lambda w: loss_function(w, batch_data, batch_target)
            # hgp = torch.autograd.functional.hvp(loss_closure, w, g, create_graph=True)[1]
            hgp = torch.autograd.grad(g, w, grad_outputs=g, retain_graph=True)[0]

            # hgp = hvp_from_grad(list(g), w, list(g))
            # print(torch.norm(hgp - hgp2))
            
            grad_norm_sq = torch.dot(f_grad, f_grad)
            polyak = loss / (grad_norm_sq + eps)
            v = f_grad - (hgp * polyak)
            v_norm_sq = torch.dot(v, v)
            step = (polyak * f_grad) + (0.5 * polyak**2 * (torch.dot(hgp, f_grad) / (v_norm_sq + eps )) * v) 

            with torch.no_grad():
                w.sub_(step, alpha=lr)

    return hist

In [None]:
class SP2Plus(torch.optim.Optimizer):

    def __init__(
            self, 
            params,
            lr=1.0,
            eps=1e-8):
        
        defaults = dict(lr=lr, eps=eps)

        super().__init__(params, defaults)
        
    def step(self, closure=None):
        
        loss = None 
        if closure is not None:
            with torch.enable_grad():
                loss = closure()


        # loss_closure = lambda w: loss_function(w, batch_data, batch_target)
        # hgp = torch.autograd.functional.hvp(loss_closure, w, g, create_graph=True)[1]

        for group in self.param_groups:
            for p in group["params"]:
                grad_flat = torch.flatten(p.grad)
                p_flat = torch.flatten(p)
                eps = group["eps"]
                lr =group["lr"]
                hgp = torch.autograd.grad(grad_flat, p_flat, grad_outputs=grad_flat, retain_graph=True)[0]
                # hgp = torch.autograd.functional.hvp(closure, p, p.grad)[1]
                # hgp = hvp_from_grad(list(p.grad), p, list(p.grad))
                
                grad_norm_sq = torch.dot(grad_flat, grad_flat)
                polyak = loss / (grad_norm_sq + eps)
                v = grad_flat - (hgp * polyak)
                v_norm_sq = torch.dot(v, v)
                step = (polyak * grad_flat) + (0.5 * polyak**2 * (torch.dot(hgp, grad_flat) / (v_norm_sq + eps )) * v) 
                with torch.no_grad():
                    p.sub_(step, alpha=lr)


        return loss

In [None]:
class Custom(torch.optim.Optimizer):

    def __init__(
            self, 
            params,
            eps=1e-8):
        
        defaults = dict(eps=eps)

        super().__init__(params, defaults)
        
        self._update_precond_grad = self._update_precond_grad_cg

        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["v"] = torch.flatten(torch.zeros_like(p))

        self._step_t = 0
        

    def step(self, closure=None):
        
        loss = None 
        if closure is not None:
            with torch.enable_grad():
                loss = closure()

        self._step_t += 1
        self._update_precond_grad()
        self.update(loss=loss)
            
                          
        return loss 

    def update(self, loss):
        for group in self.param_groups: 
            for p in group['params']:
                state = self.state[p]
                precond_grad = state["precond_grad"]
                flat_grad = torch.flatten(p.grad.detach().clone())
                grad_norm_sq = torch.dot(flat_grad, precond_grad)
                eps = group['eps']
                if 2 * loss <= grad_norm_sq:    
                    det = 1 - 2 * (loss / (grad_norm_sq + eps))
                    if det < 0.0:
                        group["step_size"] = 1.0
                    else:
                        group["step_size"] = 1 - torch.sqrt(det).item()
                else:
                    group["step_size"] = 1.0

                # group["step_size"] = 0.01
                with torch.no_grad():
                    p.sub_(precond_grad.view_as(p), alpha=group['step_size'])

    def _update_precond_grad_identity(self):
        for group in self.param_groups:
            for p in group["params"]:
                self.state[p]["precond_grad"] = torch.flatten(p.grad)


    def _update_precond_grad_cg(self):

        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                p_flat = torch.flatten(p.detach().clone())
                grad_flat = torch.flatten(p.grad.detach().clone())

                s = torch.zeros_like(p_flat) # s = H_inv * grad
                r = grad_flat.clone()
                b = r.clone()
                MAX_ITER = p.shape[0] * 2

                for cg_step in range(MAX_ITER):
                    hvp = torch.flatten(torch.autograd.grad(p.grad, p, grad_outputs=b.view_as(p), retain_graph=True)[0])
                    alpha_k = torch.dot(r, r) / torch.dot(b, hvp)
                    s = s + alpha_k * b
                    r_prev = r.clone()
                    r = r - alpha_k * hvp
                    if torch.norm(r) < 1e-4:
                        # Ax = torch.autograd.grad(g, w, grad_outputs=s, retain_graph=True)[0]    
                        # diff = torch.norm(Ax - f_grad)
                        break

                    beta_k = torch.dot(r, r) / torch.dot(r_prev, r_prev)
                    b = r + beta_k * b

                state["precond_grad"] = s


    def _update_precond_grad_adagrad(self):
        for group in self.param_groups:
            for p in group["params"]:
                state = self.state[p]
                flat_grad = torch.flatten(p.grad.detach().clone())
                state["v"] = state["v"] + torch.square(flat_grad)
                precond = 1 / (torch.sqrt( state["v"]) + 1e-10)
                state["precond_grad"] = torch.mul(precond, flat_grad)
        


In [None]:
def rademacher_old(weights):
    return torch.round(torch.rand_like(weights)) * 2 - 1

def diag_estimate_old(weights, grad, iters):
    Ds = []
    for j in range(iters):
        z = rademacher_old(weights)
        with torch.no_grad():
            hvp = torch.autograd.grad(grad, weights, grad_outputs=z, retain_graph=True)[0]
        Ds.append((hvp*z))

    return torch.mean(torch.stack(Ds), 0)

def run_psps(data, target, dataloader, EPOCHS, precond_method="none", seed=0, scaling_vec=None):

    torch.manual_seed(seed)

    alpha=0.1
    beta=0.999
    eps = 1e-12

    w = torch.zeros(data.shape[1], device=device).requires_grad_()
    hist = []

    loss = loss_function(w, data, target)
    g, = torch.autograd.grad(loss, w, create_graph=True)

    if precond_method == "none":
        D = torch.eye(w.shape[0])
    elif precond_method == "hutch":
        Dk = diag_estimate_old(w, g, 100)
    elif precond_method == "cg":
        s = torch.zeros_like(w) # s = H_inv * grad
        f_grad = torch.zeros_like(w)
        r = torch.zeros_like(w)
        p = r.detach().clone()
        r_prev = torch.dot(r, r)
        MAX_ITER = 1000
    elif precond_method == "scaling_vec":
        D = torch.diag((1 / scaling_vec)**2)
    elif precond_method == "adam":
        D = torch.zeros_like(g)
        v = torch.zeros_like(g)
        step_t = torch.tensor(0.)
        betas = (0.9, 0.999)
    elif precond_method == "adagrad":
        D = torch.zeros_like(g)
        v = torch.zeros_like(g)

    for epoch in range(EPOCHS):
        
        loss = loss_function(w, data, target)
        g, = torch.autograd.grad(loss, w, create_graph=True)
        acc = (np.sign(data @ w.detach().numpy()) == target).sum() / target.shape[0]

        print(f"[{epoch}/{EPOCHS}] | Loss: {loss.item()} | GradNorm^2: {(torch.linalg.norm(g) ** 2 ).item()} | Accuracy: {acc}")
        hist.append([loss.item(), (torch.linalg.norm(g) ** 2 ).item(), acc])

        for i, (batch_data, batch_target) in enumerate(dataloader):
            loss = loss_function(w, batch_data, batch_target)
            g, = torch.autograd.grad(loss, w, create_graph=True)
            f_grad = g.clone().detach()

            if precond_method == "hess_diag":
                D = loss_hessian(w, batch_data, batch_target)
                D = torch.diag(1 / torch.diag(D))

            elif precond_method == "adam":
                step_t += 1
                v = betas[1] * v + (1 - betas[1]) * g.square()
                v_hat = v / (1 - torch.pow(betas[1], step_t))

                D = torch.diag(1 / (torch.sqrt(v_hat) + 1e-8))
                # D = torch.diag(1 / (v_hat + 1e-8))

            elif precond_method == "adagrad":
                v.add_(torch.square(g))
                D = torch.diag( 1 / (torch.sqrt(v) + 1e-10) )

            
            elif precond_method == "cg":
                # CG is here
                s = torch.zeros_like(w) # s = H_inv * grad
                r = f_grad.clone()
                p = r.detach().clone()

                for cg_step in range(MAX_ITER):
                    hvp = torch.autograd.grad(g, w, grad_outputs=p, retain_graph=True)[0]
                    alpha_k = torch.dot(r, r) / torch.dot(p, hvp)
                    s = s + alpha_k * p
                    r_prev = r.clone()
                    r = r - alpha_k * hvp
                    if torch.norm(r) < 1e-10:
                        Ax = torch.autograd.grad(g, w, grad_outputs=s, retain_graph=True)[0]    
                        diff = torch.norm(Ax - f_grad)
                        # print(f"Took {cg_step} to reach diff={diff}")
                        break

                    beta_k = torch.dot(r, r) / torch.dot(r_prev, r_prev)
                    p = r + beta_k * p

                gnorm = torch.dot(g, s)
                precond = (loss / (gnorm + eps))
            
                with torch.no_grad():
                    w.sub_(precond * s) 

                continue 

            elif precond_method == "hutch":
                vk = diag_estimate_old(w, g, 1)

                # Smoothing and Truncation 
                Dk = beta * Dk + (1 - beta) * vk
                Dk_hat = torch.abs(Dk)
                Dk_hat[Dk_hat < alpha] = alpha

                D = torch.diag(1 / Dk_hat)

            gnorm = g.dot(D @ g)
            precond = (loss / (gnorm)) * D
        
            with torch.no_grad():
                w.sub_(precond @ g)

    return hist