In [156]:
!which python 
!whoami 
!hostname
!pwd

In [157]:
import os
import sys
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

import numpy as np

import torch
from torch.utils.data import TensorDataset, DataLoader
import torch.utils.data as data_utils
from torch.optim import SGD, Adam, Adagrad, Adadelta

import matplotlib.pyplot as plt

import experiments.loss_functions as lf
from experiments.utils import get_dataset

import scipy 

torch.set_default_dtype(torch.float64)

%load_ext line_profiler

In [158]:
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('cpu')

In [159]:
import os 
from sklearn.datasets import load_svmlight_file

torch.manual_seed(0)
np.random.seed(0)

batch_size = 256
# dataset_name = "covtype.libsvm.binary" 
dataset_name = "mushrooms"
percentage = 1.0

EPOCHS = 1000
# train_dataloader, train_data, train_target = get_dataset(dataset_name, batch_size, scale_data)
datasets_path = os.getenv("LIBSVM_DIR")
trainX, trainY = load_svmlight_file(f"{datasets_path}/{dataset_name}")
sample = np.random.choice(trainX.shape[0], round(trainX.shape[0] * percentage), replace=False)

assert sample.shape == np.unique(sample).shape

trainX = trainX[sample]
trainY = trainY[sample]

train_data = torch.tensor(trainX.toarray())
train_target = torch.tensor(trainY)

train_load = TensorDataset(train_data, train_target)
train_dataloader = DataLoader(train_load, batch_size=batch_size, shuffle=False)


scale = 10
r1 = -scale
r2 = scale
scaling_vec = (r1 - r2) * torch.rand(train_data.shape[1]) + r2
scaling_vec = torch.pow(torch.e, scaling_vec)
train_data_scaled = scaling_vec * train_data

train_load_scaled = data_utils.TensorDataset(train_data_scaled, train_target)
train_dataloader_scaled = DataLoader(train_load_scaled, batch_size=batch_size, shuffle=True)

train = train_data, train_target, train_dataloader
train_scaled = train_data_scaled, train_target, train_dataloader_scaled

# loss_function = lf.logreg
# loss_grad = lf.grad_logreg
# loss_hessian = lf.hess_logreg

loss_function = lf.nllsq
# loss_grad = lf.grad_nllsq
# loss_hessian = lf.hess_nllsq

if loss_function == lf.logreg:
    train_target[train_target == train_target.unique()[0]] = torch.tensor(-1.0, dtype=torch.get_default_dtype())
    train_target[train_target == train_target.unique()[1]] = torch.tensor(1.0, dtype=torch.get_default_dtype())
    assert torch.equal(train_target.unique(), torch.tensor([-1.0, 1.0]))

elif loss_function == lf.nllsq:
    train_target[train_target == train_target.unique()[0]] = 0.0
    train_target[train_target == train_target.unique()[1]] = 1.0
    assert torch.equal(train_target.unique(), torch.tensor([0.0, 1.0]))

train_data.shape, (train_data.min(), train_data.max()), train_target.unique(), torch.linalg.cond(train_data), torch.linalg.cond(train_data_scaled)

In [160]:
np.random.seed(0)
n = 1000
d = 100
dataset_name = f"synthetic-regression-{n}x{d}"
modified = False
A = np.random.randn(n,d)

if modified:
    U, S, VH = np.linalg.svd(A)
    S *= 0.0
    S = np.asarray([1/((x+1)**2) for x in range(S.shape[0])])
    A = np.dot(U[:, :S.shape[0]] * S, VH)
    dataset_name += "-modified"

xopt = np.random.randn(d)
b = A @ xopt 
train_data = torch.Tensor(A)
train_target = torch.Tensor(b)
xopt = torch.Tensor(xopt)

batch_size = 1000
EPOCHS = 100

train_load = data_utils.TensorDataset(train_data, train_target)
train_dataloader = DataLoader(train_load, batch_size=batch_size, shuffle=False)


scale = 1
r1 = -scale
r2 = scale
scaling_vec = (r1 - r2) * torch.rand(train_data.shape[1]) + r2
scaling_vec = torch.pow(torch.e, scaling_vec)
train_data_scaled = scaling_vec * train_data

train_load_scaled = torch.utils.data.TensorDataset(train_data_scaled, train_target)
train_dataloader_scaled = torch.utils.data.DataLoader(train_load_scaled, batch_size=batch_size, shuffle=False)

