In [1]:
from load_raw_dataset import load_raw_dataset
import torch

In [None]:
#pip install quinine
#pip install torchvision
#pip install git+https://github.com/openai/CLIP.git

In [3]:
import torch
from tqdm import tqdm
import pickle
import os
from load_raw_dataset import load_raw_dataset

root = "/home/ubuntu"

def compute_eNTK(model, dataset_name, split, subsample_size=500000, seed=123):
    
    dataset = load_raw_dataset(dataset_name, split)

    model.eval()
    num_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    params = list(model.parameters())

    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    
    random_index = torch.randperm(num_params)[:subsample_size]
    
    if not os.path.exists(f"{root}/eNTK-robustness/data/ntk_{dataset_name}_{subsample_size}/{split}"):
        os.system(f"mkdir {root}/eNTK-robustness/data/ntk_{dataset_name}_{subsample_size}/{split}")
        
    for i in tqdm(range(len(dataset))):
        model.zero_grad()
        model.forward(torch.unsqueeze(dataset[i][0], 0).to(device))[0].backward() #to(device) put in for domainnet
        eNTK = []
        for idx, param in enumerate(params):
            if param.requires_grad: #param.grad is not None:
                eNTK.append(param.grad.flatten())
        eNTK = torch.cat(eNTK)
        #subsampling
        ntk_data_point = torch.clone(eNTK[random_index])
        torch.save(ntk_data_point, f"{root}/eNTK-robustness/data/ntk_{dataset_name}_{subsample_size}/{split}/ntk_{i}.pt")
        
    labels_dir = f"{root}/eNTK-robustness/data/ntk_{dataset_name}_{subsample_size}/labels"
    labels_file = f"{labels_dir}/labels_{split}.pkl"
    store_labels(dataset, labels_dir, labels_file)
      
def store_labels(raw_dataset, save_dir, save_file):
    labels = []
    for i in tqdm(range(len(raw_dataset))):
        labels.append(raw_dataset[i][1])
    if not os.path.exists(save_dir):
        os.system(f"mkdir {save_dir}")
    if not os.path.exists(save_file):
        os.system(f"touch {save_file}")
    pickle.dump(labels, open(save_file, 'wb'))
                  

#add parser
from construct_model import build_model
import quinine
config_path = f"{root}/eNTK-robustness/configs/adaptation/domainnet.yaml"
config = quinine.Quinfig(config_path)
model = build_model(config)
print("Starting to compute NTK")
for split in ["sketch_val","real_val","painting_val","clipart_val"]:
    compute_eNTK(model, "domainnet", split)

    

Fine Tuning 38317921 of 102008162 parameters.
Starting to compute NTK


100%|███████████████████████████████████████████████████████████████████████████████| 2399/2399 [01:57<00:00, 20.48it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 2399/2399 [00:15<00:00, 152.48it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 6943/6943 [05:49<00:00, 19.87it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 6943/6943 [00:41<00:00, 166.52it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 2909/2909 [02:31<00:00, 19.20it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 2909/2909 [00:19<00:00, 148.75it/s]
100%|███████████████████████████████████████████████████████████████████████████████| 1616/1616 [01:27<00:00, 18.51it/s]
100%|██████████████████████████████████████████████████████████████████████████████| 1616/1616 [00:08<00:00, 189.10it/s]


In [7]:
import pickle

train_labels = pickle.load(open('/home/ubuntu/eNTK-robustness/data/ntk_domainnet_500000/labels/labels_clipart_val.pkl', 'rb'))

In [8]:
new_labels = set()
for i in train_labels:
    if i not in new_labels:
        new_labels.add(i)

In [9]:
new_labels

{0,
 1,
 2,
 3,
 4,
 5,
 6,
 7,
 8,
 9,
 10,
 11,
 12,
 13,
 14,
 15,
 16,
 17,
 18,
 19,
 20,
 21,
 22,
 23,
 24,
 25,
 26,
 27,
 28,
 29,
 30,
 31,
 32,
 33,
 34,
 35,
 36,
 37,
 38,
 39}

In [10]:
import torch

mega_kernel = torch.load('/home/ubuntu/eNTK-robustness/data/domainnet_mega_kernel.pt')

In [11]:
mega_kernel.shape

torch.Size([19404, 5537])

In [14]:
5537+2399+6943+2909+1616


19404

In [None]:

train_labels = torch.tensor(pickle.load(open(f"{labels_root}/labels_train.pkl", 'rb')))
test_labels = {}
for test_split in test_splits:
    test_labels[test_split] = torch.tensor(pickle.load(open(f"{labels_root}/labels_{test_split}.pkl",'rb')))


In [None]:
train_kernel = mega_kernel[:,:5537]
test_splits = ["sketch_val","real_val","painting_val","clipart_val"]
test_kernels = {}
test_kernels[test_splits[0]] = mega_kernel[:,5537:7936]
test_kernels[test_splits[1]] = mega_kernel[:,7936:14879]
test_kernels[test_splits[2]] = mega_kernel[:,14879:17788]
test_kernels[test_splits[3]] = mega_kernel[:,17788:19404]
labels_root = f"{root}/eNTK-robustness/data/ntk_domainnet_500000/labels"
train_labels = torch.tensor(pickle.load(open(f"{labels_root}/labels_train.pkl", 'rb')))
test_labels = {}
for test_split in test_splits:
    test_labels[test_split] = torch.tensor(pickle.load(open(f"{labels_root}/labels_{test_split}.pkl",'rb')))


In [18]:
5537+2399+6943+2909+1616

19404