In [1]:
## Demo codes for "Deep Networks Always Grok and Here is Why", ArXiv 2024
## Authors: Ahmed Imtiaz Humayun, Randall Balestriero, Richard Baraniuk
## Website: bit.ly/grok-adversarial
## Wandb Dashboard containing example logs: bit.ly/grok-adv-trak

In [2]:
#@title License
#
#The MIT License (MIT)
#Copyright © 2024 Ahmed Imtiaz Humayun
#Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
#The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#

#@title Setup

%cd /content/
!git clone https://github.com/AhmedImtiazPrio/grok-adversarial.git
%cd 'grok-adversarial'

!pip install ml_collections
!pip install wandb
!pip install numba opencv-python
!pip install einops

In [3]:
#@title Imports
%cd 'grok-adversarial'

import torch as ch
import torchvision
from torchvision import transforms
from torch.nn import CrossEntropyLoss, BCELoss
from torch.optim import SGD, lr_scheduler, AdamW
import numpy as np

import ml_collections
from tqdm import tqdm
import os
import time
import logging

import wandb


from dataloaders import cifar10_dataloaders, cifar10_dataloaders_ffcv, get_LC_samples
from models import make_resnet18k
from utils import flatten_model, add_hooks_preact_resnet18
from attacks import PGD
from local_complexity import get_intersections_for_hulls
from samplers import get_ortho_hull_around_samples, get_ortho_hull_around_samples_w_orig

%cd ../..
import splinecam
%cd 'grokking/grok-adversarial'

  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


/home/niket/grokking/grok-adversarial
/home/niket
/home/niket/grokking/grok-adversarial


  self.shell.db['dhist'] = compress_dhist(dhist)[-100:]


In [4]:
#@title Train and evaluation functions

