In [None]:
#@title Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

In [None]:
# @title Imports
from scipy.optimize import linear_sum_assignment
from functools import partial
import jax
from jax import grad, jit, make_jaxpr, vmap, random, pmap
import jax.numpy as jnp
import numpy as np
from scipy.spatial.distance import cosine as cosine_distance
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import tensorflow_datasets as tfds

import time

from jax.example_libraries import optimizers

import tensorflow as tf
from six.moves import cPickle as pickle #for performance
import optax as optix
import dill

gfile = tf.io.gfile

In [None]:
#@title Optimizers

def nesterov(eta):
  return optix.chain(
      optix.trace(decay=0.9, nesterov=True),
      optix.scale(-eta))
  
def sgd(eta):
  return optix.chain(
      optix.trace(decay=0.0, nesterov=False),
      optix.scale(-eta))

In [None]:
#@title Alternative Updates

@partial(jax.jit, static_argnums=(5,6))
def gha_update(vi, weights, V, opt_state, X, eta=1e-4, opt=sgd):
    # vi is (d,)
    # weights is (k,)
    # V is (k, d), i.e., vectors on rows
    # X is (N, d)
    gs_ii = []
    gs_ij = []
    Xvi = jnp.dot(X, vi)
    XV = jnp.transpose(jnp.dot(X, jnp.transpose(V)))  # Xvj on row j
    for j in range(V.shape[0]):
      viMvj = jnp.dot(Xvi, XV[j])
      # vjMvj = jnp.dot(XV[j], XV[j])
      gs_ii += [jnp.dot(jnp.transpose(X), jnp.dot(X, V[j]))]  # only needed for gs_ii
      gs_ij += [viMvj * V[j]]
    weights_ii = (jnp.sign(weights - 0.5) + 1.) / 2.  # maps 1 to 1 else to 0
    weights_ij = (jnp.sign(weights + 0.5) - 1.) / 2.  # maps -1 to -1 else to 0
    weights_ij -= weights_ii
    gs_ii = jnp.dot(jnp.transpose(jnp.array(gs_ii)), weights_ii)
    gs_ij = jnp.dot(jnp.transpose(jnp.array(gs_ij)), weights_ij)
    # gs = gs_ii + gs_ij
    # grads = jnp.dot(jnp.transpose(X), jnp.dot(X, g_ii)) + gs_ij

    grads = gs_ii + gs_ij

    # grads = jnp.dot(jnp.transpose(X), jnp.dot(X, gs))
    # grads -= jnp.dot(grads, vi) * vi
    
    # This computes and applies updates with optix and updates the opt_state
    # updates, opt_state = nesterov(eta).update(-grads, opt_state)
    updates, opt_state = opt(eta).update(-grads, opt_state)
    vi_new = optix.apply_updates(vi, updates)

    vi_new /= jnp.clip(jnp.linalg.norm(vi_new), a_min=1.)
    # vi_new = jnp.clip(vi_new, a_min=-1., a_max=1.)
    return vi_new, opt_state

@partial(jax.jit, static_argnums=(5,6))
def ojas_deflation_update(vi, weights, V, opt_state, X, eta=1e-4, opt=sgd):
    # vi is (d,)
    # weights is (k,)
    # V is (k, d), i.e., vectors on rows
    # X is (N, d)
    gs_ii = []
    gs_ij = []
    Xvi = jnp.dot(X, vi)
    XV = jnp.transpose(jnp.dot(X, jnp.transpose(V)))  # Xvj on row j
    for j in range(V.shape[0]):
      vivj = jnp.dot(vi, V[j])
      # vjMvj = jnp.dot(XV[j], XV[j])
      gs_ii += [jnp.dot(jnp.transpose(X), XV[j])]  # only needed for gs_ii
      gs_ij += [vivj * V[j]]
    weights_ii = (jnp.sign(weights - 0.5) + 1.) / 2.  # maps 1 to 1 else to 0
    weights_ij = (jnp.sign(weights + 0.5) - 1.) / 2.  # maps -1 to -1 else to 0
    # weights_ij -= weights_ii
    gs_ii = jnp.dot(jnp.transpose(jnp.array(gs_ii)), weights_ii)
    gs_ij = jnp.dot(jnp.transpose(jnp.array(gs_ij)), weights_ij)
    # gs = gs_ii + gs_ij
    # grads = jnp.dot(jnp.transpose(X), jnp.dot(X, g_ii)) + gs_ij

    grads = gs_ii + gs_ij

    # grads = jnp.dot(jnp.transpose(X), jnp.dot(X, gs))
    # grads -= jnp.dot(grads, vi) * vi
    
    # This computes and applies updates with optix and updates the opt_state
    # updates, opt_state = nesterov(eta).update(-grads, opt_state)
    updates, opt_state = opt(eta).update(-grads, opt_state)
    vi_new = optix.apply_updates(vi, updates)

    vi_new /= jnp.clip(jnp.linalg.norm(vi_new), a_min=1.)
    # vi_new = jnp.clip(vi_new, a_min=-1., a_max=1.)
    return vi_new, opt_state

@partial(jit, static_argnums=(3,4))
def matrix_krasulinas_update(V, opt_state, X, eta=1e-4, opt=sgd):
    # vectors on rows of V
    grad = jnp.dot(jnp.transpose(X), jnp.dot(X, jnp.transpose(V)))
    grad -= jnp.dot(jnp.transpose(V), jnp.dot(V, grad))
    
    # This computes and applies updates with optix and updates the opt_state
    updates, opt_state = opt(eta).update(-jnp.transpose(grad), opt_state)
    Vstep = optix.apply_updates(V, updates)

    Q, R = jnp.linalg.qr(jnp.transpose(Vstep))
    signs = jnp.sign(jnp.sign(jnp.diag(R)) + .5)
    V_new = jnp.transpose(Q * signs)
    return V_new, opt_state

@partial(jit, static_argnums=(3,4))
def ojas_update(V, opt_state, X, eta=1e-4, opt=sgd):
    # X is minibatch of size (batch_size, d) and d is dimensionality
    # for mnist, d is 784.
    # V is minibatch of size (k, d) where k is num_eigenvectors
    # vectors on rows of V
    grad = jnp.dot(jnp.transpose(X), jnp.dot(X, jnp.transpose(V)))
    
    # This computes and applies updates with optix and updates the opt_state
    updates, opt_state = opt(eta).update(-jnp.transpose(grad), opt_state)
    Vstep = optix.apply_updates(V, updates)

    Q, R = jnp.linalg.qr(jnp.transpose(Vstep))
    signs = jnp.sign(jnp.sign(jnp.diag(R)) + .5)
    V_new = jnp.transpose(Q * signs)
    return V_new, opt_state

@partial(jit)
def sherman_morrison_woodbury(Apinv, u, v):
  # u and v should be 1-d vectors
  uvT = jnp.outer(u, v)
  Apinv_uvT_Apinv = jnp.dot(Apinv, jnp.dot(uvT, Apinv))
  c = jnp.dot(v, jnp.dot(Apinv, u))
  return Apinv - Apinv_uvT_Apinv / (1 + c)