train = [train_data, train_target, train_dataloader]
train_scaled = [train_data_scaled, train_target, train_dataloader_scaled]

loss_function = lf.mse
loss_grad = lf.grad_mse
loss_hessian = lf.hess_mse

train_data.shape, torch.linalg.cond(train_data), torch.linalg.cond(train_data_scaled)

In [161]:
torch.manual_seed(0)
np.random.seed(0)

n = 1000
d = 100

train_data = np.random.randn(n, d)
w_star = np.random.randn(d)

batch_size = 1000
EPOCHS = 100

# U, S, VH = np.linalg.svd(train_data)
# S *= 0.0
# S = np.asarray([1/((x+1)**2) for x in range(S.shape[0])])
# train_data = np.dot(U[:, :S.shape[0]] * S, VH)

dataset_name = f"synthetic-classification-{n}x{d}"

train_target = train_data @ w_star
train_target[train_target < 0.0] = 0.0 # -1.0
train_target[train_target > 0.0] = 1.0

train_data = torch.Tensor(train_data)
train_target = torch.Tensor(train_target)

train_load = TensorDataset(train_data, train_target)
train_dataloader = DataLoader(train_load, batch_size=batch_size, shuffle=False)


scale = 5
r1 = -scale
r2 = scale
scaling_vec = (r1 - r2) * torch.rand(train_data.shape[1]) + r2
scaling_vec = torch.pow(torch.e, scaling_vec)
train_data_scaled = scaling_vec * train_data

train_load_scaled = TensorDataset(train_data_scaled, train_target)
train_dataloader_scaled = DataLoader(train_load_scaled, batch_size=batch_size, shuffle=False)

train = train_data, train_target, train_dataloader
train_scaled = train_data_scaled, train_target, train_dataloader_scaled


loss_function = lf.nllsq
# loss_grad = lf.grad_logreg
# loss_hessian = lf.hess_logreg


if loss_function == lf.logreg:
    train_target[train_target == train_target.unique()[0]] = torch.tensor(-1.0, dtype=torch.get_default_dtype())
    train_target[train_target == train_target.unique()[1]] = torch.tensor(1.0, dtype=torch.get_default_dtype())
    assert torch.equal(train_target.unique(), torch.tensor([-1.0, 1.0]))

elif loss_function == lf.nllsq:
    train_target[train_target == train_target.unique()[0]] = 0.0
    train_target[train_target == train_target.unique()[1]] = 1.0
    assert torch.equal(train_target.unique(), torch.tensor([0.0, 1.0]))


train_data.shape, (train_data.min(), train_data.max()), train_target.unique(), torch.linalg.cond(train_data), torch.linalg.cond(train_data_scaled)

In [162]:
def run_optimizer(optimizer, dataset, EPOCHS, seed=0, **kwargs_optimizer):

    data, target, dataloader = dataset

    torch.manual_seed(seed)

    # parameters
    w = torch.zeros(data.shape[1], device=device).requires_grad_()
    opt = optimizer([w], **kwargs_optimizer)

    # logging 
    hist = []
    
    def compute_loss(w, data, target):
        loss = loss_function(w, data, target)
        loss.backward()
        return loss
    
    def compute_loss_graph(w, data, target):
        loss = loss_function(w, data, target)
        loss.backward(create_graph=True)
        return loss


    for epoch in range(EPOCHS):

        loss = loss_function(w, data.to(device), target.to(device))
        g, = torch.autograd.grad(loss, w, create_graph=True)
        acc = (np.sign(data @ w.detach().numpy()) == target).sum() / target.shape[0]
        print(f"[{epoch}/{EPOCHS}] | Loss: {loss.item()} | GradNorm^2: {(torch.linalg.norm(g) ** 2 ).item()} | Acc: {acc}")
        hist.append([loss.item(), (torch.linalg.norm(g) ** 2).item(), acc])

        for i, (batch_data, batch_target) in enumerate(dataloader):
            batch_data = batch_data.to(device)
            batch_target = batch_target.to(device)
            opt.zero_grad()
            # if isinstance(opt, Momo):
            #     closure = lambda: compute_loss(w, batch_data, batch_target)
            #     opt.step(closure=closure)
            # elif isinstance(opt, Custom):
            #     closure = lambda: compute_loss_graph(w, batch_data, batch_target)
            #     opt.step(closure=closure)
            # else:
            loss = compute_loss(w, batch_data, batch_target)
            opt.step()

    return hist



