In [None]:
# clone the Github repo https://github.com/aw31/empirical-ntks to the local filesystem
import os
os.chdir('empirical-ntks-main')

import copy
import logging
import pathlib
import threading
import time
import torch
import numpy as np
from torch.multiprocessing import Process, Queue
from tqdm.auto import tqdm

from multiqueue_worker import multiqueue_worker
from utils import init_torch, humanize_units

local = threading.local()

In [2]:
import argparse
import pprint
import sys
from torch.multiprocessing import set_start_method, set_sharing_strategy
from utils import init_logging, load_model, load_dataset

# Set up
set_start_method("spawn")
set_sharing_strategy("file_system")

parser = argparse.ArgumentParser()
parser.add_argument("dataset", type=str)
parser.add_argument("model", type=str)
parser.add_argument("--datadir", type=str, default="./datasets")
parser.add_argument("--savedir", type=str, default="./ntks")
parser.add_argument("--logdir", type=str)
parser.add_argument("--workers-per-device", type=int, default=1)
parser.add_argument("--grad-chunksize", type=int, default=1900000)
parser.add_argument("--mm-col-chunksize", type=int, default=20000)
parser.add_argument("--ntk-dtype", type=str, default="float32")
parser.add_argument("--loader-batch-size", type=int, default=512)
parser.add_argument("--loader-num-workers", type=int, default=12)
parser.add_argument("--no-pinned-memory", dest="pin_memory", action="store_false")
parser.add_argument("--allow-tf32", action="store_true")
parser.add_argument("--benchmark", action="store_true")
parser.add_argument(
    "--non-deterministic", dest="deterministic", action="store_false"
)
args = parser.parse_args(args=['CIFAR-100', 'resnet-50_pretrained'])

In [None]:
# Initialize torch
init_torch_kwargs = {
    "allow_tf32": args.allow_tf32,
    "benchmark": args.benchmark,
    "deterministic": args.deterministic,
}
init_torch(**init_torch_kwargs, verbose=True)

# Initialize model
model = load_model(args.model)

param_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
param_batches = (param_count - 1) // args.grad_chunksize + 1
logging.info(f"Splitting {param_count} parameters into {param_batches} batches")

# Initialize datasets
datadir = pathlib.Path(args.datadir)
train_set = load_dataset(datadir, args.dataset, "train")
if args.dataset=='Flowers-102':
    val_set = load_dataset(datadir, args.dataset, "val")

test_set = load_dataset(datadir, args.dataset, "test")

In [4]:
class FeatureExtractor(torch.nn.Module):
    def __init__(self, pretrained_model):
        super(FeatureExtractor, self).__init__()
        self.features = torch.nn.Sequential(*list(pretrained_model.children())[:-1])

    def forward(self, x):
        features = self.features(x)
        features = features.view(features.size(0), -1)
        return features

# Initialize model
pretrain_model = load_model(args.model)
# Create an instance of the feature extractor
model = FeatureExtractor(pretrain_model)

In [5]:

def get_features(model, test_set, batch_size=128):
    # Set the model to evaluation mode
    model.eval()

    # Disable gradient calculation
    with torch.no_grad():
        # Move the model to the GPU if available
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)

        # Create a DataLoader for the test dataset
        test_loader = torch.utils.data.DataLoader(test_set, batch_size=batch_size, shuffle=False, num_workers=12)

        # Create an empty list to store the features
        X = []
        Y = []
        # Iterate over the test dataset
        for images, labels in tqdm(test_loader):
            # Move the images to the GPU if available
            images = images.to(device)

            # Forward pass through the model
            output = model.features(images)

            # Append the features to the list
            X.append(output.cpu())
            Y.append(labels.cpu())

        # Concatenate the features into a single tensor
        X = torch.cat(X)
        X = X.reshape(X.shape[0],-1).numpy()
        Y = torch.nn.functional.one_hot(torch.cat(Y)).numpy()
    return X, Y


In [7]:
X_train, Y_train = get_features(model, train_set)
if args.dataset=='Flowers-102':
    X_val, Y_val = get_features(model, val_set)
X_test, Y_test = get_features(model, test_set)

100%|██████████| 391/391 [29:04<00:00,  4.46s/it] 
100%|██████████| 79/79 [07:36<00:00,  5.78s/it] 


In [23]:
import numpy as np
res = {'X_train':X_train, 'Y_train':Y_train,
        'X_test':X_test, 'Y_test':Y_test}
if args.dataset=='Flowers-102':
    res['X_val'] = X_val
    res['Y_val'] = Y_val
with open("../data_{}_{}.npz".format(args.dataset,args.model), "wb") as f:
    np.savez(f, **res)

# Subset the CIFAR-100 dataset

In [None]:
data, model = 'CIFAR-100', 'resnet-50_pretrained'
with open('../data_{}_{}_full.npz'.format(data, model), 'rb') as f:
    dat = np.load(f)
    X = dat['X_train']
    Ys = dat['Y_train']
    X_test = dat['X_test']    
    Ys_test = dat['Y_test']
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit

sss = StratifiedShuffleSplit(n_splits=10, test_size=0.9, random_state=0)
idx = next(sss.split(X_train, np.argmax(Y_train, axis=1)))[0]

res = {'X_train':X_train[idx,:], 'Y_train':Y_train[idx,:],
        'X_test':X_test, 'Y_test':Y_test}

with open("../data_{}_{}.npz".format(args.dataset,args.model), "wb") as f:
    np.savez(f, **res)