In [1]:
import jax
import jax.numpy as jnp                # JAX NumPy
from jax import nn as jnn              # JAX nn
# from jax.config import config
# config.update('jax_enable_x64', True)
from flax import linen as nn           # The Linen API
from flax.training import train_state  # Useful dataclass to keep train state
import time
from absl import app
from functools import partial
from absl import flags
import numpy as np                     # Ordinary NumPy
import optax                           # Optimizers
from keras.datasets import mnist
from typing import List
import datetime
from tqdm import tqdm
from typing import Any, List, NamedTuple, Callable, Optional, Union
import jax.numpy as jnp
from optax._src import utils
from optax._src import combine
from optax._src import base
from optax._src import alias
ScalarOrSchedule = Union[float, base.Schedule]
import scipy

In [2]:
flags.DEFINE_float('beta1', 0.9, help='Beta1')
flags.DEFINE_float('beta2', 0.999, help='Beta2')
flags.DEFINE_float('lr', 0.0001, help='Learning rate')
flags.DEFINE_float('eps', 1e-8, help='eps')
flags.DEFINE_integer('batch_size',
                     1000, help='Batch size.')
flags.DEFINE_integer('model_size_multiplier',
                     1, help='Multiply model size by a constant')
flags.DEFINE_integer('model_depth_multiplier',
                     1, help='Multiply model depth by a constant')
flags.DEFINE_integer('warmup_epochs', 5, help='Warmup epochs')
flags.DEFINE_integer('epochs', 100, help='#Epochs')
flags.DEFINE_integer('t', 20, help='preconditioner computation frequency')
flags.DEFINE_enum('dtype', 'float32', ['float32', 'bfloat16'], help='dtype')
flags.DEFINE_enum('optimizer', 'tds', ['sgd', 'momentum', 'nesterov', 'adagrad',
  'rmsprop', 'tds', 'shampoo', 'diag_sonew'], help='optimizer')
FLAGS = flags.FLAGS

In [3]:
import sys
from absl import app

# Addresses `UnrecognizedFlagError: Unknown command line flag 'f'`
sys.argv = sys.argv[:1]

# `app.run` calls `sys.exit`
try:
  app.run(lambda argv: None)
except:
  pass

2022-10-19 17:00:02.857850: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudnn.so.8'; dlerror: libcudnn.so.8: cannot open shared object file: No such file or directory
2022-10-19 17:00:02.857897: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...


In [None]:
# Overwrite FLAGS here

In [4]:
MACHINE_EPS = 1e-7 if FLAGS.dtype=='float32' else 0.0078

In [5]:
num_edges_removed_tds = 0

def ldl2tridiag(Lsub,D):
  # n = D.shape[0]
  Xd = jnp.zeros_like(D)
  Xd = Xd.at[1:].set(D[1:]+Lsub*Lsub*D[:-1])
  Xd = Xd.at[0].set(D[0])
  Xe = Lsub*D[:-1]
  return Xd,Xe

def tridiagKFAC(Sd,Se, eps):
  # given diagonal-Sd and subdiagonal-Se
  # find the inverse of pd completion of this tridiagonal matrix
  # interms of Ldiag(D)L^T decomposition
  # outputs Lsub and D, where Lsub-subdiagonal of L
  Sd = Sd+eps
  psi = Se/Sd[1:]
  condCov = jnp.zeros_like(Sd)
  condCov = condCov.at[:-1].set(Sd[:-1]-Se*(Se/Sd[1:]))
  condCov = condCov.at[-1].set(Sd[-1])
  D = 1/(condCov)
  mask1 = condCov[:-1]<=MACHINE_EPS*Sd[:-1]
  mask2 = condCov <=MACHINE_EPS*Sd
  psi = jnp.where(mask1, 0, psi)
  D = jnp.where(mask2, 1/Sd, D)
  Lsub = -psi
  return ldl2tridiag(Lsub,D)

def logdet_tds(Sd,Se, eps):
  Sd = Sd+eps
  psi = Se/Sd[1:]
  condCov = jnp.zeros_like(Sd)
  condCov = condCov.at[:-1].set(Sd[:-1]-Se*(Se/Sd[1:]))
  condCov = condCov.at[-1].set(Sd[-1])
  D = 1/(condCov)
  mask1 = condCov[:-1]<=MACHINE_EPS*Sd[:-1]
  mask2 = condCov <=MACHINE_EPS*Sd
  psi = jnp.where(mask1, 0, psi)
  D = jnp.where(mask2, 1/Sd, D)
  return (-1*jnp.sum(jnp.log(D.astype(jnp.float32))), jnp.sum(mask2))

In [6]:
"""Sparse preconditioners."""
from typing import NamedTuple, Union

import chex
import jax
import jax.numpy as jnp
import optax

ScalarOrSchedule = Union[float, optax.Schedule]


def _update_moment(updates, moments, decay, order):
  """Compute the exponential moving average of the `order-th` moment."""
  return jax.tree_map(lambda g, t: (1 - decay) * (g**order) + decay * t,
                           updates, moments)


def _bias_correction(moment, decay, count):
  """Perform bias correction. This becomes a no-op as count goes to infinity."""
  beta = 1 - decay**count
  return jax.tree_map(lambda t: t / beta.astype(t.dtype), moment)