def save_results(result, dataset_name, percentage, scale, batch_size, epochs, loss_function_name, optimizer_name, lr, 
                 precond_method, pcg_method, hutch_init_iters, seed):
    
    results_path = os.getenv("RESULTS_DIR")
    directory = f"{results_path}/{dataset_name}/percentage_{percentage}/scale_{scale}/bs_{batch_size}" \
    f"/epochs_{epochs}/{loss_function_name}/{optimizer_name}/lr_{lr}/precond_{precond_method}/pcg_method_{pcg_method}/hutch_init_iters_{hutch_init_iters}/seed_{seed}"

    print(directory)
    if not os.path.exists(directory):
        os.makedirs(directory)
        
    torch.save([x[0] for x in result], f"{directory}/loss")
    torch.save([x[1] for x in result], f"{directory}/grad_norm_sq")
    torch.save([x[2] for x in result], f"{directory}/acc")

# PSPS2 Rank 1 Scaling
$
w^* = \arg\min_{w\in\mathbb{R} ^d}\frac{1}{2} \|w - w_t\|_{B_t} \nonumber \\
      \text{s.t.} \quad f_i(w_t) +  \langle  \nabla  f_i(w_t), w-w_t\rangle +\frac{1}{2}\langle B_t(w-w^t), w - w^t \rangle \leq 0 \\ 
      B_t = \frac{yy^T}{s^Ty} \\ 
      B_t^{+} = \frac{ss^T}{s^Ty} \\ 
      \text{where} \quad s = \Big ( \nabla^2  f_i(w_t) \Big ) ^{-1} \nabla f_i(w_t) \\ 
      y = \nabla^2 f_i(w_t) s = \nabla f_i(w_t) \\ 
      \text{Update Rule: } \\
      w_{t+1} = w_t - \frac{\alpha}{1 + \alpha} B_t^{+} \nabla f_i(w_t)
$

In [163]:
def rademacher_old(weights):
    return torch.round(torch.rand_like(weights)) * 2 - 1

def diag_estimate_old(weights, grad, iters):
    Ds = []
    for j in range(iters):
        z = rademacher_old(weights)
        with torch.no_grad():
            hvp = torch.autograd.grad(grad, weights, grad_outputs=z, retain_graph=True)[0]
        Ds.append((hvp*z))
    return torch.mean(torch.stack(Ds), 0)

