In [1]:
%cd ~/projects/GraB-lib

%load_ext autoreload
%autoreload 2

/home/aris/projects/GraB-lib


In [2]:
from torch import nn

import os
import sys
from functools import partial, reduce
from pathlib import Path
from dataclasses import dataclass, field

import evaluate
import numpy as np
import pandas as pd
import wandb
from tqdm import tqdm
from absl import logging

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from transformers import (
    HfArgumentParser, TrainingArguments, set_seed
)

import torchopt
from torch.func import (
    grad, grad_and_value, vmap, functional_call
)

from grabngo import GraBSampler, BalanceType
from grabngo.utils import EventTimer, pretty_time

In [3]:
device = 'cuda'
batch_size = 16
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(
        (0.1307,), (0.3081,)
    ),
    transforms.Lambda(lambda x: x.view(-1)),
])

# Loading the dataset and preprocessing
train_dataset = datasets.MNIST(
    root='data/external',
    train=True,
    download=True,
    transform=transform
)
# train_dataset = torch.utils.data.Subset(train_dataset, range(0, 6000))
test_dataset = datasets.MNIST(
    root='data/external',
    train=False,
    download=True,
    transform=transform
)

in_dim, num_classes = 784, 10

loss_fn = nn.CrossEntropyLoss().to(device)

train_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    # sampler=sampler,
    persistent_workers=False,
    num_workers=1,
    pin_memory=True
)
train_eval_loader = torch.utils.data.DataLoader(
    dataset=train_dataset,
    batch_size=batch_size,
    persistent_workers=False,
    num_workers=1,
    pin_memory=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_dataset,
    batch_size=batch_size,
    persistent_workers=False,
    num_workers=1,
    pin_memory=True
)


In [4]:
from torchinfo import summary

# model = nn.Sequential(
#     nn.Linear(in_dim, 100),
#     nn.ReLU(),
#     nn.Linear(100, 100),
#     nn.ReLU(),
#     nn.Linear(100, num_classes)
# )

model = nn.Linear(in_dim, num_classes).to(device)

summary(model, input_size=(batch_size, in_dim), device=device)

Layer (type:depth-idx)                   Output Shape              Param #
Linear                                   [16, 10]                  7,850
Total params: 7,850
Trainable params: 7,850
Non-trainable params: 0
Total mult-adds (M): 0.13
Input size (MB): 0.05
Forward/backward pass size (MB): 0.00
Params size (MB): 0.03
Estimated Total Size (MB): 0.08

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

In [None]:
for epoch in range(1):
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = loss_fn(output, target)
        loss.backward()
        optimizer.step()

In [6]:
import copy

def make_functional(mod, disable_autograd_tracking=False):
    params_dict = dict(mod.named_parameters())
    params_names = params_dict.keys()
    params_values = tuple(params_dict.values())

    stateless_mod = copy.deepcopy(mod)
    stateless_mod.to('meta')

    def fmodel(new_params_values, *args, **kwargs):
        new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
        return torch.func.functional_call(stateless_mod, new_params_dict, args, kwargs)

    if disable_autograd_tracking:
        params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
    return fmodel, params_values

def make_functional_with_buffers(mod, disable_autograd_tracking=False):
    params_dict = dict(mod.named_parameters())
    params_names = params_dict.keys()
    params_values = tuple(params_dict.values())

    buffers_dict = dict(mod.named_buffers())
    buffers_names = buffers_dict.keys()
    buffers_values = tuple(buffers_dict.values())

    stateless_mod = copy.deepcopy(mod)
    stateless_mod.to('meta')

    def fmodel(new_params_values, new_buffers_values, *args, **kwargs):
        new_params_dict = {name: value for name, value in zip(params_names, new_params_values)}
        new_buffers_dict = {name: value for name, value in zip(buffers_names, new_buffers_values)}
        return torch.func.functional_call(stateless_mod, (new_params_dict, new_buffers_dict), args, kwargs)

    if disable_autograd_tracking:
        params_values = torch.utils._pytree.tree_map(torch.Tensor.detach, params_values)
    return fmodel, params_values, buffers_values

In [7]:
from grabngo.sorter.beta import NTKBalance