def scale_by_learning_rate(
    learning_rate: ScalarOrSchedule,
    flip_sign: bool = True) -> optax.GradientTransformation:
  m = -1 if flip_sign else 1
  if callable(learning_rate):
    return optax.scale_by_schedule(lambda count: m * learning_rate(count))
  return optax.scale(m * learning_rate)


def _update_nu(updates, nu_e, nu_d, beta2):
  """Compute the exponential moving average of the tridiagonal structure of the moment."""
  nu_d = jax.tree_map(lambda g, t: (1 - beta2) * (g**2) + beta2 * t,
                           updates, nu_d)
  nu_e = jax.tree_map(
      lambda g, t: (1 - beta2) * (g[:-1] * g[1:]) + beta2 * t, updates, nu_e)
  return nu_e, nu_d


class PreconditionTriDiagonalState(NamedTuple):
  """State for the Adam preconditioner."""
  count: chex.Array  # shape=(), dtype=jnp.int32
  mu: optax.Updates
  nu_e: optax.Updates
  nu_d: optax.Updates
  logdet: optax.Updates
  num_edges_removed: optax.Updates

def precondition_by_tds(
    b1: float = 0.9,
    b2: float = 0.999,
    eps: float = 1e-8,
    transpose: bool = True,
    adam_grafting: bool = False,
    debias: bool = True) -> optax.GradientTransformation:

  def init_fn(params):
    return PreconditionTriDiagonalState(
        count=jnp.zeros([], jnp.int32),
        mu = jax.tree_map(jnp.zeros_like, params),
        logdet = jax.tree_map(jnp.zeros_like, params),
        num_edges_removed = jax.tree_map(lambda g: jnp.array([0]), params),
        nu_e=jax.tree_map(lambda g: jnp.zeros(len(g.reshape(-1))-1, dtype=g.dtype), params),
        nu_d=jax.tree_map(lambda g: jnp.zeros(len(g.reshape(-1)), dtype=g.dtype), params))
  @jax.jit
  def update_fn(updates, state, params):
    updates_hat = jax.tree_map(lambda g: g.T.reshape(-1) if transpose else g.reshape(-1), updates)
    mu = _update_moment(updates, state.mu, b1, 1)
    nu_e, nu_d = _update_nu(updates_hat, state.nu_e, state.nu_d, b2)
    count = state.count + jnp.array(1, dtype=jnp.int32)
    mu_hat = mu if not debias else _bias_correction(mu, b1, count)
    nu_hat_e = nu_e if not debias else _bias_correction(nu_e, b2, count)
    nu_hat_d = nu_d if not debias else _bias_correction(nu_d, b2, count)

    temp = jax.tree_map(lambda d, e:
                             tridiagKFAC(d,e, eps),
                             nu_hat_d, nu_hat_e)
    pre_d = jax.tree_map(lambda h, g: g[0], nu_hat_d, temp)
    pre_e = jax.tree_map(lambda h, g: g[1], nu_hat_d, temp)
    temp = jax.tree_map(lambda d, e: logdet_tds(d,e, eps), nu_hat_d, nu_hat_e)
    logdet = jax.tree_map(lambda h, g: g[0], nu_hat_d, temp)
    num_edges_removed = state.num_edges_removed
    num_edges_removed = jax.tree_map(lambda h, g, l: l.at[0].set(l[0]+g[1]), nu_hat_d, temp, num_edges_removed)
    mu_hat_flat = jax.tree_map(lambda m: m.T.reshape(-1)
                                    if transpose else m.reshape(-1), mu_hat)
    # Multiply gradient with diagonal
    updates = jax.tree_map(lambda m, a: m*a, mu_hat_flat, pre_d)
    # updates[i] = updates[i] + gradient[i-1]*pre_e[i], for i>0
    updates = jax.tree_map(lambda u, m, a: u.at[1:].set(u[1:]+m[:-1]*a),
                                updates, mu_hat_flat, pre_e)
    # updates[i] = updates[i] + gradient[i+1]*pre_e[i], for i<n-1
    updates = jax.tree_map(lambda u, m, a: u.at[:-1].set(u[:-1]+m[1:]*a),
                                updates, mu_hat_flat, pre_e)
    # reshape them to the original param shapes
    updates = jax.tree_map(lambda mf, m: mf.reshape(m.T.shape).T
                                if transpose else mf.reshape(m.shape),
                                updates, mu_hat)
    return updates, PreconditionTriDiagonalState(count=count, mu=mu, nu_e=nu_e,
                                                 nu_d=nu_d, logdet=logdet,
                                                 num_edges_removed=num_edges_removed)

  return optax.GradientTransformation(init_fn, update_fn)

def tds(learning_rate: ScalarOrSchedule, b1=0.9, b2=0.99, eps=1e-8, transpose=True, adam_grafting=False):
    return combine.chain(
      precondition_by_tds(
          b1=b1, b2=b2, eps=eps, transpose=transpose, adam_grafting=adam_grafting),
        alias._scale_by_learning_rate(learning_rate),
    )

In [7]:
num_edges_removed_bds = 0

def _update_nu_banded(updates, nu_e, nu_d, beta2):
  nu_d = jax.tree_map(lambda g, t: (1-beta2) * (g**2) + beta2 * t,
                      updates, nu_d)
  def update_band(g, band, b):
    for i in range(b):
      band = band.at[:-(i+1), i].set((1-beta2)*(g[:-(i+1)]*g[i+1:]) + 
                                     beta2*band[:-(i+1), i])
    return band
  nu_e = jax.tree_map(lambda g, t: update_band(g, t, t.shape[-1]), updates,
                      nu_e)
  return nu_e, nu_d