def train(model, loaders, config, add_hook_fn, hulls=None):

    model.cuda()

    ## setup optimizer
    if config.optimizer == 'sgd':
        print('Using SGD optimizer')
        opt = SGD(model.parameters(),
                  lr=config.lr,
                  momentum=config.momentum,
                  weight_decay=config.weight_decay)

    elif config.optimizer == 'adam':
        opt = AdamW(model.parameters(),
                    lr=config.lr,
                    weight_decay=config.weight_decay)

    else:
        raise NotImplementedError

    ## resume training
    if config.resume_step>0 and config.resume_dir is not None:

        assert os.path.exists(
            os.path.join(
                config.load_dir, f'checkpoint-s:{config.resume_step}.pt'
                )
            ), f"Resume checkpoint not found"

        base_chkpt = ch.load(os.path.join(config.resume_dir,
                             f'checkpoint-s:{-1}.pt'))
        model = base_chkpt['model']
        opt = base_chkpt['optimizer']
        state_chkpt = ch.load(os.path.join(config.resume_dir,
                              f'checkpoint-s:{config.resume_step}.pt'))
        model.load_state_dict(state_chkpt['model_state_dict'])
        opt.load_state_dict(state_chkpt['optimizer_state_dict'])


    ### save model and optimizer before training
    ### shortcut to avoid removing hooks later
    ch.save(
              {
                  'model': model,
                  'optimizer': opt,
              },
              os.path.join(
                  config.model_dir,
                  f'checkpoint-s:{-1}.pt'
              )
    )

    iters_per_epoch = len(loaders['train'])
    epochs = np.floor(config.num_steps/iters_per_epoch)

    if config.lr_schedule_flag:
        # Cyclic LR with single triangle
        print('Using Learning Rate Schedule')
        lr_schedule = np.interp(np.arange((epochs+1) * iters_per_epoch),
                                [0, config.lr_peak_epoch * iters_per_epoch, epochs * iters_per_epoch],
                                [0, 1, 0])

        scheduler = lr_scheduler.LambdaLR(opt, lr_schedule.__getitem__)

    loss_fn = ch.nn.BCEWithLogitsLoss()


    train_step = 0 if config.resume_step <= 0 else config.resume_step

    ## stat dict for plotting convenience
    stat_names = ['train_acc','train_loss','test_loss',
                  'test_acc','adv_acc', 'train_step', 'l2', 'last_layer_l2'] + \
            [each+'_LC' for each in list(hulls.keys())]

    stats = dict(zip(stat_names,[[] for _ in stat_names]))

    print(f'Logging stats for steps:{config.log_steps}')

    while True:

        if train_step > config.num_steps: break

        for ims, labs in tqdm(loaders['train'], desc=f"train_step:{train_step}-{train_step+iters_per_epoch}"):

            ### train step
            ims = ims.cuda()
            labs = labs.cuda()

            if config.use_ffcv and labs.max()>config.num_class-1:
              labs = ch.clip(labs,0,config.num_class-1) ## weird ffcv bug

            opt.zero_grad()
            out = model(ims)

            loss = loss_fn(out, labs)
            loss.backward()
            opt.step()
            train_step += 1

            if config.lr_schedule_flag:
                scheduler.step()

            ### log step
            if train_step in config.log_steps:
                print('Computing stats...')

                model.eval()

                ### checkpoint before anything
                if config.save_model:

                  ch.save(
                      {
                          'model_state_dict': model.state_dict(),
                          'optimizer_state_dict': opt.state_dict(),
                      },
                      os.path.join(
                          config.model_dir,
                          f'checkpoint-s:{train_step}.pt'
                      )
                  )

                ## evaluate on train and test
                train_acc, train_loss = evaluate(model,
                                                 loaders['train'],
                                                 loss_fn)
                test_acc, test_loss = evaluate(model,loaders['test'],
                                                 loss_fn)
                l2_norm_in = ch.linalg.norm(model.input_layer.weight).item()
                l2_last_layer_norm = ch.linalg.norm(model.output_layer.weight).item()

                stats['train_acc'].append(train_acc)
                stats['test_acc'].append(test_acc)
                stats['train_loss'].append(train_loss)
                stats['test_loss'].append(test_loss)
                stats['train_step'].append(train_step)
                stats['l2'].append(l2_norm_in)
                stats['last_layer_l2'].append(l2_last_layer_norm)
                
                if config.splinecam:
                    print('Wrapping model with SplineCam...')
                    
                    domain = (
                    (-config.mu_1 * config.splinecam_domain, config.mu_1 * config.splinecam_domain),
                    (config.mu_1 * config.splinecam_domain, config.mu_2 * config.splinecam_domain),
                    (config.mu_2 * config.splinecam_domain, -config.mu_2 * config.splinecam_domain),
                    (-config.mu_2 * config.splinecam_domain, config.mu_1 * config.splinecam_domain),
                    )

                    T = splinecam.utils.get_proj_mat(domain)
                    NN = splinecam.wrappers.model_wrapper(
                        model,
                        input_shape=model.input_shape,
                        T = T,
                        dtype = ch.float64,
                        device = 'cuda'
                    )

                ## evaluate local complexity
                # add hooks
                if config.compute_LC:

                  model, layer_names, activation_buffer = add_hook_fn(model, config)

                  if hulls is not None:
                    for k in hulls.keys():

                      # compute number of neurons that intersect hulls
                      # using network activations

                      with ch.no_grad():

                        n_inters, _ = get_intersections_for_hulls(
                                        hulls[k],
                                        model=model,
                                        batch_size=config.LC_batch_size,
                                        layer_names=layer_names,
                                        activation_buffer=activation_buffer
                                  )

                      stats[k+'_LC'].append(n_inters.cpu())

                ## evaluate robustness
                if config.compute_robust:

                  adv_acc = evaluate_adv(model, loaders['test'], config)
                  stats['adv_acc'].append(adv_acc)


                if config.wandb_log:
                  wandb.log({
                      'iter' : train_step,
                      'train/acc': stats['train_acc'][-1],
                      'train/loss': stats['train_loss'][-1],
                      'test/acc' : stats['test_acc'][-1],
                      'test/loss' : stats['test_loss'][-1],
                      'train/LC' : stats['train_LC'][-1].sum(1).mean(0),
                      'test/LC' : stats['test_LC'][-1].sum(1).mean(0),
                      'random/LC' : stats['rand_LC'][-1].sum(1).mean(0),
                      'adv/acc' : stats['adv_acc'][-1],
                      'l2' : stats['l2'][-1],
                      'last_layer_l2' : stats['last_layer_l2'][-1]
                  })

                model.train()


    ## save after training is complete
    ch.save(
        {
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': opt.state_dict(),
        },
        os.path.join(
            config.model_dir,
            f'checkpoint-s:{train_step}.pt'
        )
    )

    return stats

@ch.no_grad
def evaluate(model, dloader, loss_fn=None):

  acc = 0
  loss = 0
  nsamples = 0
  nbatch = 0

  for ims, labs in dloader:

      ims = ims.cuda()
      labs = labs.cuda()

      outs = model(ims)

      if loss_fn is not None:
        loss += loss_fn(outs, labs)
        nbatch += 1

      acc += ch.sum(labs == (outs>0.5)).cpu()
      nsamples += outs.shape[0]

  return acc/nsamples, loss/nbatch