def run_psps2(dataset, epochs, precond_method, pcg_method="none", seed=0, **kwargs):

    torch.manual_seed(seed)

    data, target, dataloader = dataset

    eps = kwargs.get("eps", 1e-6)

    # torch.manual_seed(seed)
    
    # parameters
    w = torch.zeros(data.shape[1], device=device).requires_grad_()

    # save loss and grad size to history
    hist = []

    opt = Adam([w], lr=0.1)
    

    opt.zero_grad()
    loss = loss_function(w, data.to(device), target.to(device))
    g, = torch.autograd.grad(loss, w, create_graph=True)
    f_grad = g.clone().detach() 


    if precond_method == "none":
        D = torch.ones_like(w)
    elif precond_method == "hutch":
        alpha=0.1
        beta=0.999
        init_iters = kwargs["hutch_init_iters"]
        Dk = diag_estimate_old(w, g, init_iters)
    elif precond_method == "pcg":
        MAX_ITER = train_data.shape[1] * 2

    elif precond_method == "scaling_vec":
        scaling_vec = kwargs["scaling_vec"]
        D = (1 / scaling_vec)**2
    elif precond_method == "adam" or precond_method == "adam_m":
        D = torch.zeros_like(g)
        v = torch.zeros_like(g)
        step_t = torch.tensor(0.)
        betas = (0.9, 0.999)
    elif precond_method == "adagrad" or precond_method == "adagrad_m":
        D = torch.zeros_like(g)
        v = torch.zeros_like(g)

    if pcg_method == "hutch":
        alpha=0.1
        beta=0.999
        init_iters = kwargs["hutch_init_iters"]
        Dk_pcg = diag_estimate_old(w, g, init_iters)
    elif pcg_method == "adam" or pcg_method == "adam_m":
        D_pcg = torch.zeros_like(g)
        v_pcg = torch.zeros_like(g)
        step_t_pcg = torch.tensor(0.)
        betas = (0.9, 0.999)
    elif pcg_method == "adagrad" or pcg_method == "adagrad_m":
        D_pcg = torch.zeros_like(g)
        v_pcg = torch.zeros_like(g)
    elif pcg_method == "none":
        D_pcg = torch.ones_like(g)

    cg_steps = []

    for epoch in range(epochs):

        opt.zero_grad()
        loss = loss_function(w, data.to(device), target.to(device))
        g, = torch.autograd.grad(loss, w, create_graph=True)
        grad_norm_sq = torch.linalg.norm(g) ** 2  
        acc = (np.sign(data @ w.detach().numpy()) == target).sum() / target.shape[0]

        print(f"[{epoch}/{epochs}] | Loss: {loss.item()} | GradNorm^2: {grad_norm_sq.item()} | Accuracy: {acc}")
        hist.append([loss.item(), grad_norm_sq.item(), acc])
           
        for i, (batch_data, batch_target) in enumerate(dataloader): 
            
            # opt.zero_grad()
            # loss = loss_function(w, data.to(device), target.to(device))
            # g, = torch.autograd.grad(loss, w, create_graph=True)
            # grad_norm_sq = torch.linalg.norm(g) ** 2  
            # print(f"[{epoch}/{epochs}] | Loss: {loss.item()} | GradNorm^2: {grad_norm_sq.item()}")

            opt.zero_grad()
            loss = loss_function(w, batch_data, batch_target)
            g, = torch.autograd.grad(loss, w, create_graph=True)
            f_grad = g.detach().clone()

            if precond_method == "hess_diag":
                hess = loss_hessian(w, batch_data, batch_target)
                # closure = lambda w: loss_function(w, batch_data, batch_target)
                # hess = torch.autograd.functional.hessian(closure, w)
                hess_diag_inv = 1 / torch.diag(hess)
                s = hess_diag_inv * f_grad

            elif precond_method == "true_hessian":
                closure = lambda w: loss_function(w, batch_data, batch_target)
                hess = torch.autograd.functional.hessian(closure, w)
                # hess = loss_hessian(w, batch_data, batch_target)
                s = torch.linalg.solve(hess, f_grad)
                # hess[hess <= 0.01] = 0.01
                # hess_inv = torch.linalg.inv(hess)
                # s = hess_inv @ f_grad

            elif precond_method == "scaling_vec":
                s = D * f_grad

            elif precond_method in ("adam", "adam_m"):
                step_t += 1
                v = betas[1] * v + (1 - betas[1]) * g.square()
                v_hat = v / (1 - torch.pow(betas[1], step_t))

                if precond_method == "adam":
                    D = 1 / (torch.sqrt(v_hat) + eps)
                else:
                    D = 1 / (v_hat + eps) 
                s = D * f_grad

            elif precond_method in ("adagrad", "adagrad_m"):
                v.add_(torch.square(g))
                if precond_method == "adagrad":
                    D = 1 / (torch.sqrt(v) + eps)
                else:
                    D = 1 / (v + eps)
                s = D * f_grad

            elif precond_method == "scipy_cg":
                A = scipy.sparse.csc_matrix(loss_hessian(w, batch_data, batch_target).detach().numpy())
                s, exit_code = scipy.sparse.linalg.cg(A, f_grad.numpy(), tol=1e-10)
                s = torch.tensor(s)

            elif precond_method == "none":
                s = D * f_grad

            elif precond_method == "hutch":
                vk = diag_estimate_old(w, g, 1)

                # Smoothing and Truncation 
                Dk = beta * Dk + (1 - beta) * vk
                Dk_hat = torch.abs(Dk)
                Dk_hat[Dk_hat < alpha] = alpha

                D = 1 / Dk_hat
                s = D * f_grad

            elif precond_method == "pcg":

                if pcg_method == "hutch":
                    vk_pcg = diag_estimate_old(w, g, 1)
                    # Smoothing and Truncation 
                    Dk_pcg = beta * Dk_pcg + (1 - beta) * vk_pcg
                    Dk_hat = torch.abs(Dk_pcg)
                    Dk_hat[Dk_hat < alpha] = alpha
                    D_pcg = 1 / Dk_hat

                elif pcg_method == "adam":
                    step_t_pcg += 1
                    v_pcg = betas[1] * v_pcg + (1 - betas[1]) * f_grad.square()
                    v_hat = v_pcg / (1 - torch.pow(betas[1], step_t_pcg))
                    # if pcg_method == "adam":
                    # D_pcg = 1 / (torch.sqrt(v_hat) + 1e-12)
                    # else:
                    D_pcg = 1 / (v_hat + 1e-6)

                elif pcg_method == "adagrad":
                    v_pcg.add_(f_grad.square())
                    # if pcg_method == "adagrad":
                    #     D_pcg = 1 / (torch.sqrt(v_pcg) + 1e-8)
                    # else:   
                    D_pcg = 1 / (v_pcg + 1e-6)


                hess_diag_inv = D_pcg.clone()
                # Preconditioned CG is here
                s = torch.zeros_like(w) # s = H_inv * grad
                r = f_grad.clone()
                z = hess_diag_inv * r
                p = z.detach().clone()

                for cg_step in range(MAX_ITER):
                    hvp = torch.autograd.grad(g, w, grad_outputs=p, retain_graph=True)[0]
                    alpha_k = torch.dot(r, z) / torch.dot(p, hvp)
                    # if torch.dot(p, hvp) <= 0:
                    #     gamma = 0.5
                    #     s = gamma * s * torch.sign(torch.dot(s, f_grad)) + (1 - gamma) * p * torch.sign(torch.dot(p ,f_grad))
                    #     # s = p.clone()
                    #     break

                    s = s + alpha_k * p
                    r_prev = r.clone()
                    r = r - alpha_k * hvp
                    r_hat = torch.autograd.grad(g, w, grad_outputs=r, retain_graph=True)[0] 
                    z_prev = z.clone()
                    z = hess_diag_inv * r
                    if torch.dot(r, r_hat) < 1e-14:
                        break
                    

                    beta_k = torch.dot(r, z) / torch.dot(r_prev, z_prev)
                    p = z + beta_k * p    

            grad_norm_sq_scaled = torch.dot(f_grad, s)

            if 2 * loss <= ( grad_norm_sq_scaled ):
                c = loss / ( grad_norm_sq_scaled )
                det = 1 - 2 * c
                step_size = 1 - torch.sqrt(det)
            else:
                step_size = 1.0

            # print(f"{loss=}")
            # print(f"{step_size=}")
            # print(f"{torch.norm(s)=}")
            with torch.no_grad():
                w.sub_(step_size * s)
            # opt.zero_grad()        
            # loss = loss_function(w, batch_data, batch_target)
            # print(loss)

    return hist