def GENP_jax(a, b):
  """Gaussian elimination with no pivoting.

  % input: a is an batch x n x n nonsingular matrix
  %        b is an batch x n x 1 vector
  % return: x is the solution of Ax=b.
  % post-condition: A and b have been modified. 
  """
  n = a.shape[1]
  orig_a = a
  if b.shape[1] != n:
    raise ValueError("Invalid argument: " +
                     "incompatible sizes between A & b.", b.shape[-1], n)
  for pivot_row in range(n-1):
    for row in range(pivot_row+1, n):
      den = a[:, pivot_row, pivot_row]
      den = jnp.where(den == 0, MACHINE_EPS*orig_a[:, pivot_row, pivot_row], den)
      a = a.at[:, pivot_row, pivot_row].set(den)
      multiplier = a[:, row, pivot_row]/den
      multiplier = multiplier.reshape(-1, 1)
      a = a.at[:, row, pivot_row:].set(a[:, row, pivot_row:]-
                                       multiplier*a[:, pivot_row, pivot_row:])
      b = b.at[:, row].set(b[:, row] - multiplier*b[:, pivot_row])
  batches = a.shape[0]
  x = jnp.zeros((batches, n), dtype=a.dtype)
  k = n-1
  a = a.at[:, k, k].set(jnp.where(a[:, k, k] == 0, MACHINE_EPS*orig_a[:, k, k], a[:, k, k]))
  den = a[:, k, k].reshape(-1, 1)
  temp = b[:, k] / den
  temp = temp.reshape(-1)
  x = x.at[:, k].set(temp)
  k = k-1
  while k >= 0:
    first = a[:, k, k+1:].reshape((batches, 1, -1))
    second = x[:, k+1:].reshape((batches, -1, 1))
    second_term = jnp.matmul(first, second)
    temp = second_term.reshape(-1)
    den = a[:, k, k].reshape(-1)
    x = x.at[:, k].set((b[:, k].reshape(-1) - temp.reshape(-1))/den)
    k = k-1
  return x.reshape((batches, -1, 1))


def bandedInv(Sd,subDiags,ind,eps,innerIters):
  # given diagonal-Sd and subdiagonals-subDiags
  # find the inverse of pd completion of this banded matrix
  # interms of Ldiag(D)L^T decomposition
  # outputs Lsub and D, where Lsub-subdiagonals of L

  n = Sd.shape[0]
  b = subDiags.shape[1]

  bandvecs = jnp.concatenate((Sd.reshape(-1, 1), subDiags), axis=1)

  indX,indY = ind
  epsMat = jnp.zeros((b, b+1), dtype=Sd.dtype)
  epsMat = epsMat.at[:,0].set(eps)
  bandWindows = jnp.concatenate((bandvecs, epsMat), axis=0)
  sig22 = bandWindows[indX[:,1:,1:],indY[:,1:,1:]]
  sig21 = bandWindows[indX[:,1:,0],indY[:,1:,0]]

  def A_bmm(X):
    return jnp.matmul(sig22, X)

  diagSig22 = jnp.diagonal(sig22, axis1=1, axis2=2)

  def M_bmm(X):
    return jnp.broadcast_to(jnp.expand_dims(1/diagSig22, axis=-1), X.shape)*X

  # psi, _ = jax.scipy.sparse.linalg.cg(A_bmm, jnp.expand_dims(sig21, axis=-1),
  #                                     tol=1e-8, M=M_bmm, maxiter=innerIters)
  psi = GENP_jax(M_bmm(sig22), M_bmm(jnp.expand_dims(sig21, axis=-1)))
  # psi = GENP_jax(sig22, jnp.expand_dims(sig21, axis=-1))
  psi = psi.squeeze(-1)
  # print("sig22:\n", sig22.reshape(-1))
  # print("sig21:", sig21.reshape(-1))
  # print("psi:", psi.reshape(-1))
  # print("psi.shape:", psi.shape)
  # assert 1==2
  # print("sig22:", sig22)
  # print("sig21:", sig21)
  # print("Sd:", Sd)

  psiSig21 = jnp.matmul(psi.reshape((n,1,b)),
                        sig21.reshape((n, b, 1))).squeeze(-1).squeeze(-1)
  condCov = Sd - psiSig21
  # print("cond_cov:", condCov)

  ##################
  '''
  condCovFail = (condCov<=MACHINE_EPS*Sd).reshape((-1,1))
  condCovFail = jnp.broadcast_to(condCovFail, (condCovFail.shape[0], b))
  condCovFail = condCovFail.at[:-1,0].set(jnp.logical_or(condCovFail[:-1,0],
                                                          condCovFail[1:,0]))
  for i in range(1,b):
    condCovFail = condCovFail.at[:-(i+1),i].set(
        jnp.logical_or(condCovFail[:-(i+1), i], condCovFail[1:-(i), i-1]))
  # global num_edges_removed
  # num_edges_removed+=jnp.sum(condCovFail)
  # print("jnp.sum(condCovFail)", jnp.sum(condCovFail))

  psi = jnp.where(condCovFail, 0.0, psi)
  psiSig21 = jnp.matmul(psi.reshape((n, 1, b)), sig21.reshape((n, b, 1)))
  psiSig21 = psiSig21.squeeze(-1).squeeze(-1)
  condCov = Sd - psiSig21
  '''
  ####################

  def cond(arguments):
    condCov, psi, psiSig21, condCovFail = arguments
    return jnp.any((condCov <= MACHINE_EPS*Sd))

  def body(arguments):
    condCov, psi, psiSig21, condCovFail = arguments
    condCovFail = (condCov<=MACHINE_EPS*Sd).reshape((-1,1))
    condCovFail = jnp.broadcast_to(condCovFail, (condCovFail.shape[0], b))
    condCovFail = condCovFail.at[:-1,0].set(jnp.logical_or(condCovFail[:-1,0],
                                                           condCovFail[1:,0]))
    for i in range(1,b):
      condCovFail = condCovFail.at[:-(i+1),i].set(
          jnp.logical_or(condCovFail[:-(i+1), i], condCovFail[1:-(i), i-1]))

    psi = jnp.where(condCovFail, 0.0, psi)
    psiSig21 = jnp.matmul(psi.reshape((n, 1, b)), sig21.reshape((n, b, 1)))
    psiSig21 = psiSig21.squeeze(-1).squeeze(-1)
    condCov = Sd - psiSig21
    return (condCov, psi, psiSig21, condCovFail)

  condCovFail = (condCov<=MACHINE_EPS*Sd).reshape((-1,1))
  condCovFail = jnp.broadcast_to(condCovFail, (condCovFail.shape[0], b))
  condCovFail = condCovFail.at[:-1,0].set(jnp.logical_or(condCovFail[:-1,0], condCovFail[1:,0]))
  ret = (condCov, psi, psiSig21, condCovFail)
  ret = jax.lax.while_loop(cond, body, ret)
  condCov, psi, psiSig21, condCovFail = ret
  num_edges_removed_bds = jnp.sum(condCovFail)
  D = 1/(condCov)
  return psi.astype(Sd.dtype), D.astype(Sd.dtype), num_edges_removed_bds