def evaluate_adv(model, dloader, config):

  atk = PGD(model,
          eps=config.atk_eps,
          alpha=config.atk_alpha,
          steps=config.atk_itrs,
          dmin=config.dmin,
          dmax=config.dmax
          )

  acc = 0
  nsamples = 0
  for ims, labs in tqdm(dloader, desc=f"Computing robust acc for eps:{config.atk_eps:.3f}"):

    ims = ims.cuda()
    labs = labs.cuda()

    adv_images = atk(ims, labs)

    with ch.no_grad():
        adv_pred = model(adv_images).argmax(dim=-1)

    acc += ch.sum(labs == (adv_pred>0.5)).cpu()
    nsamples += len(labs)

  return acc/nsamples

In [5]:
def add_hooks_MLP(model, config, verbose=False):
    """
    Add hooks to preact resnet
    """

    names,modules = flatten_model(model)
    assert len(names) == len(modules)

    ## add hooks to linear layers
    layer_ids = np.asarray([i for i,each in enumerate(modules) if (type(each)==ch.nn.Linear)])

    activation = {}
    def get_activation(name):
        def hook(model, input, output):
            activation[name] = output.detach()
        return hook

    for each in layer_ids:
        modules[each].register_forward_hook(get_activation(names[each]))

    layer_names = np.sort(np.asarray(names)[layer_ids])

    if verbose:
        print('Adding Hook to',layer_names)

    return model, layer_names, activation

In [6]:
def get_config():
  """hyperparameter configuration."""
  config = ml_collections.ConfigDict()

  config.optimizer = 'sgd'
  config.lr = 1e-3
  config.momentum = 0.00
  config.lr_schedule_flag = False

  config.train_batch_size = 100
  config.test_batch_size = 100
  config.num_steps = 500000                       # number of training steps
  config.weight_decay = 0.1
  config.label_smoothing = 0.
  config.log_steps = np.unique(
      np.logspace(0,5.7,50).astype(int).clip(
          0,config.num_steps
          )
      )
  config.seed = 42
  config.use_aug = False
  config.normalize = True                        # rescale cifar10 to have mean 0 std 1.25

  if config.normalize:
      config.dmax = 2.7537                       # precomputed data max/min needed for PGD
      config.dmin = -2.4291
  else:
      config.dmax = 1
      config.dmin = 0


  config.save_model = False                      # save every model checkpoint
  config.wandb_log = True                        # log using wandb
  config.save_model = False                      # save every model checkpoint
  config.wandb_log = True                        # log using wandb
  config.wandb_proj = 'grok-adv-MLP-XOR'
  config.wandb_pref = 'XOR-MLP'
  config.use_ffcv = True

  ## resnet params
  config.k = 16                                  # Resnet width parameter, number of filters in first layer
  config.num_class = 10
  config.use_bn = False
  config.resume_dir = None                       # resume directory absolute path
  config.resume_step = -1                        # time step to resume, from resume directory

  ## local complexity approx. parameters
  config.compute_LC = True
  config.approx_n = 256                         # number of samples to use for approximation
  config.n_frame = 100                            # number of vertices for neighborhood
  config.r_frame = 0.0001                         # radius of \ell-1 ball neighborhood
  config.LC_batch_size = 256
  config.inc_centroid = False                     # include original sample as neighborhood vertex


  ## adv robustness parameters
  config.compute_robust = True                   # note that if normalize==True, data is not bounded between [0,1]
  config.atk_eps = 8/255   ## 8/255
  config.atk_alpha = 10/255  ## 2/255
  config.atk_itrs = 5

  config.splinecam = True
  config.splinecam_domain = 10

  config.input_dimension = 4000
  
  config.mu_norm = 5
  config.train_flip_prob = 0.01
  config.dataset_size_4 = 1000
  config.test_dataset_size_4 = 1000

  config.hidden_dim = 50000

  A = np.random.randn( config.input_dimension, 2)
  Q, _ = np.linalg.qr(A)
  config.mu_1 = Q[:,0] * config.mu_norm
  config.mu_2 = Q[:,1] * config.mu_norm

  return config


config = get_config()

In [7]:
class XORDataset(ch.utils.data.Dataset):
    def __init__(self, mu_1, mu_2, input_dimension, n_samples_per_center, eta_label_flipping_prob):
        self.mu_1 = mu_1
        self.mu_2 = mu_2

        data_1_1 = np.random.multivariate_normal(self.mu_1, np.eye(input_dimension), n_samples_per_center)
        data_1_2 = np.random.multivariate_normal(-self.mu_1, np.eye(input_dimension), n_samples_per_center)
        data_2_1 = np.random.multivariate_normal(self.mu_2, np.eye(input_dimension), n_samples_per_center)
        data_2_2 = np.random.multivariate_normal(-self.mu_2, np.eye(input_dimension), n_samples_per_center)

        self.data = np.concatenate([data_1_1, data_1_2, data_2_1, data_2_2])
        self.data = ch.Tensor(self.data)
        self.labels = ch.concat([ch.zeros(n_samples_per_center*2), ch.ones(n_samples_per_center*2)])

        c = ch.Tensor(np.random.rand(n_samples_per_center*4) < eta_label_flipping_prob)
        self.labels = c*(1-self.labels) + (1-c)*(self.labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]