In [181]:
EPOCHS = 500

In [None]:
lr = 0.001
for opt in [Adam, Adagrad, Adadelta]:
    for seed in [0, 1, 2, 3, 4]:
        for t, s in zip([train, train_scaled], [0, scale]):
            hist = run_optimizer(opt, t, EPOCHS, seed=seed, lr=lr)
            save_results(hist, dataset_name, 1.0, s, batch_size, EPOCHS, loss_function.__name__, opt.__name__.lower(), lr, "none", "none", 1000, seed)


In [None]:
for precond in ["adagrad_m"]:
    for seed in [0, 1, 2, 3, 4]:
        for t, s, eps in zip([train, train_scaled], [0, scale], [1e-6, 1e-9]):
            hist = run_psps2(t, EPOCHS, precond, "none", seed, eps=eps)
            save_results(hist, dataset_name, 1.0, s, batch_size, EPOCHS, loss_function.__name__, "psps2", 1.0, precond, "none", 1000, seed)

In [165]:
for precond in ["adam_m"]:
    for seed in [0, 1, 2, 3, 4]:
        # for t, s, eps in zip([train, train_scaled], [0, scale], [1e-6, 1e-9]):
        t = train_scaled
        s = scale
        hist = run_psps2(t, EPOCHS, precond, "none", seed, eps=1e-8)
        save_results(hist, dataset_name, 1.0, s, batch_size, EPOCHS, loss_function.__name__, "psps2", 1.0, precond, "none", 1000, seed)

In [166]:
hist_adam = run_optimizer(Adam, train, EPOCHS, lr=0.001)
hist_adam_scaled = run_optimizer(Adam, train_scaled, EPOCHS, lr=0.001)

hist_adagrad = run_optimizer(Adagrad, train, EPOCHS, lr=0.001)
hist_adagrad_scaled = run_optimizer(Adagrad, train_scaled, EPOCHS, lr=0.001)

In [167]:
hist_psps2_adam = run_psps2(train, EPOCHS, "adam_m", "none", seed=3, eps=1e-6)