def bandedMult(psi, D, vecv):
  b = psi.shape[1]
  update = vecv
  for i in range(b):
    update = update.at[:-i-1].set(update[:-i-1] - vecv[i+1:]*psi[:-i-1, i])
  update = update*D
  vecv2 = update
  for i in range(b):
    update = update.at[i+1:].set(update[i+1:] - vecv2[:-i-1]*psi[:-i-1, i])
  return update

def getl1norm(Sd, Se):
  bandSize = Se.shape[1]
  n = Sd.shape[0]
  temp = Sd
  for b in range(bandSize):
    temp = temp.at[:-(b+1)].set(Sd[:-(b+1)] + jnp.abs(Se[:, b][:-(b+1)]))
    temp = temp.at[(b+1):].set(Sd[(b+1):] + jnp.abs(Se[:, b][:-(b+1)]))
  return jnp.max(temp)

def bandedUpdates(Sd, subDiags, ind, eps, innerIters, mu):
  # innerIters = subDiags.shape[1]*13
  l1norm = getl1norm(Sd, subDiags)
  # psi, D = bandedInv(Sd+(eps*l1norm), subDiags, ind, eps, innerIters)
  psi, D, num_edges_removed_bds = bandedInv(Sd+eps, subDiags, ind, eps, innerIters)
  return (-1*jnp.sum(jnp.log(D.astype(jnp.float32))), num_edges_removed_bds)
  # return bandedMult(psi, D, mu)

def createInd(n,b):
  b1 = b+1
  offsetX = jnp.broadcast_to(jnp.expand_dims(jnp.arange(b1), axis=-1), (b1,b1))
  offsetX = jnp.triu(offsetX)+jnp.transpose(jnp.triu(offsetX,1), (1,0))

  offsetY = jnp.array(scipy.linalg.toeplitz(np.arange(b1)))

  indX = jnp.broadcast_to(jnp.expand_dims(jnp.expand_dims(jnp.arange(n), axis=-1), axis=-1), (n,b1,b1))
  indY = jnp.broadcast_to(jnp.expand_dims(jnp.expand_dims(jnp.zeros(n, dtype=jnp.int32), axis=-1), axis=-1), (n,b1,b1))

  indX = indX+jnp.expand_dims(offsetX, 0)
  indY = indY+jnp.expand_dims(offsetY, 0)

  return jnp.array([indX, indY])

class PreconditionBandedDiagonalState(NamedTuple):
  """State for the Adam preconditioner."""
  count: chex.Array  # shape=(), dtype=jnp.int32
  mu: optax.Updates
  nu_e: optax.Updates
  nu_d: optax.Updates
  ind: optax.Updates
  diag: optax.Updates
  logdet: optax.Updates
  num_edges_removed: optax.Updates

