In [1]:
import torch
import numpy as np
from tqdm.notebook import tqdm

In [2]:
import matplotlib.pyplot as plt

import matplotlib

%config InlineBackend.figure_format = 'retina'
matplotlib.rcParams.update({
        "font.family": "serif",
       "font.serif": ["DejaVu Serif", "Bitstream Vera Serif", "Computer Modern Roman", "New Century Schoolbook", "Century Schoolbook L", "Utopia", "ITC Bookman", "Bookman", "Nimbus Roman No9 L", "Times New Roman", "Times", "Palatino", "Charter", "serif"],
        "axes.labelsize": 18,
        "font.size": 18,
        "legend.fontsize": 16,
        "xtick.labelsize": 18,
        "ytick.labelsize": 18,
})

In [3]:
def sgd_mom(W_0, data, b, lr=0.02, beta=0.9, n_iter=10000):
    W = W_0.clone().detach().requires_grad_(True)
    m, n = W.size()
    mom = torch.zeros(m, n)
    total_loss = ((W @ data)**2).sum() / n
    losses = [total_loss.item()]
    for t in range(1, n_iter+1):
        i = np.random.choice(n, (b,), replace=False)
        loss = ((W @ data[:, i])**2).sum() / b
        loss.backward()
        with torch.no_grad():
            mom = beta * mom + (1 - beta) * W.grad
            W -= lr * mom / (1 - beta**t)
            W.grad.zero_()
        total_loss = ((W @ data)**2).sum() / n
        losses.append(total_loss.item())
    return W, losses

def sgd_mom_galore(W_0, data, b, r, lr=0.1, beta=0.9, n_iter=10000, T=10):
    W = W_0.clone().detach().requires_grad_(True)
    m, n = W.size()
    U = None
    mom = torch.zeros(r, n)
    total_loss = ((W @ data)**2).sum() / n
    losses = [total_loss.item()]
    for t in range(1, n_iter+1):
        i = np.random.choice(n, (b,), replace=False)
        loss = ((W @ data[:, i])**2).sum() / b
        loss.backward()
        with torch.no_grad():
            if t % T == 1 or T == 1:
                U, _, _ = torch.svd(W.grad)
                U = U[:, :r]
            mom = beta * mom + (1 - beta) * (U.T @ W.grad)
            W -= lr / (1 - beta**t) * U @ mom 
            W.grad.zero_()
        total_loss = ((W @ data)**2).sum() / n
        losses.append(total_loss.item())
    return W, losses

def sgd_mom_proper_galore(W_0, data, b, r, lr=0.1, beta=0.9, n_iter=10000, T=10):
    W = W_0.clone().detach().requires_grad_(True)
    m, n = W.size()
    U = None
    mass = 0
    mom = torch.zeros(r, n)
    total_loss = ((W @ data)**2).sum() / n
    losses = [total_loss.item()]
    for t in range(1, n_iter+1):
        batch_i = np.random.choice(n, (b,), replace=False)
        loss = ((W @ data[:, batch_i])**2).sum() / b
        loss.backward()
        with torch.no_grad():
            if t % T == 1 or T == 1:
                old_mom = mom.clone()
                if t > 1:
                    mom = U @ mom # lifting to original space
                U, _, _ = torch.svd(W.grad)
                U = U[:, :r]
                if t > 1:
                    mom = U.T @ mom # projection to new basis
                # recalibrate mass
                if t > 1:
                    mass *=  mom.norm()/ old_mom.norm()
            mom = beta * mom + (1 - beta) * (U.T @ W.grad)
            mass = beta * mass + (1 - beta) * 1
            mom_hat = lr * mom / mass
            W -= U @ mom_hat
            W.grad.zero_()
        total_loss = ((W @ data)**2).sum() / n
        losses.append(total_loss.item())
    return W, losses


In [None]:
import matplotlib.pyplot as plt

b = 2
# r = 1
n_iter = 200
T = 10

m, n = 10, 10
seed = 1
torch.manual_seed(seed)
np.random.seed(seed)
W_0 = torch.randn(m, n)  # initialize W randomly

m, n = W_0.size()
data = torch.eye(n)

lr = 0.1

plt.figure(figsize=(20, 15))

for r in [3, 6]:    
    # 5 repetitions for averaging
    
    losses_sgd_galore = []
    losses_sgd_proper_galore = []

    for _ in range(5):
        _, loss = sgd_mom_galore(W_0, data, b, r, n_iter = n_iter, T = T, lr=lr)
        losses_sgd_galore.append(loss)
        _, loss = sgd_mom_proper_galore(W_0, data, b, r, n_iter = n_iter, T = T, lr=lr)
        losses_sgd_proper_galore.append(loss)

    # plot average and std
    losses_sgd_galore = np.array(losses_sgd_galore)
    losses_sgd_proper_galore = np.array(losses_sgd_proper_galore)

    plt.plot(losses_sgd_galore.mean(0), label=f'GaLore-like SGDM, rank {r}')
    plt.fill_between(np.arange(n_iter+1), losses_sgd_galore.mean(0) - losses_sgd_galore.std(0), losses_sgd_galore.mean(0) + losses_sgd_galore.std(0), alpha=0.2)
    plt.plot(losses_sgd_proper_galore.mean(0), label=f' GaLore-like SGDM with Momentum re-projection, rank {r}')
    plt.fill_between(np.arange(n_iter+1), losses_sgd_proper_galore.mean(0) - losses_sgd_proper_galore.std(0), losses_sgd_proper_galore.mean(0) + losses_sgd_proper_galore.std(0), alpha=0.2)

plt.legend(fontsize=24)
plt.xlabel('Iteration', fontsize=32)
plt.ylabel('Loss', fontsize=32)
plt.yscale('log')

plt.tight_layout()
plt.savefig(f'reproj_toy.pdf', bbox_inches='tight')