In [1]:
%cd ~/projects/GraB-lib

%load_ext autoreload
%autoreload 2

/home/gw338/projects/GraB-lib


In [2]:
from torch import nn

import os
import sys
from functools import partial, reduce
from pathlib import Path
from dataclasses import dataclass, field

import evaluate
import numpy as np
import pandas as pd
import wandb
from tqdm import tqdm
from absl import logging

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from transformers import (
    HfArgumentParser, TrainingArguments, set_seed
)

import torchopt
from torch.func import (
    grad, grad_and_value, vmap, functional_call
)

from grabngo import GraBSampler, BalanceType
from grabngo.utils import EventTimer, pretty_time

from experiments.cv.models import LeNet

In [3]:
device = 'cuda'
batch_size = 1
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
])

# Loading the dataset and preprocessing
train_dataset = datasets.CIFAR10(
    root='data/external',
    train=True,
    download=True,
    transform=transform
)
test_dataset = datasets.CIFAR10(
    root='data/external',
    train=False,
    download=True,
    transform=transform
)

in_dim, num_classes = 3, 10

loss_fn = nn.CrossEntropyLoss().to(device)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    # sampler=sampler,
    persistent_workers=False,
    num_workers=1,
    pin_memory=True
)
train_eval_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    persistent_workers=False,
    num_workers=1,
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    persistent_workers=False,
    num_workers=1,
    pin_memory=True
)

Files already downloaded and verified
Files already downloaded and verified


In [4]:
loader = torch.utils.data.DataLoader(train_dataset, batch_size=len(train_loader), shuffle=False)

data, targets = next(iter(loader))

logging.set_verbosity(logging.INFO)

In [6]:
from grabngo.sorter import OfflineMeanBalance

def sort_grads(grads, epochs=10):
    n, d = grads.shape
    sorter = OfflineMeanBalance(n, d)
    print(sorter.orders)
    
    return sorter.offline_balance(grads, rounds=epochs)


In [7]:
from grabngo.sorter.beta.functional import compute_kernel
import gc
from matplotlib import pyplot as plt

num_nets = 64

kernel_dtype = torch.float32
kernel_device = torch.device("cuda")

n = len(train_loader)

K = torch.zeros(n, n, dtype=kernel_dtype, device=kernel_device)

# for balance in ['mean','rr','so']:
for seed in range(1, num_nets + 1):
    logging.info(f"Running seed {seed}")
    set_seed(seed)
    model = LeNet(in_dim, num_classes).cuda()
    params = dict(model.named_parameters())
    buffers = dict(model.named_buffers())

    for param in params.values():
        param.requires_grad = False
    dK = compute_kernel(
        model,
        params,
        buffers,
        loss_fn,
        data,
        targets,
        batch_size=128,
        kernel_dtype=kernel_dtype,
        kernel_device=kernel_device,
        centered_feature_map=True,
    )
    # print(dK[:4, :4])

    # _d = 256
    # plt.imshow(dK[:_d, :_d].cpu().numpy() / seed)
    # plt.savefig(f"sandbox/gary/0shot/ntk_{seed}.png")
    # plt.clf()
    K += dK
K /= num_nets

n = len(train_loader)
d = 62006

torch.save(K, f"sandbox/gary/0shot/ntk_lenet_{num_nets}.pt")

INFO:absl:Running seed 1
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 2
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 3
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 4
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 5
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 6
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 7
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 8
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 9
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 10
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 11
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 12
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 13
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 14
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 15
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 16
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 17
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 18
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 19
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 20
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 21
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 22
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 23
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 24
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 25
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 26
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 27
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 28
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 29
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 30
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 31
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 32
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 33
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 34
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 35
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 36
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 37
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 38
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 39
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 40
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 41
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 42
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 43
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 44
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 45
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 46
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 47
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 48
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 49
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 50
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 51
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 52
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 53
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 54
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 55
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 56
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 57
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 58
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 59
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 60
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 61
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 62
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 63
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...
INFO:absl:Running seed 64
INFO:absl:Computing the kernel matrix, this may take a while...


  0%|          | 0/391 [00:00<?, ?it/s]

INFO:absl:Finish computing the feature map, now computing the kernel matrix...


In [7]:
torch.lobpcg(K, k=1, largest=False, method="ortho")

(tensor([-0.3519], device='cuda:0'),
 tensor([[0.0034],
         [0.0036],
         [0.0049],
         ...,
         [0.0035],
         [0.0037],
         [0.0038]], device='cuda:0'))

In [11]:
torch.lobpcg(K, k=1, largest=True, method="ortho")

(tensor([11395.5449], device='cuda:0'),
 tensor([[ 0.0015],
         [ 0.0095],
         [ 0.0097],
         ...,
         [ 0.0096],
         [-0.0103],
         [-0.0093]], device='cuda:0'))

In [10]:
np.sqrt(12022 / 0.0073)

1283.2962694048842

In [13]:
eigvals, eigvecs = torch.linalg.eigh(K)

In [9]:
K.size()

torch.Size([50000, 50000])

In [11]:
orders = sort_grads(K, epochs=30)

tensor([  396, 49608, 44015,  ..., 43435, 25587, 49115])


In [13]:
orders

tensor([ 9668,  3118, 16256,  ...,  8040, 10601, 28892])

In [14]:
torch.save(orders, 'lenet_cifar10_64_mean_epoch_0_rounds_30.pt')