def precondition_by_bds(beta1: float = 0.9,
                        beta2: float = 0.999,
                        eps: float = 1e-8,
                        graft_eps: float = 1e-8,
                        graft_type: int = 0,
                        transpose: bool = True,
                        ridge_epsilon: float = 1e-12,
                        b: int = 3,
                        innerIters = 15,
                        debias: bool = True) -> optax.GradientTransformation:
  def init_fn(params):
    diag = None
    return PreconditionBandedDiagonalState(
        count=jnp.zeros([], jnp.int32),
        mu=jax.tree_map(jnp.zeros_like, params),
        logdet=jax.tree_map(jnp.zeros_like, params),
        num_edges_removed = jax.tree_map(lambda g: jnp.array([0]), params),
        nu_e=jax.tree_map(lambda g: jnp.zeros((len(g.reshape(-1)), b),
                                              dtype=g.dtype), params),
        nu_d=jax.tree_map(lambda g: jnp.zeros(len(g.reshape(-1)),
                                              dtype=g.dtype), params),
        ind=jax.tree_map(lambda g: createInd(len(g.reshape(-1)), b), params),
        diag=diag)
  @jax.jit
  def update_fn(updates, state, params):
    del params
    diag = state.diag
    mu = state.mu
    updates_hat = jax.tree_map(
        lambda g: g.T.reshape(-1) if transpose else g.reshape(-1), updates)
    nu_e, nu_d = _update_nu_banded(updates_hat, state.nu_e, state.nu_d, beta2)
    count = state.count + jnp.array(1, dtype=jnp.int32)
    mu_hat = mu if not debias else _bias_correction(mu, beta1, count)
    nu_hat_e = nu_e if not debias else _bias_correction(nu_e, beta2, count)
    nu_hat_d = nu_d if not debias else _bias_correction(nu_d, beta2, count)

    mu_hat_flat = jax.tree_map(lambda m: m.T.reshape(-1)
                               if transpose else m.reshape(-1), mu_hat)
    temp = jax.tree_map(lambda d, e, g, ind: bandedUpdates(d, e, ind, eps,
                                                             innerIters, g),
                           nu_hat_d, nu_hat_e, mu_hat_flat, state.ind)
    logdet = jax.tree_map(lambda h, g: g[0], nu_hat_d, temp)
    num_edges_removed = state.num_edges_removed
    num_edges_removed = jax.tree_map(lambda h, g, l: l.at[0].set(l[0]+g[1]), nu_hat_d, temp, num_edges_removed)
    
    return updates, PreconditionBandedDiagonalState(
        count=count, mu=mu, nu_e=nu_e, nu_d=nu_d, ind=state.ind, diag=diag,
        logdet=logdet, num_edges_removed=num_edges_removed)

  return optax.GradientTransformation(init_fn, update_fn)


def bds(learning_rate: ScalarOrSchedule,
        beta1: float = 0.9,
        beta2: float = 0.99,
        eps: float = 1e-8,
        graft_eps: float = 1e-8,
        graft_type: int = 0,
        weight_decay: float = 0.0,
        ridge_epsilon: float = 1e-12,
        b: int = 3,
        transpose: bool = True) -> optax.GradientTransformation:
  return optax.chain(
      precondition_by_bds(beta1=beta1, beta2=beta2, eps=eps,
                          graft_type=graft_type, innerIters=20,
                          graft_eps=graft_eps, ridge_epsilon=ridge_epsilon, b=b,
                          transpose=transpose),
      scale_by_learning_rate(learning_rate),
  )


In [10]:
class Autoencoder(nn.Module):
  enc_hidden_states: List[int]
  dec_hidden_states: List[int]
  dtype: Any
  param_dtype: Any

  @nn.compact
  def __call__(self, x):
    for i in range(len(self.enc_hidden_states)):
      x = nn.Dense(features = self.enc_hidden_states[i],
                   kernel_init=jnn.initializers.glorot_uniform(),
                   dtype=self.dtype, param_dtype=self.param_dtype)(x)
      if i<len(self.enc_hidden_states)-1:
        x = nn.tanh(x)
    for i in range(len(self.dec_hidden_states)):
      x = nn.Dense(features = self.dec_hidden_states[i],
                   kernel_init=jnn.initializers.glorot_uniform(),
                   dtype=self.dtype, param_dtype=self.param_dtype)(x)
      x = nn.tanh(x)
    x = nn.Dense(features = 784,
                 kernel_init=jnn.initializers.glorot_uniform(),
                 dtype=self.dtype, param_dtype=self.param_dtype)(x)
    return x

  def __hash__(self):
    return id(self)

In [11]:
def get_optimizer(opt, learning_rate):
  print("using tds optimizer to generate gradients")
  return tds(learning_rate, b1=FLAGS.beta1, b2=FLAGS.beta2, eps=FLAGS.eps, transpose=True, adam_grafting=False)

In [12]:
def create_train_state(params, model, opt, learning_rate):
  """Creates initial `TrainState`."""
  tx = get_optimizer(opt, learning_rate)
  return train_state.TrainState.create(
      apply_fn=model.apply, params=params, tx=tx)

@partial(jax.jit, static_argnums=0)
def train_step(model, state, x):
  def loss_fn(params):
    logits = model.apply(params, x)
    loss = optax.sigmoid_binary_cross_entropy(logits, x).mean(0).sum()
    return loss
  grad_fn = jax.value_and_grad(loss_fn)
  loss, grads = grad_fn(state.params)
  state = state.apply_gradients(grads=grads)
  return state, loss, grads

@partial(jax.jit, static_argnums=0)
def eval_step(model, state, x):
  logits = model.apply(state.params, x)
  loss = optax.sigmoid_binary_cross_entropy(logits, x)
  return loss.astype(jnp.float32).mean(0).sum()