params = dict(model.named_parameters())
buffers = dict(model.named_buffers())

n = 6000
d = sum(p.numel() for p in params.values())
loss_fn = nn.CrossEntropyLoss().to(device)

data = train_dataset.data[:n].reshape(n, -1).float() / 255
target = train_dataset.targets[:n]

sorter = NTKBalance(
    n, d, model, params, buffers, loss_fn, data, target
)

  0%|          | 0/47 [00:00<?, ?it/s]

In [13]:
K = sorter.K

torch.linalg.matrix_rank(K, atol=1e-8, hermitian=True)

tensor(6000, device='cuda:0')

In [21]:
%timeit torch.linalg.eigh(K)

512 ms ± 15.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [26]:
torch.lobpcg(K, largest=True)

(tensor([44706.7305], device='cuda:0', grad_fn=<LOBPCGAutogradFunctionBackward>),
 tensor([[ 0.0009],
         [-0.0405],
         [ 0.0007],
         ...,
         [ 0.0057],
         [ 0.0023],
         [ 0.0036]], device='cuda:0', grad_fn=<LOBPCGAutogradFunctionBackward>))

In [23]:
torch.linalg.eigvalsh(K)

tensor(44706.7891, device='cuda:0', grad_fn=<MaxBackward1>)

In [45]:

from torch.func import grad, grad_and_value, vmap, functional_call, jacrev, vjp, jvp

from functorch import make_functional

model = nn.Linear(in_dim, num_classes).to(device)

fnet, params = make_functional(model)

def fnet_single(params, x):
    return fnet(params, x.unsqueeze(0)).squeeze(0)

# https://pytorch.org/functorch/stable/notebooks/neural_tangent_kernels.html
def empirical_ntk_jacobian_contraction(fnet_single, params, x):
    # Compute J(x1)
    jac = vmap(jacrev(fnet_single), (None, 0))(params, x)
    jac = [j.flatten(2) for j in jac]

    # Compute J(x1) @ J(x2).T
    result = torch.stack(
        [torch.einsum('Naf,Mbf->NMab', j1, j2) for j1, j2 in zip(jac, jac)])
    result = result.sum(0)
    return result

def empirical_ntk_ntk_vps(func, params, x1, x2, compute='full'):
    def get_ntk(x1, x2):
        def func_x1(params):
            return func(params, x1)

        def func_x2(params):
            return func(params, x2)

        output, vjp_fn = vjp(func_x1, params)

        def get_ntk_slice(vec):
            # This computes vec @ J(x2).T
            # `vec` is some unit vector (a single slice of the Identity matrix)
            vjps = vjp_fn(vec)
            # This computes J(X1) @ vjps
            _, jvps = jvp(func_x2, (params,), vjps)
            return jvps

        # Here's our identity matrix
        basis = torch.eye(output.numel(), dtype=output.dtype, device=output.device).view(output.numel(), -1)
        return vmap(get_ntk_slice)(basis)

    # get_ntk(x1, x2) computes the NTK for a single data point x1, x2
    # Since the x1, x2 inputs to empirical_ntk_ntk_vps are batched,
    # we actually wish to compute the NTK between every pair of data points
    # between {x1} and {x2}. That's what the vmaps here do.
    result = vmap(vmap(get_ntk, (None, 0)), (0, None))(x1, x2)

    if compute == 'full':
        return result
    if compute == 'trace':
        return torch.einsum('NMKK->NM', result)
    if compute == 'diagonal':
        return torch.einsum('NMKK->NMK', result)


@torch.no_grad()
def get_0shot_order(
        model,
        params,
        dataset,
        construct_kernel_matrix: bool = True,
        largest_eigval: bool = True,
        ascending: bool = True,
        device: torch.device = torch.device("cuda"),
):
    # Try to load data at once
    try:
        x = dataset.data.to(device=device).reshape(-1, 784)  # (n, ...)
        x = x.float() / 255

        ntk = empirical_ntk_ntk_vps(
            fnet_single, params, x, x
        )

        print(ntk.shape)
    except torch.cuda.OutOfMemoryError as err:
        raise err
        print("Failed to load data at once, loading data one by one")
        ...

get_0shot_order(model, params, train_dataset)