In [None]:
import torch as ch\n
import torchvision\n
from torchvision import transforms\n
from torch.nn import CrossEntropyLoss\n
from torch.optim import SGD, lr_scheduler, AdamW\n
from torch.func import jacfwd, jacrev\n
from torch import vmap\n
import numpy as np\n
from torch import nn\n
import torch as ch\n
import einops\n
\n
from attacks import PGD\n
\n
import ml_collections\n
from tqdm import tqdm\n
import os\n
import time\n
import logging\n
\n
import wandb

# Configuration Settings

In [None]:
def get_config():\n
  \
\
\
\n
  config = ml_collections.ConfigDict()\n
\n
  config.optimizer = 'adam'\n
  config.lr = 1e-4\n
  config.momentum = 0.01\n
  config.bias_decay = False\n
\n
  config.train_batch_size = 64 #256\n
  config.test_batch_size = 64 #256\n
\n
  config.num_steps = 1000000                       # number of training steps\n
  config.weight_decay = 0.01\n
\n
  config.label_smoothing = 0.0\n
\n
  config.log_steps = np.unique(\n
      np.logspace(0,np.log10(config.num_steps),50).astype(int).clip(\n
          0,config.num_steps\n
          )\n
      )\n
  \n
  config.seed = 42\n
\n
  #config.dmax = 1\n
  #config.dmin = 0\n
\n
  config.wandb_proj = 'LRLC_study_MNIST_MLP'\n
  config.wandb_pref = 'MNIST-MLP'\n
\n
  config.resume_step = 0\n
\n
  ## mlp params\n
  config.input_dim = 784\n
  config.output_dim = 10\n
  config.hidden_dim = 200\n
  config.n_layers = 4\n
  config.input_weights_init_scale = np.sqrt(2) #sets initialization standard deviation to be = (input_weights_init_scale)/sqrt(input_dim) \n
  config.output_weights_init_scale = np.sqrt(2)\n
\n
\n
  ## local complexity approx. parameters\n
  config.compute_LC = True\n
  config.n_batches = 2\n
  config.sigma = 0.01\n
  config.n_iters_LC = 1\n
\n
  ## adv robustness parameters\n

# Data Generation

In [None]:
from torch.utils.data import Subset\n
import random\n
\n
train_dataset = torchvision.datasets.MNIST('../mnist_data',\n
                                           download=True,\n
                                           train=True,\n
                                           transform=transforms.Compose([\n
                                               transforms.ToTensor(), # first, convert image to PyTorch tensor\n
                                               transforms.Normalize((0.1307,), (0.3081,)), # normalize inputs\n
                                               lambda x: einops.rearrange(x, 'c h w -> (c h w)') # flatten the input images\n

# Model Definition

In [None]:
class MLP(nn.Module):\n
    def __init__(self, input_dim, hidden_dim, output_dim, n_layers, input_std_init=1, output_std_init=1):\n
        super(MLP, self).__init__()\n
        self.input_dim = input_dim\n
        self.output_dim = output_dim\n
        self.n_layers = n_layers\n
\n
        # Define input layer\n
        self.layers = nn.ModuleList()\n

# Eval Functions

In [None]:
@ch.no_grad\n
def evaluate(model, dloader, loss_fn=None):\n
  \n
  model.eval()\n
\n
  acc = 0\n
  loss = 0\n
  nsamples = 0\n
  nbatch = 0\n
  \n
  for inputs, targets in dloader:\n
      \n
      inputs = inputs.cuda()\n
      targets = targets.cuda()\n
      outputs = model(inputs)\n
\n
      if loss_fn is not None:\n
        loss += loss_fn(outputs, targets).item()\n
        nbatch += 1\n
              \n
      acc += ch.sum(targets == outputs.argmax(dim=-1)).cpu()\n

# Training Loop

In [None]:
def train(model, loaders, config):\n
    print('Training....')\n
    print(f'Logging at steps: {config.log_steps}')\n
\n
    model.cuda()\n
\n
    # No Weight Decay on Biases\n
    decay = dict()\n
    no_decay = dict()\n
    for name, param in model.named_parameters():\n
        print('checking {}'.format(name))\n
        if 'weight' in name:\n

# Run

In [None]:
wandb_project = config.wandb_proj\n
timestamp = time.ctime().replace(' ','_')\n
wandb_run_name = f\
\n
wandb.init(project=wandb_project, name=wandb_run_name, config=config)