In [13]:
def train_epoch(state, model, train_ds, batch_size, epoch, rng, lrVec):
  train_ds_size = len(train_ds)
  steps_per_epoch = train_ds_size // batch_size
  print("epoch:", epoch,"and lr going to be used:", lrVec[epoch])
  grads_list = []

  perms = jax.random.permutation(rng, train_ds_size)
  perms = perms[:steps_per_epoch * batch_size]  # skip incomplete batch
  perms = perms.reshape((steps_per_epoch, batch_size))
  batch_metrics = []
  for perm in perms:
    train_x = train_ds[perm]
    state, loss, grads = train_step(model, state, train_x)
    # print("loss:", loss, loss.dtype)
    batch_metrics.append(loss.item())
    grads_list.append(grads)

  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = np.mean(batch_metrics_np)

  # print('train epoch: %d, loss: %.4f' % (epoch, epoch_metrics_np))

  return state, grads_list

In [14]:
global_grads_list = []
def main(argv):
  train_start = time.time()
  #Get random keys
  rng = jax.random.PRNGKey(0)
  rng, key1 = jax.random.split(rng)
  rng, key2 = jax.random.split(rng)
  rng, key3 = jax.random.split(rng)

  #Get dtype
  if FLAGS.dtype=="float32":
    dtype = jnp.float32
  elif FLAGS.dtype=="bfloat16":
    dtype = jnp.bfloat16
  else:
      raise NotImplementedError

  print("dtype is:", dtype)
  #Generate data
  (train_inputs, _), (test_inputs, test_labels) = mnist.load_data()
  train_inputs = jnp.array(train_inputs).astype(jnp.float32)
  test_inputs = jnp.array(test_inputs).astype(jnp.float32)

  # Rescale input images to [0, 1]
  train_inputs = jnp.reshape(train_inputs, [-1, 784]) / 255.0
  test_inputs = jnp.reshape(test_inputs, [-1, 784]) / 255.0

  train_inputs = train_inputs.astype(dtype)
  test_inputs = test_inputs.astype(dtype)

  num_train_examples = train_inputs.shape[0]
  num_test_examples = test_inputs.shape[0]
  print('MNIST dataset:')
  print('Num train examples: ' + str(num_train_examples))
  print('Num test examples: ' + str(num_test_examples))

  batch_size = FLAGS.batch_size

  encoder_sizes = [1000] +  [500] * FLAGS.model_depth_multiplier + [250, 30]
  decoder_sizes = [250] +  [500] * FLAGS.model_depth_multiplier + [1000]

  encoder_sizes = [FLAGS.model_size_multiplier * e for e in encoder_sizes]
  decoder_sizes = [FLAGS.model_size_multiplier * e for e in decoder_sizes]
  encoder_decoder_sizes = encoder_sizes, decoder_sizes

  input_image_batch = np.random.normal(size=(batch_size,784))
  input_image_batch = jnp.array(input_image_batch).astype(dtype)

  #Set learning rate schedule array
  num_epochs = FLAGS.epochs
  warmup_epochs = FLAGS.warmup_epochs
  lr = FLAGS.lr
  lrVec = np.concatenate([np.linspace(0,lr,warmup_epochs),
                          np.linspace(lr,0,num_epochs-warmup_epochs+2)[1:-1]],
                         axis=0)
  lrVec = jnp.array(lrVec).astype(dtype)
  def autoencoder_shedule(lrVec):
    def schedule(count):
      bucket = count//60
      return lrVec[bucket]
    return schedule

  train_loss_val_=[]
  model = Autoencoder(encoder_sizes, decoder_sizes, dtype=dtype, param_dtype=dtype)
  params = model.init(key3, input_image_batch)
  state = create_train_state(params, model, FLAGS.optimizer, autoencoder_shedule(lrVec))
  print("Initialized model and optimizer!")
  global global_grads_list
  for i in range(num_epochs):
    rng, key = jax.random.split(rng)
    epoch_start = time.time()
    state, grads = train_epoch(state, model, train_inputs, FLAGS.batch_size, i, key, lrVec)
    print("this epoch time:", time.time()-epoch_start)
    global_grads_list += grads
    train_loss_val = eval_step(model, state, train_inputs)
    train_loss_val_.append(train_loss_val)
    print("epoch: " + str(i) +", train_loss_val: " + str(train_loss_val))
    print("")
  print("training time:", time.time()-train_start)
#   return global_grads_list

In [15]:
############# GET GRADIENTS USING tds OPTIMIZER ################
app.run(main)

I1019 17:00:34.867702 140604106622784 xla_bridge.py:345] Unable to initialize backend 'tpu_driver': NOT_FOUND: Unable to find driver in registry given worker: 
I1019 17:00:36.429756 140604106622784 xla_bridge.py:345] Unable to initialize backend 'rocm': NOT_FOUND: Could not find registered platform with name: "rocm". Available platform names are: Host Interpreter CUDA
I1019 17:00:36.431886 140604106622784 xla_bridge.py:345] Unable to initialize backend 'tpu': module 'jaxlib.xla_extension' has no attribute 'get_tpu_client'


dtype is: <class 'jax.numpy.float32'>
MNIST dataset:
Num train examples: 60000
Num test examples: 10000
using tds optimizer to generate gradients
Initialized model and optimizer!
epoch: 0 and lr going to be used: 0.0


  abs_value_flat = jax.tree_leaves(abs_value)
  value_flat = jax.tree_leaves(value)


this epoch time: 16.74345064163208
epoch: 0, train_loss_val: 544.74426