In [168]:
hist_psps2_adam_scaled = run_psps2(train_scaled, EPOCHS, "adam_m", "none", seed=3, eps=1e-8)

In [169]:
# plt.semilogy([x[0] for x in hist_pcg_adam], label="PSPS2 PCG Adam")
# plt.semilogy([x[0] for x in hist_pcg_adam_scaled], linestyle="--", label="PSPS2 PCG Adam")

# plt.semilogy([x[0] for x in hist_pcg_adagrad], label="PSPS2 PCG Adagrad")
# plt.semilogy([x[0] for x in hist_pcg_adagrad_scaled], linestyle="--", label="PSPS2 PCG Adagrad")

# plt.semilogy([x[0] for x in hist_psps2_newton], label="PSPS2 Newton")
# plt.semilogy([x[0] for x in hist_psps2_newton_scaled], linestyle="--", label="PSPS2 Newton")

plt.semilogy([x[0] for x in hist_psps2_adam], label="PSPS2 Adam")
plt.semilogy([x[0] for x in hist_psps2_adam_scaled], linestyle="--", label="PSPS2 Adam")

# plt.semilogy([x[0] for x in hist_psps2_adagrad], label="PSPS2 Adagrad")
# plt.semilogy([x[0] for x in hist_psps2_adagrad_scaled], linestyle="--", label="PSPS2 Adagrad")

plt.semilogy([x[0] for x in hist_adam], label="Adam")
plt.semilogy([x[0] for x in hist_adam_scaled], linestyle="--", label="Adam")

plt.semilogy([x[0] for x in hist_adagrad], label="Adagrad")
plt.semilogy([x[0] for x in hist_adagrad_scaled], linestyle="--", label="Adagrad")


# plt.ylim(bottom=1e-2)

plt.legend()
# plt.savefig(f"experiments/plots/testtest.jpeg", format="jpeg", dpi=200)

# Non-Convex PSPS2 CG

In [170]:
def rademacher_old(weights):
    return torch.round(torch.rand_like(weights)) * 2 - 1

def diag_estimate_old(weights, grad, iters):
    Ds = []
    for j in range(iters):
        z = rademacher_old(weights)
        with torch.no_grad():
            hvp = torch.autograd.grad(grad, weights, grad_outputs=z, retain_graph=True)[0]
        Ds.append((hvp*z))

    return torch.mean(torch.stack(Ds), 0)