train_dataset = XORDataset(config.mu_1, config.mu_2, input_dimension = config.input_dimension, n_samples_per_center = config.dataset_size_4, eta_label_flipping_prob = config.train_flip_prob)
test_dataset = XORDataset(config.mu_1, config.mu_2, input_dimension = config.input_dimension, n_samples_per_center = config.test_dataset_size_4, eta_label_flipping_prob = 0.0)

In [8]:
## load data
train_loader = ch.utils.data.DataLoader(train_dataset,
                                           batch_size=config.train_batch_size,
                                           shuffle=True)

test_loader = ch.utils.data.DataLoader(test_dataset,
                                           batch_size=config.test_batch_size,
                                           shuffle=True)


## initialize neighborhood sampler
sampler_params = {'n' : config.n_frame if not config.inc_centroid \
                    else config.n_frame+1,
                  'r' : config.r_frame, 'seed':config.seed}

sampler = get_ortho_hull_around_samples_w_orig if config.inc_centroid \
            else get_ortho_hull_around_samples

## select samples for neighborhood computation
train_LC_batch, _ = get_LC_samples(train_loader,config)
test_LC_batch, _ = get_LC_samples(train_loader,config)
if config.normalize:
  rand_LC_batch = ch.rand_like(test_LC_batch)*2.8*2 - 2.8
else:
  rand_LC_batch = ch.rand_like(test_LC_batch)   ## Data domain [0,1]


## sample hulls/neighborhoods
train_hulls = sampler(
  train_LC_batch.cuda(),
  **sampler_params
    ).cpu()
test_hulls = sampler(
    test_LC_batch.cuda(),
    **sampler_params
).cpu()
rand_hulls = sampler(
    rand_LC_batch.cuda(),
    **sampler_params
).cpu()


## make hull dict. the keys of this dict will be used for logging
## you can add hulls separately for different classes as well to
## keep track of classwise statistics
hulls = {
    'train' : train_hulls,
    'test' : test_hulls,
    'rand' : rand_hulls
}

In [9]:
loaders = {
    'train' : train_loader,
    'test' : test_loader
}

## directory for saving logs and models
timestamp = time.ctime().replace(' ','_')
config.model_dir = os.path.join(f'./models/{timestamp}')
config.log_dir = os.path.join(f'./logs/{timestamp}')
os.mkdir(config.model_dir)
os.mkdir(config.log_dir)

In [10]:
if config.wandb_log:
  wandb_project = config.wandb_proj
  wandb_run_name = f"{config.wandb_pref}-{timestamp}"
  wandb.init(project=wandb_project, name=wandb_run_name, config=config)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mniketpatel[0m ([33mniket[0m). Use [1m`wandb login --relogin`[0m to force relogin


In [11]:
from torch import nn
import einops

class MLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, output_dim):
        super(MLP, self).__init__()
        self.input_dim = input_dim
        self.hidden_dims = hidden_dims
        self.n_layers = len(hidden_dims)
        self.output_dim = output_dim

        # Define input layer
        self.input_layer = nn.Linear(input_dim, hidden_dims[0])
        self.input_activation = nn.ReLU()

        # Define hidden layers
        self.hidden_layers = nn.ModuleList()
        for i in range(len(hidden_dims) - 1):
            layer = nn.Linear(hidden_dims[i], hidden_dims[i+1])
            self.hidden_layers.append(layer)
            activation = nn.ReLU()
            self.hidden_layers.append(activation)

        # Define output layer
        self.output_layer = nn.Linear(hidden_dims[-1], output_dim)

    def forward(self, x):
        #x = einops.rearrange(x, "b c u i -> b c (u i)")
        x = self.input_activation(self.input_layer(x))
        for layer in self.hidden_layers:
            x = layer(x)
        x = self.output_layer(x)
        x = einops.rearrange(x, "x 1 -> x")
        return x.float()

In [12]:
## create model
model = MLP(input_dim = config.input_dimension, hidden_dims = [config.hidden_dim], output_dim = 1)

stats = train(model, loaders,
      config=config, hulls=hulls,
      add_hook_fn=add_hooks_MLP
      )

NameError: name 'input_dimension' is not defined