In [None]:
!git clone https://github.com/haoweiwang0/Coreset_Prioritization.git
%cd Coreset_Prioritization

In [None]:
!pip install neural-tangents

In [None]:
!pip install -U "jax[cuda12]"

In [None]:
import bilevel_coreset
import loss_utils
import ntk_generator

import numpy as np
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Subset
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
import torch.nn.functional as F

subset_size = 50

def same_seeds(seed):
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
same_seeds(0)

mnist_transforms = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.1307,), (0.3081,))
    ])
train_dataset = datasets.MNIST(root='data', train=True, transform=mnist_transforms, download=True)
test_dataset = datasets.MNIST('data', train=False, transform=mnist_transforms)

In [None]:
from tqdm.auto import tqdm
import models

def train_model(model, loader):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=5e-4, weight_decay=0)
    criterion = torch.nn.CrossEntropyLoss()
    pbar = tqdm(range(nr_epochs), desc="Training", unit="epoch")

    for epoch in pbar:
        model.train()
        training_loss = 0.0

        for batch_idx, (data, target) in enumerate(loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            training_loss += loss.item()
            optimizer.step()

        pbar.set_postfix({'loss': training_loss / (batch_idx + 1)})

def test_model(model, loader):
    model.to(device)
    criterion = torch.nn.CrossEntropyLoss()

    model.eval()
    pbar = tqdm(loader, desc="Testing", unit="batch")
    correct = 0
    testing_loss = 0.0

    with torch.no_grad():
        for data, target in pbar:
            data, target = data.to(device), target.to(device)
            output = model(data)
            loss = criterion(output, target)
            testing_loss += loss.item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()
            pbar.set_postfix({'loss': testing_loss / len(loader)})
    test_acc = 1. * correct / len(loader.dataset)
    return test_acc

Training:   0%|          | 0/50 [00:00<?, ?epoch/s]

  return F.conv2d(input, weight, bias, self.stride,


Testing:   0%|          | 0/469 [00:00<?, ?batch/s]

Uniform sample - Train accuracy 0.99365


Testing:   0%|          | 0/79 [00:00<?, ?batch/s]

Uniform sample - Test accuracy 0.9914


In [None]:
nr_classes = 10
batch_size = 128
device = 'cuda' if torch.cuda.is_available() else 'cpu'
nr_epochs = 50

resnet18 = models.ResNet18().to(device)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

train_model(resnet18, train_loader)
print('Original ResNet - Train accuracy', test_model(resnet18, train_loader))
print('Original ResNet - Test accuracy', test_model(resnet18, test_loader))

In [None]:
proxy_kernel_fn = lambda x, y: ntk_generator.generate_resnet_ntk(x.view(-1, 28, 28, 1).numpy(), y.view(-1, 28, 28, 1).numpy())

In [None]:
# work only on the first 10000 samples for speedup

limit = 10000

loader = torch.utils.data.DataLoader(train_dataset, batch_size=limit, shuffle=False)
X, y = next(iter(loader))

bc = bilevel_coreset.BilevelCoreset(outer_loss_fn=loss_utils.cross_entropy,
                                    inner_loss_fn=loss_utils.cross_entropy, out_dim=10,
                                    max_outer_it=10, outer_lr=0.05, max_inner_it=200)
coreset_inds, _ = bc.build_with_representer_proxy_batch(X, y, subset_size, proxy_kernel_fn, cache_kernel=True,
                                                        start_size=10, inner_reg=1e-7)

In [None]:
nr_epochs = 1000

coreset_net = models.ResNet18().to(device)
coreset_subset = Subset(train_dataset, coreset_inds)
train_loader = torch.utils.data.DataLoader(coreset_subset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size)

train_model(coreset_net, train_loader)
print('Coreset ResNet - Train accuracy', test_model(coreset_net, train_loader))
print('Coreset ResNet - Test accuracy', test_model(coreset_net, test_loader))

In [None]:
def test_with_logits(model, coreset_model, loader):
    model.eval()
    logits = []
    with torch.no_grad():
        for data, target in loader:
            data, target = data.to(device), target.to(device)
            output = F.softmax(model(data), dim=1)
            coreset_output = F.softmax(coreset_model(data), dim=1)
            logits.append((data, output, coreset_output, target))

    return logits

logits = test_with_logits(resnet18, coreset_net, test_loader)

In [None]:
prioritized_test_data = []

# calculate the similarity score between the logits from
# original DNN and coreset-trained DNN
# reorder test cases based on the similarity score
def test_case_prioritization(logits):
    result = []
    probs = torch.cat([item[1] for item in logits], dim=0).cpu().numpy()
    coreset_probs = torch.cat([item[2] for item in logits], dim=0).cpu().numpy()
    inputs = torch.cat([item[0] for item in logits], dim=0).cpu().numpy()
    targets = torch.cat([item[3] for item in logits], dim=0).cpu().numpy()
    for i in range(probs.shape[0]):
        item = probs[i]
        coreset_item = coreset_probs[i]
        similarity_score = F.cosine_similarity(torch.tensor(item).unsqueeze(0), torch.tensor(coreset_item).unsqueeze(0))
        result.append({'similarity_score': similarity_score,
                       'probabilities': item,
                       'input': inputs[i],
                       'target': targets[i]})
    result = sorted(result, key=lambda x: x['similarity_score'], reverse=False)
    return result

In [None]:
prioritized_test_data = test_case_prioritization(logits)

In [None]:
from torch.utils.data import TensorDataset

inputs = torch.stack([torch.from_numpy(item['input']) for item in prioritized_test_data])
targets = torch.stack([torch.tensor(item['target']) for item in prioritized_test_data])
tensor_dataset = TensorDataset(inputs, targets)
prioritized_loader = torch.utils.data.DataLoader(tensor_dataset, batch_size=1, shuffle=False)

In [None]:
def test_with_APFD(model, loader):
    model.to(device)
    model.eval()

    TFs = 10 * [0]

    with torch.no_grad():
        for ids, (data, target) in enumerate(loader):
            data, target = data.to(device), target.to(device)
            output = model(data)
            category = output.argmax(dim=1, keepdim=True)
            if category != target and TFs[target] == 0:
                TFs[target] = ids + 1

            if all(TFs):
                break

    APFD = 1 - (sum(TFs) / (10 * len(loader))) + 1 / (2 * len(loader))

    return APFD

print('ResNet - APFD', test_with_APFD(resnet18, prioritized_loader))