def run_psps2_nc(train_data, train_target, train_dataloader, epochs, precond_method="cg", seed=0, **kwargs):

    torch.manual_seed(seed)

    # parameters
    w = torch.zeros(train_data.shape[1], device=device).requires_grad_()

    # save loss and grad size to history
    hist = []


    loss = loss_function(w, train_data.to(device), train_target.to(device))
    g, = torch.autograd.grad(loss, w, create_graph=True)
    f_grad = g.clone().detach()


    if precond_method == "cg" or precond_method == "pcg":
        s = torch.zeros_like(w) # s = H_inv * grad
        r = f_grad - torch.autograd.grad(g, w, grad_outputs=s, retain_graph=True)[0]
        p = r.detach().clone()
        r_prev = torch.dot(r, r)
        MAX_ITER = train_data.shape[1] * 2
        # MAX_ITER = 1000


    for epoch in range(epochs):

        loss = loss_function(w, train_data.to(device), train_target.to(device))
        g, = torch.autograd.grad(loss, w, create_graph=True)
        grad_norm_sq = torch.linalg.norm(g) ** 2
        acc = (np.sign(train_data @ w.detach().numpy()) == train_target).sum() / train_target.shape[0]

        print(f"[{epoch}/{epochs}] | Loss: {loss.item()} | GradNorm^2: {grad_norm_sq.item()} | Accuracy: {acc}")
        hist.append([loss.item(), grad_norm_sq.item(), acc])

        for i, (batch_data, batch_target) in enumerate(train_dataloader):

            loss = loss_function(w, batch_data, batch_target)
            g, = torch.autograd.grad(loss, w, create_graph=True)
            f_grad = g.detach().clone()

            # ssstep = 1.0

            if i % 64 == 0:
                loss = loss_function(w, train_data.to(device), train_target.to(device))
                g, = torch.autograd.grad(loss, w, create_graph=True)
                grad_norm_sq = torch.linalg.norm(g) ** 2
                print(f"[{epoch}][{i}] | Loss: {loss.item()} | GradNorm^2: {grad_norm_sq.item()} | Accuracy: {acc}")
                # hist.append([loss.item(), grad_norm_sq.item(), acc])

            if precond_method == "cg":


                # estimate the TR radius from Polyak-step-size
                trDelta = (loss / ( f_grad.dot(f_grad) )).item()
                # print("trDelta",trDelta)


                gamma=0.9

                # CG is here
                s = torch.zeros_like(w) # s = H_inv * grad
                z = torch.zeros_like(w)
                r = f_grad.clone()
                p = r.clone()
                # tt=f_grad.clone()
                for cg_step in range(MAX_ITER):
                    hvp = torch.autograd.grad(g, w, grad_outputs=p, retain_graph=True)[0]

                    # print(torch.dot(p,hvp))
                    if torch.dot(p,hvp)<=0: 
                        # print("NEGATIVE CURVATURE")
                        # if torch.dot(p,f_grad)<=0:
                        s=gamma*z*torch.sign(torch.dot(z,f_grad))+(1-gamma)*p*torch.sign(torch.dot(p,f_grad))
                        # else:
                            # s=-p
                        # hvs = torch.autograd.grad(g, w, grad_outputs=s, retain_graph=True)[0]
                        # ssstep=torch.dot(s,f_grad)/torch.dot(s,hvs)
                        step_size=torch.min(torch.tensor([loss/torch.dot(s,s),50]))
                        # print("*")
                        break



                    # print(">",cg_step)
                    # print(">",torch.dot(p,hvp))
                    alpha_k = torch.dot(r, r) / torch.dot(p, hvp)
                    z = z + alpha_k * p
                    r_prev = r.clone()
                    r = r - alpha_k * hvp
                    if torch.norm(r) < 1e-4:
                        s=z
                        # hvs = torch.autograd.grad(g, w, grad_outputs=s, retain_graph=True)[0]
    
                        # Ax = torch.autograd.grad(g, w, grad_outputs=s, retain_graph=True)[0]
                        # diff = torch.norm(Ax - f_grad)
                        # print(f"CG Took {cg_step} to reach diff={diff}")
                        # cg_steps.append(cg_step)
                        grad_norm_sq_scaled = torch.dot(f_grad, s)
                        if 2 * loss <= grad_norm_sq_scaled:
                            c = loss / ( grad_norm_sq_scaled )
                            det = 1 - 2 * c
                            if det < 0.0:
                                step_size = 1.0
                            else:
                                # print("**")
                                step_size = 1 - torch.sqrt(det)
                        else:
                            # print(f"[{epoch}, {i}] No solution")
                            # print("***")
                            step_size = 1.0
                        
                        
                        break

                    beta_k = torch.dot(r, r) / torch.dot(r_prev, r_prev)
                    p = r + beta_k * p

            step_size = step_size
            # print("step_size",step_size)
            # FB =  loss_function(w, train_data.to(device), train_target.to(device))
            # FBB =  loss_function(w, batch_data, batch_target)
            with torch.no_grad():
                w.sub_(step_size *s)
            # FA =  loss_function(w, train_data.to(device), train_target.to(device))
            # FAA =  loss_function(w, batch_data, batch_target)
            
            # print(FA - FB, " <<< 0", FAA - FBB)
            # if i > 50:
            #   return ''

    return hist


In [171]:
hist_cg = run_psps2_nc(train_data, train_target, train_dataloader, EPOCHS, precond_method="cg", seed=2)
hist_cg_scaled = run_psps2_nc(train_data_scaled, train_target, train_dataloader_scaled, EPOCHS, precond_method="cg", seed=2)

In [172]:
hist_adam = run_optimizer(Adam, train_data, train_target, train_dataloader, EPOCHS, lr=0.1)
hist_adam_scaled = run_optimizer(Adam, train_data_scaled, train_target, train_dataloader_scaled, EPOCHS, lr=0.01)

In [173]:
plt.semilogy([x[0] for x in hist_cg], label="CG")
plt.semilogy([x[0] for x in hist_adam], label="Adam")

plt.semilogy([x[0] for x in hist_cg_scaled], linestyle="--", label="CG")
plt.semilogy([x[0] for x in hist_adam_scaled], linestyle="--", label="Adam")

plt.legend()

# plt.savefig(f"experiments/plots/non-convex-cg_vs_adam-gamma-0_1.jpeg", format="jpeg", dpi=200)