epoch: 1 and lr going to be used: 2.5e-05
this epoch time: 0.2474362850189209
epoch: 1, train_loss_val: 90.72567

epoch: 2 and lr going to be used: 5e-05
this epoch time: 0.2123701572418213
epoch: 2, train_loss_val: 75.64782

epoch: 3 and lr going to be used: 7.5e-05
this epoch time: 0.21440744400024414
epoch: 3, train_loss_val: 72.10297

epoch: 4 and lr going to be used: 1e-04
this epoch time: 0.21770763397216797
epoch: 4, train_loss_val: 71.94248

epoch: 5 and lr going to be used: 9.8958335e-05
this epoch time: 0.21609210968017578
epoch: 5, train_loss_val: 69.91649

epoch: 6 and lr going to be used: 9.7916665e-05
this epoch time: 0.21682119369506836
epoch: 6, train_loss_val: 103.98162

epoch: 7 and lr going to be used: 9.6875e-05
this epoch time: 0.23123979568481445
epoch: 7, train_loss_val: 78.250496

epoch: 8 and lr going to be used: 9.583333e-05
this epoch time: 0.23268604278564453
epoch: 8, train_loss_val: 72

this epoch time: 0.2052001953125
epoch: 68, train_loss_val: 58.379463

epoch: 69 and lr going to be used: 3.2291668e-05
this epoch time: 0.20019030570983887
epoch: 69, train_loss_val: 58.167492

epoch: 70 and lr going to be used: 3.125e-05
this epoch time: 0.2182753086090088
epoch: 70, train_loss_val: 57.94629

epoch: 71 and lr going to be used: 3.0208334e-05
this epoch time: 0.224928617477417
epoch: 71, train_loss_val: 57.91194

epoch: 72 and lr going to be used: 2.9166667e-05
this epoch time: 0.22822093963623047
epoch: 72, train_loss_val: 57.74287

epoch: 73 and lr going to be used: 2.8125e-05
this epoch time: 0.23056745529174805
epoch: 73, train_loss_val: 57.61487

epoch: 74 and lr going to be used: 2.7083333e-05
this epoch time: 0.21734929084777832
epoch: 74, train_loss_val: 57.46028

epoch: 75 and lr going to be used: 2.6041667e-05
this epoch time: 0.22044610977172852
epoch: 75, train_loss_val: 57.213398

epoch: 76 and lr going to be used: 2.5e-05
this epoch time: 0.21688270568847

2022-10-19 17:01:45.134419: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 408.33MiB (rounded to 428160000)requested by op 
2022-10-19 17:01:45.228100: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] ****************************************************************************************************
2022-10-19 17:01:45.228488: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 428160000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  190.27MiB
              constant allocation:         4B
        maybe_live_out allocation:         4B
     preallocated temp allocation:  408.33MiB
  preallocated temp fragmentation:         0B (0.00%)
                 total allocation:  598.59MiB
Peak buffe

ValueError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 428160000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  190.27MiB
              constant allocation:         4B
        maybe_live_out allocation:         4B
     preallocated temp allocation:  408.33MiB
  preallocated temp fragmentation:         0B (0.00%)
                 total allocation:  598.59MiB
Peak buffers:
	Buffer 1:
		Size: 228.88MiB
		Operator: op_name="jit(eval_step)/jit(main)/Autoencoder/tanh" source_file="/tmp/ipykernel_219857/3037788847.py" source_line=19
		XLA Label: tanh
		Shape: f32[60000,1000]
		==========================

	Buffer 2:
		Size: 179.44MiB
		Operator: op_name="jit(eval_step)/jit(main)/Autoencoder/Dense_7/add" source_file="/home/devvrit/anaconda3/envs/env/lib/python3.7/site-packages/flax/linen/linear.py" source_line=200
		XLA Label: broadcast
		Shape: f32[60000,784]
		==========================

	Buffer 3:
		Size: 179.44MiB
		Entry Parameter Subshape: f32[60000,784]
		==========================

	Buffer 4:
		Size: 2.99MiB
		Entry Parameter Subshape: f32[1000,784]
		==========================

	Buffer 5:
		Size: 2.99MiB
		Entry Parameter Subshape: f32[784,1000]
		==========================

	Buffer 6:
		Size: 1.91MiB
		Entry Parameter Subshape: f32[500,1000]
		==========================

	Buffer 7:
		Size: 1.91MiB
		Entry Parameter Subshape: f32[1000,500]
		==========================

	Buffer 8:
		Size: 488.3KiB
		Entry Parameter Subshape: f32[250,500]
		==========================

	Buffer 9:
		Size: 488.3KiB
		Entry Parameter Subshape: f32[500,250]
		==========================

	Buffer 10:
		Size: 29.3KiB
		Entry Parameter Subshape: f32[30,250]
		==========================

	Buffer 11:
		Size: 29.3KiB
		Entry Parameter Subshape: f32[250,30]
		==========================

	Buffer 12:
		Size: 3.9KiB
		Entry Parameter Subshape: f32[1000]
		==========================

	Buffer 13:
		Size: 3.9KiB
		Entry Parameter Subshape: f32[1000]
		==========================

	Buffer 14:
		Size: 3.1KiB
		Entry Parameter Subshape: f32[784]
		==========================

	Buffer 15:
		Size: 2.0KiB
		Entry Parameter Subshape: f32[500]
		==========================