@partial(jit, static_argnums=(3,4))
def implicit_matrix_krasulinas_update(V, Vpinv, X, eta=1e-4, opt=sgd):
    del opt
    # only works with mb_size = 1 !!!!
    # vectors on rows of V (k,d)
    yt = X
    C = jnp.transpose(V)
    Cpinv = jnp.transpose(Vpinv)

    xt = jnp.dot(Cpinv, yt)  # X (d,), Vpinv (d,k), xt (k,)
    rt = jnp.dot(C, xt) - yt  # rt (d,)
    etaxt = eta / (1 + eta * jnp.sum(xt**2.))  # scalar
    deltat = -etaxt * jnp.outer(rt, xt)  # deltat (d,k)

    Vdeltat = jnp.transpose(deltat)

    V_new = V + Vdeltat

    # Cpinv_new = (C - etaxt * outer(rt, xt))^{-1}
    Cpinv_new = sherman_morrison_woodbury(Cpinv, -etaxt * rt, xt)
    Vpinv_new = jnp.transpose(Cpinv_new)

    return V_new, Vpinv_new

In [None]:
#@title Data Class Util

class Data(object):
  def __init__(self, ds, mb_size, k, num_samples=None,
               one_device=False, shuffle_data=True, center=False, unit_var=False,
               svd_by_evd=False, data_stats_dict=None):

    self.mb_size = mb_size
    self.ds = ds

    self.k = k
    self.shuffle_data = shuffle_data
    self.one_device = one_device

    self.num_devices = jax.local_device_count()
    self.batch_dims = self.get_batch_dims(k, self.one_device)

    self.X = None
    self.generator = self.make_generator(self.ds, self.batch_dims,
                                         self.shuffle_data)

    if svd_by_evd:
      self.mean = 0
      Cov = 0
      batch_size = np.prod(self.batch_dims)
      num_batches, remainder = divmod(num_samples, batch_size)
      mean = 0
      if remainder > 0:
        print("Warning: mb_size does not evenly divide dataset")
        print("num_samples, num_batches, remainder",
              num_samples, num_batches, remainder)
      if center:
        for i in range(num_batches):
          if i % (num_batches // 10) == 0:
            print("{:d} / {:d}".format(i, num_batches), flush=True)
          xi = jnp.reshape(self.reshape_data(next(self.generator)),
                          (batch_size, -1))
          mean += xi / float(num_samples)
      for i in range(num_batches):
        if i % (num_batches // 10) == 0:
          print("{:d} / {:d}".format(i, num_batches), flush=True)
        xi = jnp.reshape(self.reshape_data(next(self.generator)),
                         (batch_size, -1))
        Cov += jnp.dot(jnp.transpose(xi - mean), xi - mean)
      self.mean = mean
      dims = Cov.shape[0]
      print("Computing principal components...", flush=True)
      Sigma2, V = np.linalg.eigh(np.array(Cov))  # returns in ascending order
      # change to descending order
      Sigma2 = Sigma2[::-1]
      V = V[:, ::-1]
      Sigma = np.sqrt(np.real(Sigma2))
      Vh = V.T
    else:
      X = self.load_from_ds(ds, num_samples).astype(np.float32)
      self.mean = np.zeros(X.shape[1])
      if center:
        self.mean = np.mean(X, axis=0)
        X -= self.mean
      self.std = np.ones(X.shape[1])
      if unit_var:
        self.std = np.std(X, axis=0)
        X /= self.std

      num_samples = X.shape[0]
      if len(X.shape[1:]) > 1:
        X = np.reshape(X, (num_samples, -1))
      dims = X.shape[1]
      print("Computing principal components...")
      _, Sigma, Vh = np.linalg.svd(X, full_matrices=False,
                                  compute_uv=True)
      Cov = np.dot(X.T, X) / X.shape[0]
      self.X = X

    self.num_samples = num_samples
    self.dims = dims
    self.Sigma = Sigma
    self.Vh = Vh
    self.Cov = Cov

  def load_from_ds(self, ds, num_samples):
    if isinstance(ds, np.ndarray):
      return np.reshape(ds, (ds.shape[0], -1))
    else:
      if num_samples is None:
        err_msg = ("num_samples must be passed to Data constructor "
                   "if using tfds.")
        raise ValueError(err_msg)
      X = next(self.make_generator(ds, [num_samples], shuffle_data=False))
      X = np.reshape(X['image'], (num_samples, -1)) / 255.
      return X

  def get_batch_dims(self, k, one_device=False):
    if one_device:
      self.k_per_device = k
      return [self.mb_size]
    else:
      self.k_per_device = k // self.num_devices
      err_msg = f"specify a k that num_devices={self.num_devices} divides evenly"
      if self.num_devices * self.k_per_device != k:
        raise ValueError(err_msg)
      return [self.num_devices, self.k_per_device, self.mb_size]

  def set_multi_device(self, multi):
    if self.one_device == multi:
      self.one_device = not self.one_device
      self.batch_dims = self.get_batch_dims(self.k, self.one_device)
      self.generator = self.make_generator(self.ds, self.batch_dims,
                                           self.shuffle_data)

  def set_k(self, k):
    self.k = k
    self.batch_dims = self.get_batch_dims(k, self.one_device)
    self.generator = self.make_generator(self.ds, self.batch_dims,
                                         self.shuffle_data)
  
  def set_mb_size(self, mb_size):
    self.mb_size = mb_size
    self.batch_dims = self.get_batch_dims(self.k, self.one_device)
    self.generator = self.make_generator(self.ds, self.batch_dims,
                                         self.shuffle_data)

  def make_generator(self, ds, batch_dims, shuffle_data, seed=None):
    """Loads the dataset as a generator of batches."""
    if ds is None:
      inds = np.arange(self.num_samples)
      if shuffle_data:
        np.random.seed(seed)
        np.random.shuffle(inds)
      num_batches = self.num_samples // self.mb_size
      num_samples_even = num_batches * self.mb_size
      X = np.reshape(self.X[inds[:num_samples_even]], [num_batches] + batch_dims + [-1])
      yield from np.split(X, num_batches, axis=0)
    else:
      total_batch_size = jnp.prod(jnp.array(batch_dims))
      if shuffle_data:
          ds = ds.shuffle(10 * total_batch_size, seed=seed)
      for batch_size in reversed(batch_dims):
        ds = ds.batch(batch_size)
      yield from tfds.as_numpy(ds)

  def reset_generator(self, seed=12345):
    self.generator = self.make_generator(self.ds, self.batch_dims, self.shuffle_data, seed)
  
  def reshape_data(self, data_input):
    # expected data_input shape = (num_devices, k_per_device, mb_size, *data_shape)
    if 'image' in data_input:
      data_input = jnp.array(data_input['image']) / 255.
    data_input_dims = len(data_input.shape)
    data_input = data_input.astype(np.float32)
    expected_dims = self.one_device * 2 + (1 - self.one_device) * 4
    if len(data_input.shape) > expected_dims:
      dim_size = np.prod(data_input.shape[expected_dims - 1:])
      shape = list(data_input.shape[:expected_dims - 1]) + [dim_size]
      return (jnp.reshape(data_input, shape) - self.mean) / self.std
    else:
      return (jnp.array(data_input) - self.mean) / self.std

In [None]:
#@title Metrics Class Util

def eval_to_class(ylim=None, bucket=None):
  def inner(evaluator):
    class Eval(object):
      def __init__(self, evaluator, ylim, bucket):
        self.evaluator = evaluator
        self.__str__ = evaluator.__name__
        self.ylim = ylim
        self.bucket = bucket
      
      def __call__(self, V, **kwargs):
        return self.evaluator(V, **kwargs)
    return Eval(evaluator, ylim, bucket)
  return inner

class Metrics(object):
  def __init__(self, evaluators, **kwargs):
    eval_names = [e.__str__ for e in evaluators]
    self.evaluators = dict(zip(eval_names, evaluators))
    self.match_eigvecs = kwargs["match_eigvecs"]  # not used
    self.fixed_kwargs = kwargs
    self.V_last = None

    if ("X" in kwargs and kwargs["X"] is not None) and ("Cov" in kwargs):
      X = kwargs["X"]
      Ustar = kwargs["Ustar"]
      top1 = kwargs["Vh"][0]
      Ustar1 = jnp.outer(top1, top1)
      min_recon_error = jnp.linalg.norm(X - jnp.dot(X, Ustar))
      top1_recon_error = jnp.linalg.norm(X - jnp.dot(X, Ustar1))
      min_compression_loss = jnp.trace(jnp.dot(jnp.eye(Ustar.shape[0]) - Ustar, kwargs["Cov"]))
      top1_compression_loss = jnp.trace(jnp.dot(jnp.eye(Ustar1.shape[0]) - Ustar1, kwargs["Cov"]))
      ev_error_scale = jnp.linalg.norm(Ustar1 - Ustar)

      self.fixed_kwargs.update({"min_recon_error": min_recon_error,
                                "top1_recon_error": top1_recon_error,
                                "min_compression_loss": min_compression_loss,
                                "top1_compression_loss": top1_compression_loss,
                                "ev_error_scale": ev_error_scale})
    
    empty_lists = [[] for _ in evaluators]
    self.progress = []
    self.record = dict(zip(eval_names, empty_lists))
  
  def evaluate(self, iteration, epoch, time, V):
    # V contains vectors on rows
    # Q, R = jnp.linalg.qr(jnp.transpose(V))
    # signs = jnp.sign(jnp.sign(jnp.diag(R)) + .5)
    # Vorth = jnp.transpose(Q * signs)
    # P = jnp.dot(Vorth.transpose(), Vorth)
    self.progress.append((iteration, epoch, time))
    Vpinv = jnp.array(np.linalg.pinv(np.array(V).T))
    P = jnp.dot(V.transpose(), Vpinv)
    for key, evaluator in self.evaluators.items():
      value = evaluator(V, P=P, V_last=self.V_last,
                        **self.fixed_kwargs)
      self.record[key].append(value)
    if 'norm_of_drift' in self.evaluators:
      self.V_last = V
  
  def update(self, key, value):
    if key in self.record:
      self.record[key].append(value)
    else:
      self.record[key] = [value]

def sine_squared(u, v):
  cosine = 1 - cosine_distance(u, v)
  return 1. - cosine**2.

@eval_to_class()
def compression_loss(V, P, Cov, min_compression_loss,
                     top1_compression_loss, **kwargs):
  del kwargs
  compress_loss = jnp.trace(jnp.dot(jnp.eye(P.shape[0]) - P, Cov))
  scale = top1_compression_loss - min_compression_loss
  return (compress_loss - min_compression_loss) / scale

@eval_to_class()
def recon_error(V, P, X, min_recon_error, top1_recon_error,
                **kwargs):
  del kwargs
  est_recon = jnp.dot(X, P)
  recon_err = jnp.linalg.norm(X - est_recon)
  scale = top1_recon_error - min_recon_error
  return (recon_err - min_recon_error) / scale

@eval_to_class()
def ev_error(V, Ustar, P, ev_error_scale, **kwargs):
  del kwargs
  k = V.shape[0]
  eigvec_err = jnp.linalg.norm(P - Ustar)
  return eigvec_err / ev_error_scale

@eval_to_class(ylim=[0,1])
def ev_error_to_v1(V, Vh, **kwargs):
  del kwargs
  eigvec_err = jnp.array([sine_squared(vi, Vh[0]) for vi in V])
  return eigvec_err

@eval_to_class(ylim=[0,90], bucket=4)
def individual_ev_error_bucket(V, Vh, match_eigvecs=False, **kwargs):
  del kwargs
  if match_eigvecs:
    errs = np.zeros((len(V), len(Vh)))
    for i, vi in enumerate(V):
      for j, vh in enumerate(Vh):
        errs[i, j] = sine_squared(vi, vh)
    row_ind, col_ind = linear_sum_assignment(errs)
    eigvec_err = errs[row_ind, col_ind]
  else:
    eigvec_err = jnp.array([sine_squared(vi, vh) for vi, vh in zip(V, Vh)])
  return jnp.arcsin(jnp.sqrt(eigvec_err)) * 180. / np.pi

@eval_to_class(ylim=[0,90], bucket=None)
def individual_ev_error_deg(V, Vh, match_eigvecs=False, **kwargs):
  del kwargs
  if match_eigvecs:
    errs = np.zeros((len(V), len(Vh)))
    for i, vi in enumerate(V):
      for j, vh in enumerate(Vh):
        errs[i, j] = sine_squared(vi, vh)
    row_ind, col_ind = linear_sum_assignment(errs)
    eigvec_err = errs[row_ind, col_ind]
  else:
    eigvec_err = jnp.array([sine_squared(vi, vh) for vi, vh in zip(V, Vh)])
  return jnp.arcsin(jnp.sqrt(eigvec_err)) * 180. / np.pi

@eval_to_class()
def longest_streak_of_correct_eigvecs(V, Vh, **kwargs):
  del kwargs
  eigvec_err = jnp.array([sine_squared(vi, vh) for vi, vh in zip(V, Vh)])
  radians = jnp.arcsin(jnp.sqrt(eigvec_err))
  learned = (radians <= np.pi / 8)
  num_learned = np.sum(np.cumsum(learned) >= np.arange(1, len(learned) + 1))
  return num_learned

@eval_to_class()
def neurips_loss(V, Ustar, P, **kwargs):
  del kwargs
  k = V.shape[0]
  neurips_loss = 1 - np.trace(Ustar.dot(P)) / float(k)
  return neurips_loss

@eval_to_class()
def ortho_error(V, **kwargs):
  del kwargs
  k = V.shape[0]
  sqrt_num_pairs = np.sqrt(k * float(k - 1.))
  orth_err = jnp.linalg.norm(jnp.dot(V, V.transpose()) - jnp.eye(V.shape[0]))
  return orth_err / sqrt_num_pairs

@eval_to_class()
def norm_of_drift(V, V_last, **kwargs):
  del kwargs
  if V_last is None:
    return np.nan * jnp.ones(V.shape[0])
  else:
    return jnp.linalg.norm(V - V_last, axis=1)

evaluators = [recon_error, ev_error, longest_streak_of_correct_eigvecs, individual_ev_error_bucket, individual_ev_error_deg, neurips_loss, compression_loss, ortho_error, ev_error_to_v1]
evaluators_short = [longest_streak_of_correct_eigvecs, individual_ev_error_bucket, individual_ev_error_deg, neurips_loss, norm_of_drift]

In [None]:
#@title QR Experiment Util

def qr_experiment(V0, k, lr, num_epochs, update,
                  optimizer, evaluators, data, match_eigvecs=False,
                  zero_mean=False, unit_variance=False):  
  Ustar = np.dot(data.Vh[:k].T, data.Vh[:k])
  metrics = Metrics(evaluators, X=data.X, Vh=data.Vh, Cov=data.Cov,
                    Ustar=Ustar, match_eigvecs=match_eigvecs)

  key = jax.random.PRNGKey(1234)
  V = jax.random.normal(key, (k, data.dims))
  V = V/jnp.linalg.norm(V, axis=1, keepdims=True)
  
  weights = get_weights(k).reshape(k, -1)
  opt_state = optimizer(lr).init(V)

  epochs_to_iters = (data.num_samples / data.mb_size)
  num_iters = int(num_epochs * epochs_to_iters)

  start = time.time()

  for i in range(num_iters + 1):
    # retrieve minibatch
    minibatch = data.reshape_data(next(data.generator))

    if i % (num_iters // 10) == 0:
      # evaluate current vectors
      epoch = i / epochs_to_iters
      print("# of epochs = {:.1f}, # of iterations = {:d}".format(epoch, i))

      metrics.evaluate(i, epoch, time.time() - start, V)

    # update vectors
    V, opt_state = update(V.reshape(k, -1), opt_state, minibatch, lr,
                          optimizer)

  return V.reshape((k, -1)), metrics

In [None]:
#@title Lissa Experiment Util

def lissa_experiment(V0, k, lr, num_epochs, update,
                  optimizer, evaluators, data, match_eigvecs=False,
                  zero_mean=False, unit_variance=False):  
  Ustar = np.dot(data.Vh[:k].T, data.Vh[:k])
  metrics = Metrics(evaluators, X=data.X, Vh=data.Vh, Cov=data.Cov,
                    Ustar=Ustar, match_eigvecs=match_eigvecs)

  key = jax.random.PRNGKey(1234)
  V = jax.random.normal(key, (k, data.dims))
  V = V/jnp.linalg.norm(V, axis=1, keepdims=True)
  
  weights = get_weights(k).reshape(k, -1)
  opt_state = optimizer(lr).init(V.T)

  epochs_to_iters = (data.num_samples / data.mb_size)
  num_iters = int(num_epochs * epochs_to_iters)

  start = time.time()
  key = jax.random.PRNGKey(42)

  for i in range(num_iters + 1):
    # retrieve minibatch
    # 32 x 784
    minibatch = data.reshape_data(next(data.generator))
    # 8 x 784
    # 1 x 784 --> fair comparisons

    # batch_size x 128 -- sample rows from minibatch
    # sampled_minibatch = np.random.choice()
    # 1 x 128
    # 16 x 784

    if i % (num_iters // 10) == 0:
      # evaluate current vectors
      epoch = i / epochs_to_iters
      print("# of epochs = {:.1f}, # of iterations = {:d}".format(epoch, i))

      metrics.evaluate(i, epoch, time.time() - start, V)

    # update vectors
    V, opt_state, key = update(V.reshape(k, -1), opt_state, minibatch, key, lr,
                          optimizer)

  return V.reshape((k, -1)), metrics


# @partial(jit, static_argnums=(4, 5, 6, 7))
def lissa_update(V, opt_state, X, key, eta=1e-4, opt=sgd,
                estimator='lissa', p=0.1):
    # (32 x 784)
    # (16 X 784)
    # X is minibatch of size (batch_size, d) and d is dimensionality
    # for mnist, d is 784.
    # V is minibatch of size (k, d) where k is num_eigenvectors
    # vectors on rows of V

    # N X D 
    # N states, D tasks
    # Output: N x d

    # 60000 X 784
    # 16 x 784
    # N = 784, D = 60000
    # ----- ----- ----- -----
    # Phi: N = 784, batch_size = 32
    # Psi: N = 784, d = 16

    # 784 x 60000 
    # 784 x 32
    # 2 x 1
    # num_rows = 3
    # 784 X 60000 <--

    # S x T
    # Original matrix: 784 x 60000
    # Sample rows: 784 x 2
    ##### Other methods #####
    # 784 x 1
    ## Our method can do 784 x 1 or smaller (128 x 1)

  

    key, subkey, subkey2 = jax.random.split(key, 3)
    # iteration1 = jax.random.randint(subkey2, shape=(), minval=1, maxval=10)
    iteration1 = 10
    print(iteration1)
    # iteration1 = jnp.ceil(jnp.log(jax.random.uniform(subkey2, shape=())) / jnp.log1p(-p))
    # iteration1 = iteration1.astype('int')
    k, d = V.shape
    Phi, opt_state, _ = nabla_phi_analytical( 
        Phi=V.T,  # (d x k) d=784, 
        Psi=X.T,  # (d x batch_size) d =784,
        key=subkey,
        optim=opt(eta),
        opt_state=opt_state,
        estimator=estimator,
        alpha=0.9,
        use_l2_reg=False,
        reg_coeff=False,
        use_penalty=False,
        # j=d,
        j=iteration1,
        num_rows=d, # This is to be changed.
    )
    V_new = Phi.T
    return V_new, opt_state, key

In [None]:
#@title Lissa code

import optax
import functools

def matrix_estimator(Phi, num_rows, key):  # pylint: disable=invalid-name
  r"""Computes an unbiased estimate of an input matrix.

  $\nu(s_i)^{-1}e_{s_i} \phi_{s_i}^\T$

  Args:
    Phi: S times d array
    num_rows: int: number of rows used in estimators
    key: prng key
  Returns:
    S times d array
  """
  S, _ = Phi.shape  # pylint: disable=invalid-name
  states = jax.random.randint(key, (num_rows,), 0, S)
  # states = jax.random.permutation(key, jnp.arange(S))[:num_rows]
  mask = jnp.zeros_like(Phi)
  mask = mask.at[states].set(1)
  return Phi * mask

@jax.jit
def _russian_roulette(Phi, states, coefficients, alpha):  # pylint: disable=invalid-name
  """Computes Russian roulette given fixed number of iterations."""
  S, d = Phi.shape  # pylint: disable=invalid-name
  I = jnp.eye(d)  # pylint: disable=invalid-name

  def _lissa_body(carry, state):
    lissa_j = alpha * I
    lissa_j += (
        I - alpha * S * jnp.einsum('i,j->ij', Phi[state], Phi[state])) @ carry
    return lissa_j, lissa_j

  lissa_init = alpha * I
  _, lissa_estimates = jax.lax.scan(_lissa_body, lissa_init, states)

  deltas = I - S * jnp.einsum('ni,nj,mjk->mik', Phi[states], Phi[states],
                              lissa_estimates)
  deltas *= coefficients.reshape(-1, 1, 1)

  return jnp.sum(deltas, axis=0)


def russian_roulette(Phi, p, key, coeff_alpha):  # pylint: disable=invalid-name
  """Computes the Russian roulette estimator from a LISSA sequence.

  Args:
    Phi: S times d array
    p: paramter of the bernoulli distribution
    key: prng key
    coeff_alpha: float
  Returns:
    array of shape d times d
  """
  S, _ = Phi.shape  # pylint: disable=invalid-name
  norm = jnp.linalg.norm(Phi.T @ Phi, ord=2)
  alpha = coeff_alpha * 1 / norm

  # Sample from geometric R.V. to get number of iterations
  key, subkey = jax.random.split(key)
  iterations = int(
      jnp.ceil(jnp.log(jax.random.uniform(subkey)) / jnp.log1p(-p)))

  # Sample states
  key, subkey = jax.random.split(key)
  states = jax.random.randint(subkey, (iterations,), 0, S)

  # Get delta coefficients
  coefficients = alpha / ((1 - p) ** jnp.arange(1, iterations + 1))

  return _russian_roulette(Phi, states, coefficients, alpha)

@functools.partial(jax.jit, static_argnums=(1, 2, 4, 5, 6, 7))
def lissa(Phi,  # pylint: disable=invalid-name
          j,
          num_rows,
          key,
          coeff_alpha,
          use_penalty=False,
          reg_coeff=0.0):
  """Computes the lissa estimator.

  Args:
    Phi: S times d array
    j: int, index of the lissa estimator
    num_rows: int: number of rows used in estimators
    key: prng key
    coeff_alpha: float
    use_penalty: bool: whether to add "lambda * Id" term to features
    reg_coeff: float: coeff for reg
    normalization: bool: whether to use S to normalize alpha

  Returns:
    d times d array
  """
  num_rows = 1
  S, d = Phi.shape  # pylint: disable=invalid-name
  I = jnp.eye(d)  # pylint: disable=invalid-name

  def _neumann_series(carry, state):
    A_j = alpha * I  # pylint: disable=invalid-name
    # pylint: disable=invalid-name
    if use_penalty:
      A_j += (I - alpha * (1 / num_rows) *
              (Phi[state, :].T @ Phi[state, :] + reg_coeff * I)) @ carry
    else:
      A_j += (I - alpha *
              (1 / num_rows) * Phi[state, :].T @ Phi[state, :]) @ carry
    return A_j, None

  _, subkey = jax.random.split(key)
  states = jax.random.randint(subkey, (j, 1), 0, S)
  # pdb.set_trace()
  norm = 2 * jnp.max(jnp.sum(jnp.square(Phi[states.reshape(j, )]), axis=1))
  # norm = jnp.linalg.norm(Phi[states.reshape(j, )].T @ Phi[states.reshape(j, )]  , ord=2) / num_rows
  # norm = jnp.linalg.norm(Phi.T @ Phi , ord=2) / num_rows
  alpha = coeff_alpha * 1 / norm
  # alpha = 2.
  lissa_init = alpha * I
  lisa_j, _ = jax.lax.scan(_neumann_series, lissa_init, states)
  return lisa_j

def least_square_estimator(Phi,  # pylint: disable=invalid-name
                           Psi,
                           num_rows,
                           j,
                           key,
                           estimator='lissa',
                           alpha=0.9,
                           use_penalty=False,
                           reg_coeff=0.0):  # pylint: disable=invalid-name
  r"""Computes an unbiased least squares estimate.

  $W^*_\Phi = (\Phi^T \Phi)^{-1} \Phi^T \Psi$

  Args:
    Phi: S times d array
    Psi: S times T array
    num_rows: int: number of rows used in estimators
    j: int: num of samples for lissa
    key: prng key
    estimator: str: russian_roulette, lissa, hat_w
    alpha: float: renormalize covariance term
    use_penalty: bool: whether to add "lambda * Id" term to features
    reg_coeff: float: coeff for reg

  Returns:
    array d times T
  """
  S, _ = Phi.shape  # pylint: disable=invalid-name
  key, subkey = jax.random.split(key)
  states = jax.random.randint(subkey, (num_rows,), 0, S)
  _, subkey = jax.random.split(key)
  if estimator == 'lissa':
    cov_estim = lissa(Phi, j, num_rows, subkey, alpha, use_penalty, reg_coeff)
  # cov_estim = jnp.linalg.solve(Phi.T @ Phi, jnp.eye(d))
  return cov_estim @ Phi[states, :].T @ Psi[
      states, :] / num_rows  # we use the same samples here

@functools.partial(jax.jit, static_argnums=(3, 5, 6, 7, 8, 9, 10, 11, 12))
def nabla_phi_analytical(  # pylint: disable=invalid-name
    Phi,
    Psi,
    key,
    optim,
    opt_state,
    estimator,
    alpha,
    use_l2_reg,
    reg_coeff,
    use_penalty,
    j,
    num_rows=1):

  r"""Computes unbiased estimate of 2 * (\Phi W^*_\Phi - \Psi)(W^*_\Phi)^T.

  Args:
    Phi: S times d array
    Psi: S times T array
    key: prng key
    optim: optax optimizer
    opt_state: optimizer initialization
    estimator: str: russian_roulette, lissa, hat_w
    alpha: float: used to nornalize covariance term
    use_l2_reg: bool: whether to use l2 reg
    reg_coeff: float: coeff for reg
    use_penalty: bool: whether to add "lambda * Id" term to features
    j: int: num of samples for lissa
    num_rows: int: number of rows used in estimators

  Returns:
    array S times d
  """
  key, subkey = jax.random.split(key)
  S, T = Psi.shape  # pylint: disable=invalid-name
  task = jax.random.randint(subkey, (1,), 0, T)
  key, subkey = jax.random.split(key)
  Phi_estim = matrix_estimator(Phi, num_rows, subkey)  # pylint: disable=invalid-name
  Psi_estim = matrix_estimator(Psi[:, task], num_rows, subkey)  # pylint: disable=invalid-name
  key, subkey = jax.random.split(key)
  least_square_estim_1 = least_square_estimator(Phi, Psi[:, task], num_rows, j,
                                                subkey, estimator, alpha,
                                                use_penalty, reg_coeff)

  key, subkey = jax.random.split(key)
  least_square_estim_2 = least_square_estimator(Phi, Psi[:, task], num_rows, j,
                                                subkey, estimator, alpha,
                                                use_penalty, reg_coeff)
  grads = (Phi_estim @ least_square_estim_1 -
           Psi_estim) @ least_square_estim_2.T
  if use_l2_reg:
    grads += Phi_estim * reg_coeff
  # if jnp.linalg.norm(grads) > 100000.0:
  #   import pdb
  #   pdb.set_trace()
  updates, opt_state = optim.update(grads, opt_state, Phi)
  # beta = 1 / (1 + 0.1 * epoch)
  beta = 1
  
  return optax.apply_updates(Phi, beta * updates), opt_state, grads

def train(Phi, Psi, num_epochs, learning_rate, key, estimator, alpha,  # pylint: disable=invalid-name
          optimizer, use_l2_reg, reg_coeff, use_penalty, j, num_rows):
  """Training function."""
  Phis = [Phi]  # pylint: disable=invalid-name
  grads = []
  if optimizer == 'sgd':
    optim = optax.sgd(learning_rate)
  elif optimizer == 'adam':
    optim = optax.adam(learning_rate)
  opt_state = optim.init(Phi)
  for _ in tqdm(range(num_epochs)):
    key, subkey = jax.random.split(key)
    Phi, opt_state, grad = nabla_phi_analytical(
        Phi, Psi, subkey, optim, opt_state, estimator, alpha, use_l2_reg,
        reg_coeff, use_penalty, j, num_rows)
    Phis.append(Phi)
    grads.append(grad)
  return jnp.stack(Phis), jnp.stack(grads)


# TODO 

Charline
1. Run with batch size of 1 (number of columns = 1) and compare with other methods -- 1 x 784 matrices 
2. Run with batch size of 1 and smaller number of rows (d = 128 or smaller) and show similar performance to 1.

In [None]:
#@title General Experiment Util

def init_V(num_devices, k_per_device, dims):
  keys = jax.random.split(jax.random.PRNGKey(1234), num_devices)
  V = jax.pmap(lambda key: jax.random.normal(key, (k_per_device, dims)))(keys)
  V = jax.pmap(lambda V: V/jnp.linalg.norm(V, axis=1, keepdims=True))(V)
  return V

def get_weights(k):
  num_devices = jax.local_device_count()
  k_per_device = k // num_devices
  weights = np.eye(k) * 2 - np.ones((k, k))
  weights[np.triu_indices(k, 1)] = 0.
  weights_jnp = jnp.array(weights)
  weights_jnp = jnp.reshape(weights_jnp, [num_devices, k_per_device, k])
  return weights_jnp

def run_sweep(experiment_runner, V0, k, lr, num_epochs, update,
              optimizer, evaluators, data, num_trials=10, match_eigvecs=False):
    sweep_record = dict()
    border = "*" * 20
    avg_time = 0.0
    data.reset_generator()
    for t in range(num_trials):
      print("\n{:s}TRIAL = {:d}{:s}\n".format(border, t, border))
      start = time.time()
      _, metrics = experiment_runner(
        V0, k, lr, num_epochs, update=update, optimizer=optimizer,
        evaluators=evaluators, data=data, match_eigvecs=match_eigvecs)
      end = time.time()
      avg_time += (end - start)
      for key, value in metrics.record.items():
        if key in sweep_record:
          sweep_record[key] += [value]
        else:
          sweep_record[key] = [value]
    avg_time /= float(num_trials)
    return metrics.progress, sweep_record, avg_time

def run_lr_sweep(data, V0, k, num_epochs, lrs, num_trials=10):
  data.set_multi_device(False)
  # names = ["Ojas", "EigenGame", "GHA", "EGGHA", "MatKras"]
  names = ["Ojas"]
  empties = [dict() for _ in names]
  xs_lrs = dict(zip(names, empties))

  for loglr in lrs:
    lr = 10.**float(loglr)
    print("#" * 50)
    print("LEARNING RATE = {:g}".format(lr))
    print("#" * 50)

    print("[START - 0] ojas" + "\t"*5)
    ojas_x, ojas_sweep, ojas_time = run_sweep(qr_experiment,
    V0, k, lr, num_epochs, update=ojas_update, optimizer=sgd,
    evaluators=evaluators_short, data=data, num_trials=num_trials)
    xs_lrs["Ojas"].update({str(loglr): [ojas_x, ojas_sweep, ojas_time]})
    print("[END - 0] ojas" + "\t"*5, ojas_time)

    # print("[START - 1] eigengame" + "\t"*5)
    # eigengame_x, eigengame_sweep, eigengame_time = run_sweep(game_experiment,
    #     V0, k, lr, num_epochs, update=update, optimizer=sgd,
    #     evaluators=evaluators_short, data=data, num_trials=num_trials)
    # xs_lrs["EigenGame"].update({str(loglr): [eigengame_x, eigengame_sweep, eigengame_time]})
    # print("[END - 1] eigengame" + "\t"*5, eigengame_time)

    # print("[START - 2] gha" + "\t"*5)
    # gha_x, gha_sweep, gha_time = run_sweep(game_experiment,
    # V0, k, lr, num_epochs, update=gha_update, optimizer=sgd,
    # evaluators=evaluators_short, data=data, num_trials=num_trials)
    # xs_lrs["GHA"].update({str(loglr): [gha_x, gha_sweep, gha_time]})
    # print("[END - 2] gha" + "\t"*5, gha_time)

    # print("[START - 3] eggha" + "\t"*5)
    # eggha_x, eggha_sweep, eggha_time = run_sweep(game_experiment,
    # V0, k, lr, num_epochs, update=eggha_update, optimizer=sgd,
    # evaluators=evaluators_short, data=data, num_trials=num_trials)
    # xs_lrs["EGGHA"].update({str(loglr): [eggha_x, eggha_sweep, eggha_time]})
    # print("[END - 3] eggha" + "\t"*5, eggha_time)

    # print("[START - 4] matkras" + "\t"*5)
    # matkras_x, matkras_sweep, matkras_time = run_sweep(qr_experiment,
    # V0, k, lr, num_epochs, update=matrix_krasulinas_update, optimizer=sgd,
    # evaluators=evaluators_short, data=data, num_trials=num_trials, match_eigvecs=True)
    # xs_lrs["MatKras"].update({str(loglr): [matkras_x, matkras_sweep, matkras_time]})
    # print("[END - 4] matkras" + "\t"*5, matkras_time)
  
  return xs_lrs

In [None]:
#@title Lissa Sweep

def run_lissa_sweep(data, V0, k, num_epochs, lrs, num_trials=10):
  data.set_multi_device(False)
  # names = ["Ojas", "EigenGame", "GHA", "EGGHA", "MatKras"]
  names = ["Lissa"]
  empties = [dict() for _ in names]
  xs_lrs = dict(zip(names, empties))

  for loglr in lrs:
    lr = 10.**float(loglr)
    print("#" * 50)
    print("LEARNING RATE = {:g}".format(lr))
    print("#" * 50)

    print("[START - 0] Lissa" + "\t"*5)
    ojas_x, ojas_sweep, ojas_time = run_sweep(lissa_experiment,
    V0, k, lr, num_epochs, update=lissa_update, optimizer=optax.adam,
    evaluators=evaluators_short, data=data, num_trials=num_trials)
    xs_lrs["Lissa"].update({str(loglr): [ojas_x, ojas_sweep, ojas_time]})
    print("[END - 0] Lissa" + "\t"*5, ojas_time)
  
  return xs_lrs


def run_rr_sweep(data, V0, k, num_epochs, lrs, num_trials=10):
  data.set_multi_device(False)
  # names = ["Ojas", "EigenGame", "GHA", "EGGHA", "MatKras"]
  names = ["RR"]
  empties = [dict() for _ in names]
  xs_lrs = dict(zip(names, empties))

  update = functools.partial(lissa_update, estimator='russian_roulette')

  for loglr in lrs:
    lr = 10.**float(loglr)
    print("#" * 50)
    print("LEARNING RATE = {:g}".format(lr))
    print("#" * 50)

    print("[START - 0] Lissa" + "\t"*5)
    ojas_x, ojas_sweep, ojas_time = run_sweep(lissa_experiment,
    V0, k, lr, num_epochs, update=update, optimizer=optax.adam,
    evaluators=evaluators_short, data=data, num_trials=num_trials)
    xs_lrs["RR"].update({str(loglr): [ojas_x, ojas_sweep, ojas_time]})
    print("[END - 0] Lissa" + "\t"*5, ojas_time)
  
  return xs_lrs

- *run_lr_sweep* returns a dictionary mapping algorithm names (hard coded within run_lr_sweep at top) to results dictionaries

- each algorithm's results dictionary contains learning rates as keys mapping to values [metrics.progress --> see Metrics class, sweep_record, avg runtime (scalar) over trials]

- sweep_record: dictionary mapping metric names to lists of values over training, e.g., xs_lrs['EGGHAgrad']['-4'][1]['individual_ev_error_deg'] is of shape num_trials x (num_epochs + 1) x k

In [None]:
ds = tfds.load("mnist:3.*.*", split="train").cache().repeat()
num_samples = 60000

In [None]:
# MNIST
k = 16
V0 = None
num_epochs = 10
mb_size = 256
mb_size_for_svd = 15
svd_by_evd = False

In [None]:
data = Data(ds, mb_size_for_svd, k, num_samples=num_samples, svd_by_evd=svd_by_evd, center=True, unit_var=False, one_device=True)
print("Computed principal components.")
data.set_k(k)
data.set_mb_size(mb_size)

In [None]:
# S X T
(784, 60000)

# S X D
(784, 16)

In [None]:
data.set_mb_size(60000)

In [None]:
num_epochs = 10000

In [None]:
lrs = [-1]  # specified in log_10 space
xs_lrs_lissa = run_lissa_sweep(data, V0, k, num_epochs, lrs, num_trials=1)

In [None]:
lrs = [-1]  # specified in log_10 space
xs_lrs_lissa = run_rr_sweep(data, V0, k, num_epochs, lrs, num_trials=1)

In [None]:
lrs = [-3]  # specified in log_10 space
xs_lrs_oja = run_lr_sweep(data, V0, k, num_epochs, lrs, num_trials=1)

In [None]:
xs_lrs_oja['Ojas']['-3'][1]

In [None]:
to_plot = xs_lrs_lissa['Lissa']['-1'][1]['neurips_loss'][0]
plt.plot(to_plot, color='blue')
to_plot = xs_lrs_oja['Ojas']['-3'][1]['neurips_loss'][0]
plt.plot(to_plot, color='red')

In [None]:
## Batch size

xs_lrs['Ojas']['-3'][1]

In [None]:
#@title Save and Collect Results

base_file_path = "/tmp/"

def save_data(data, name):
  def _disabled_seek(*_):
    raise AttributeError('seek() is disabled on this object.')
  filename = base_file_path + name + '.pkl'
  with gfile.GFile(filename, 'wb') as out_f:
    setattr(out_f, 'seek', _disabled_seek)
    pickle.dump(data, out_f)

def load_data(file_path):
  with gfile.GFile(file_path, 'rb') as fin:
    return dill.load(fin)

def collect_best_results(xs_lrs, metric, methods, sign=1.):
  print(methods)
  method_map = dict(zip(methods, np.arange(len(methods))))
  best_lrs = [np.nan for _ in methods]
  best_xs = [np.nan for _ in methods]
  best_sweeps = [np.nan for _ in methods]
  for method, xs_lr in xs_lrs.items():
    best_lr = -np.inf
    best_score = -np.inf
    for lr, data in xs_lr.items():
      score = np.mean(data[1][metric], axis=0)[-1]
      print("current best {} vs candidate {}".format(best_score, sign * score))
      if sign * score > best_score:
        best_score = sign * score
        best_lr = lr
    print("method_map[method] = {}".format(method_map[method]))
    print("best lr = {}".format(best_lr))
    best_lrs[method_map[method]] = best_lr
    best_xs[method_map[method]] = xs_lr[best_lr][0]
    best_sweeps[method_map[method]] = xs_lr[best_lr][1]
  return best_lrs, best_xs, best_sweeps

def collect_lr_results(xs_lrs, metric, methods, lrs):
  method_map = dict(zip(methods, np.arange(len(methods))))
  xs = [np.nan for _ in methods]
  sweeps = [np.nan for _ in methods]
  for method, xs_lr in xs_lrs.items():
    lr = lrs[method_map[method]]
    xs[method_map[method]] = xs_lr[lr][0]
    sweeps[method_map[method]] = xs_lr[lr][1]
  return lrs, xs, sweeps

In [None]:
save_data(xs_lrs, "mnist")

In [None]:
#@title Plotting Util

axis_label_fs = 18
tick_label_fs = 18
title_fs = 18
legend_fs = 18
lw = 4

def plot_statistic(ax, x, data, legend_label, color,
                   lw_mean, lw_fill, fill_alpha=0.5,
                   log_scale_y=False, legend_format="{:.1f}",
                   ls='-', skip=1):
  x = x[::skip]
  mean = np.mean(data, axis=0)[::skip]
  sem = np.std(data, axis=0)[::skip] / np.sqrt(len(data))
  # sem = np.std(data, axis=0)
  if legend_format is None:
    legend_label_with_val = legend_label
  else:
    legend_label_with_val = legend_label + ": " + legend_format.format(mean[-1])
  if log_scale_y:
    ax.semilogy(x, mean, color=color, ls=ls, lw=lw_mean, label=legend_label_with_val)
  else:
    ax.plot(x, mean, color=color, ls=ls, lw=lw_mean, label=legend_label_with_val)
  ax.fill_between(x, mean - sem, mean + sem, lw=lw_fill, color=color,
                  alpha=fill_alpha)

def set_plot_frame(ax, xlabel, ylabel, title, ylim, incl_legend=True,
                   label_locs=None, legend_labels=None,
                   legend_colors=None, excluded_methods=None,
                   num_epochs=50, mb_size=1024, num_samples=60000,
                   x2factor=1e6, x2label='millions', decimals=0,
                   plot_x_axis=True):
  ax.set_ylabel(ylabel, fontsize=axis_label_fs)
  ax.axes.tick_params(labelsize=tick_label_fs)
  ax.set_ylim(ylim)
  ax.set_title(title, fontsize=title_fs)
  if label_locs:
    for i, (label_loc, legend_label, color) in enumerate(zip(label_locs, legend_labels, legend_colors)):
      if i in excluded_methods:
        continue
      ax.text(*label_loc, legend_label, fontsize=int(legend_fs*.75),
              color=color,
              bbox=dict(boxstyle='round', facecolor='white',
                        edgecolor=color, alpha=0.75))
  elif incl_legend:
    # ax.legend(prop={"size": legend_fs})
    ax.legend(loc='center left', bbox_to_anchor=(1, 0.5),
              prop={"size": legend_fs})
  if not plot_x_axis:
    return
  ax.set_xlabel(xlabel, fontsize=axis_label_fs)
  ax2 = ax.twiny()
  plt.gcf().subplots_adjust(bottom=0.5)
  # Move twinned axis ticks and label from top to bottom
  ax2.xaxis.set_ticks_position("bottom")
  ax2.xaxis.set_label_position("bottom")
  # Offset the twin axis below the host
  ax2.spines["bottom"].set_position(("axes", -0.4))
  # Turn on the frame for the twin axis, but then hide all 
  # but the bottom spine
  ax2.set_frame_on(True)
  ax2.patch.set_visible(False)
  # as @ali14 pointed out, for python3, use this
  # for sp in ax2.spines.values():
  # and for python2, use this
  for sp in ax2.spines.values():
    sp.set_visible(False)
  ax2.spines["bottom"].set_visible(True)
  x = ax.get_xticks()
  ax2.set_xticks(x)
  conversion = num_epochs * num_samples / float(mb_size)
  if decimals == 0:
    ax2.set_xticklabels([int(xi * conversion / x2factor) for xi in x])
  else:
    ax2.set_xticklabels([np.round(xi * conversion / x2factor, decimals=decimals) for xi in x])
  ax2.set_xlim(ax.get_xlim())
  ax2.axes.tick_params(labelsize=tick_label_fs)
  ax2.set_xlabel("Iterations ({:s})".format(x2label), fontsize=axis_label_fs)

def plot_metrics(metrics, title, mean_instead=False,
                 separate_plots=False, inc_ylabel=True):
  num_metrics = len(metrics.record)
  if separate_plots:
    axs = [plt.subplots()[1] for _ in metrics.record.keys()]
  else:
    fig, axs = plt.subplots(1, num_metrics, figsize=(5 * num_metrics, 4))
    plt.suptitle(title, fontsize=title_fs)
  for i, (metric, data) in enumerate(metrics.record.items()):
    x = metrics.epochs
    ylabel = metric.replace("_", " ").title()
    if "Longest" in ylabel:
      ylabel = "Longest Correct\nEigenvector Streak"
      axs[i].set_ylim([0, 16])
    elif ylabel == "Neurips Loss":
      ylabel = "Normalized\nRank Coverage"
    if inc_ylabel:
      axs[i].set_ylabel(ylabel, fontsize=axis_label_fs)
    if separate_plots:
      axs[i].set_title(title, fontsize=title_fs)
    if metric in metrics.evaluators and metrics.evaluators[metric].ylim:
        axs[i].set_ylim(metrics.evaluators[metric].ylim)
    if len(data[0].shape) < 1:
      if mean_instead:
        axs[i].plot(x, np.mean(data) * np.ones_like(x), "-o",
                    label="Final: {:g}".format(data[-1]), lw=lw,
                    color="blue")
      else:
        axs[i].plot(x, data, "-o",
                    label="Final: {:g}".format(data[-1]), lw=lw,
                    color="blue")
      if metric in metrics.evaluators and metrics.evaluators[metric].ylim:
        axs[i].set_ylim(metrics.evaluators[metric].ylim)
      axs[i].legend(prop={"size": legend_fs})
    else:
      stacked = jnp.transpose(jnp.stack(data, axis=0))  # num vecs x num iters
      label = "V"
      if metric in metrics.evaluators:
        ylim = metrics.evaluators[metric].ylim
        bucket = metrics.evaluators[metric].bucket
        if ylim and bucket:
          if inc_ylabel:
            axs[i].set_ylabel("Percentiles")
          ptp = ylim[1] - ylim[0]
          width = float(ptp) / float(bucket)
          intervals = ylim[0] + jnp.cumsum(jnp.ones(bucket) * width)  # num buckets
          intervals = jnp.reshape(intervals, (1, 1, bucket))
          stacked = jnp.reshape(stacked, stacked.shape + (1,))
          bucketted = (stacked < intervals) * (stacked >= (intervals - width))
          stacked = jnp.transpose(jnp.mean(bucketted, axis=0))
          label = "%%"
          axs[i].set_ylim([0, 1])
      for j, datum in enumerate(stacked):
        color_idx = float(j) / stacked.shape[0]
        color = plt.get_cmap("gray")(color_idx)
        if mean_instead:
          axs[i].plot(x, np.mean(datum) * np.ones_like(x), "-o", color=color,
                      label=r"${:s}_{{{:d}}}$".format(label, j), lw=lw)
        else:
          axs[i].plot(x, datum, "-o", color=color,
                      label=r"${:s}_{{{:d}}}$".format(label, j), lw=lw)
      # axs[i].legend(loc='center left', bbox_to_anchor=(1, 0.5))
      # axs[i].legend(loc='upper center', bbox_to_anchor=(0.5, -0.05), ncol=8)
    axs[i].set_xlabel("Epochs", fontsize=axis_label_fs)
    axs[i].axes.tick_params(labelsize=tick_label_fs)
    axs[i].grid()

In [None]:
mb = mb_size
xs_lrs_loaded = load_data('<Specify file path here>')

In [None]:
xlabel = "Epochs"
title = "MNIST (Minibatch = {})".format(mb)
lw_mean = 4
lw_fill = 1
ls = '-'
incl_legend = True

x2factor = 1e3
x2label = 'thousands'
decimals = 0

methods = ["EigenGame", "GHA", "MatKras", "Ojas", "EGGHA"]
legend_loc_mapping = [0, 1, 2, 3, 4]
legend_labels_base = [r"$\alpha$-EG", "GHA", "Krasulinas", "Ojas", r"$\mu$-EG"]
colors = ["blue", "purple", "green", "black", "red"]
excluded_methods = []

included_methods = [i for i in range(len(methods)) if i not in excluded_methods]
colors = [colors[i] for i in included_methods]

use_lrs = ['-4', '-4', '-4', '-4', '-4']

key = "longest_streak_of_correct_eigvecs"
best_lrs, xs, sweeps = collect_best_results(xs_lrs_loaded, key, methods, sign=1.)
print("best_lrs streak", best_lrs)
best_lrs, xs, sweeps = collect_lr_results(xs_lrs_loaded, key, methods, use_lrs)
legend_labels = []
for idx, (label_base, method, lr) in enumerate(zip(legend_labels_base, methods, use_lrs)):
  if idx not in excluded_methods:
    if not isinstance(xs_lrs_loaded[method][lr][-1], float):
      legend_labels += [label_base]
      continue
    legend_labels += [label_base + " ({:.0f})".format(xs_lrs_loaded[method][lr][-1])]
ylabel = "Longest Correct\nEigenvector Streak"

ylim = [0, 17]
fig, axs = plt.subplots(nrows=2, ncols=1, sharex=True, figsize=(6, 6))
axs[0].set_yticks([0, 8, 16])
axs[0].set_yticklabels([0, 8, 16])


label_locs = [(-1,1), (-1,1), (-1,1), (-1,1), (-1,1)]
if legend_loc_mapping:
  label_locs = [label_locs[llm] for llm in legend_loc_mapping]

xs = [xs[i] for i in included_methods]
sweeps = [sweeps[i] for i in included_methods]
for i, (color, legend_label, x, sweep) in enumerate(zip(colors, legend_labels, xs, sweeps)):
  plot_statistic(axs[0], np.array(x)[:,1], sweep[key], legend_label,
                 color, lw_mean, lw_fill, fill_alpha=0.5,
                 legend_format=None, ls=ls)

set_plot_frame(axs[0], xlabel, ylabel, title, ylim, incl_legend=True,
               label_locs=[label_locs[i] for i in included_methods],
               legend_labels=legend_labels,
               legend_colors=colors, excluded_methods=[],
               num_epochs=num_epochs, mb_size=mb, num_samples=num_samples,
               x2factor=x2factor, x2label=x2label, decimals=decimals,
               plot_x_axis=False)
fig.tight_layout(rect=[0, 0.02, 1, 1])
path = "ls_k16_mb.pdf"
fig.savefig(path, dpi=300)
%download_file ls_k16_mb.pdf

key = "neurips_loss"
best_lrs, xs, sweeps = collect_best_results(xs_lrs_loaded, key, methods, sign=-1.)
print("best_lrs subspace_dist", best_lrs)
best_lrs, xs, sweeps = collect_lr_results(xs_lrs_loaded, key, methods, use_lrs)
ylabel = "Subspace Distance"
ylim = [0, 1]


label_locs = [(-1,10**(-1)), (-1,10**(-1)), (-1,10**(-1)), (-1,10**(-1)), (-1,10**(-1))]
if legend_loc_mapping:
  label_locs = [label_locs[llm] for llm in legend_loc_mapping]

xs = [xs[i] for i in included_methods]
sweeps = [sweeps[i] for i in included_methods]

for i, (color, legend_label, x, sweep) in enumerate(zip(colors, legend_labels, xs, sweeps)):
  if color == "gray":
    lsi = '--'
  else:
    lsi = ls
  plot_statistic(axs[1], np.array(x)[:,1], sweep[key], legend_label,
                 color, lw_mean, lw_fill, fill_alpha=0.5,
                 log_scale_y=True, legend_format=None, ls=lsi)

set_plot_frame(axs[1], xlabel, ylabel, "", ylim, incl_legend=incl_legend,
               label_locs=[label_locs[i] for i in included_methods],
               legend_labels=legend_labels,
               legend_colors=colors, excluded_methods=[],
               num_epochs=num_epochs, mb_size=mb, num_samples=num_samples,
               x2factor=x2factor, x2label=x2label, decimals=decimals)

fig.tight_layout(rect=[0.04, 0.04, 1, 1])
path = "nl_k16_mb.pdf"
fig.savefig(path, dpi=300)
%download_file nl_k16_mb.pdf