In [16]:
len(global_grads_list) # should be 60*FLAGS.epochs

5460

In [17]:
num_edges_removed_tds = 0
optimizer_tds = tds(0.0, b1=FLAGS.beta1, b2=FLAGS.beta2, eps=FLAGS.eps, transpose=True, adam_grafting=False)
tds_state = optimizer_tds.init(global_grads_list[0])
log_det_tds = jax.tree_map(lambda g: jnp.zeros(len(global_grads_list)), global_grads_list[0])
eps = FLAGS.eps
for i, grads in tqdm(enumerate(global_grads_list)):
  updates, new_opt_state = optimizer_tds.update(grads, tds_state, None)
  log_det_tds = jax.tree_map(lambda l, v: l.at[i].set(v), log_det_tds, new_opt_state[0].logdet)
  tds_state = new_opt_state


5460it [02:29, 36.48it/s]


In [18]:
num_edges_removed_bds = 0
optimizer_bds = bds(0.0, beta1=FLAGS.beta1, beta2=FLAGS.beta2, eps=FLAGS.eps, graft_eps=0.0, weight_decay=0.0, b=4, transpose=True, graft_type=0)
bds_state = optimizer_bds.init(global_grads_list[0])
log_det_bds = jax.tree_map(lambda g: jnp.zeros(len(global_grads_list)), global_grads_list[0])
eps = FLAGS.eps
for i, grads in tqdm(enumerate(global_grads_list)):
  updates, new_opt_state = optimizer_bds.update(grads, bds_state, grads)
  log_det_bds = jax.tree_map(lambda l, v: l.at[i].set(v), log_det_bds, new_opt_state[0].logdet)
  bds_state = new_opt_state


2022-10-19 17:07:33.894007: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:479] Allocator (GPU_0_bfc) ran out of memory trying to allocate 149.54MiB (rounded to 156800000)requested by op 
2022-10-19 17:07:34.016604: W external/org_tensorflow/tensorflow/core/common_runtime/bfc_allocator.cc:491] ****************************************************************************************************
2022-10-19 17:07:34.016765: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2129] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 156800000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  149.54MiB
              constant allocation:         0B
        maybe_live_out allocation:  149.54MiB
     preallocated temp allocation:         0B
                 total allocation:  299.07MiB
              total fragmentation:         0B (0.00%)
Peak buffe

XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 156800000 bytes.
BufferAssignment OOM Debugging.
BufferAssignment stats:
             parameter allocation:  149.54MiB
              constant allocation:         0B
        maybe_live_out allocation:  149.54MiB
     preallocated temp allocation:         0B
                 total allocation:  299.07MiB
              total fragmentation:         0B (0.00%)
Peak buffers:
	Buffer 1:
		Size: 149.54MiB
		Operator: op_name="jit(concatenate)/jit(main)/concatenate[dimension=0]" source_file="/tmp/ipykernel_219857/543851162.py" source_line=197
		XLA Label: concatenate
		Shape: s32[2,784000,5,5]
		==========================

	Buffer 2:
		Size: 74.77MiB
		Entry Parameter Subshape: s32[1,784000,5,5]
		==========================

	Buffer 3:
		Size: 74.77MiB
		Entry Parameter Subshape: s32[1,784000,5,5]
		==========================



In [None]:
######### PLOTTING #############

In [None]:
def flatten(p, label=None):
  if isinstance(p, dict):
    for k, v in p.items():
      yield from flatten(v, k if label is None else f"{label}.{k}")
  else:
    yield (label, p)

In [None]:
def plot(name, a, b, summary_writer_true, summary_writer_false):
  with summary_writer_true.as_default():
    for step, i, in enumerate(a):
      tf.summary.scalar(name, i.item(), step=step)
  with summary_writer_false.as_default():
    for step, i, in enumerate(b):
      tf.summary.scalar(name, i.item(), step=step)

In [None]:
# !mkdir logs
# !rm -rf logs
log_det_tds_overall = jnp.zeros(len(global_grads_list))
log_det_bds_overall = jnp.zeros(len(global_grads_list))
import tensorflow as tf
# current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/tds'
summary_writer_tds = tf.summary.create_file_writer(log_dir)

# current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
log_dir = 'logs/bds'
summary_writer_bds = tf.summary.create_file_writer(log_dir)

try:
  log_det_tds = dict(flatten(log_det_tds.unfreeze()))
except:
  log_det_tds = dict(flatten(log_det_tds))
try:
  log_det_bds = dict(flatten(log_det_bds.unfreeze()))
except:
  log_det_bds = dict(flatten(log_det_bds))
for k, v in log_det_tds.items():
  plot(k, log_det_tds[k], log_det_bds[k], summary_writer_tds, summary_writer_bds)
  log_det_tds_overall+=log_det_tds[k]
  log_det_bds_overall+=log_det_bds[k]

plot("overall_logdet", log_det_tds_overall, log_det_bds_overall, summary_writer_tds, summary_writer_bds)

In [None]:
%load_ext tensorboard

In [None]:
!tensorboard --logdir logs --port 6006

In [None]:
!kill 198880

In [None]:
tds_state[0].num_edges_removed

In [None]:
# for b=1
bds_state[0].num_edges_removed

In [None]:
# !rm -rf logs

In [None]:
# for b=4
bds_state[0].num_edges_removed