In [1]:
import os
os.chdir('..')

In [2]:
os.path.abspath('.')

'/data/kewen/IF_project'

In [3]:
class Struct:
    def __init__(self, **entries):
        self.__dict__.update(entries)

In [4]:
import argparse
import os
import random
from tkinter import Image
from matplotlib import pyplot as plt
from sklearn.utils import shuffle
from tqdm import tqdm
import copy

import numpy as np
import torch

from torchvision import models
from src.data_utils.MnistDataset import MnistDataset
from src.utils.utils import save_json
from src.data_utils.Cifar10Dataset import Cifar10Dataset
from src.solver.fenchel_solver import FenchelSolver
from src.modeling.classification_models import CnnCifar, MNIST_1
from src.modeling.influence_models import Net_IF, MNIST_IF_1
from torch.autograd.functional import hessian
from torch.nn.utils import _stateless
from torch.nn import CrossEntropyLoss 
import torch.nn.functional as F
import wandb
import yaml
YAMLPath = 'src/config/MNIST/default.yaml'

def get_single_image_from_dataset(dataset, idx):
        x, y = dataset[idx]
        x = x.unsqueeze(0)
        y = torch.LongTensor([y])
        return x, y
    

    


  from .autonotebook import tqdm as notebook_tqdm


In [5]:
parser = argparse.ArgumentParser()
parser.add_argument("--YAMLPath", type=str)
args, unknown = parser.parse_known_args()
if args.YAMLPath:
    YAMLPath = args.YAMLPath
with open(YAMLPath) as file:
    config = yaml.safe_load(file)
args= Struct(**config)

In [6]:
file_path = os.path.join(args._ckpt_dir, args._pretrain_ckpt_name)
inv_hessian_path = os.path.join(args._ckpt_dir, "numpy_inv_hessian_" + args._pretrain_ckpt_name)

if args.dataset_name == 'cifar10':
    Dataset = Cifar10Dataset
elif args.dataset_name == 'mnist':
    Dataset = MnistDataset
else:
    raise NotImplementedError()

class_label_dict = Dataset.get_class_label_dict()
CLASS_MAP = Dataset.get_class_map()
train_classes = [class_label_dict[c] for c in args.train_classes]
ImageDataset = Dataset(args.dev_original_folder, args.dev_transformed_folder, args.test_original_folder, args.test_transformed_folder, train_classes, args.num_per_class)

train_dataset, train_dataset_no_transform = ImageDataset.get_train()
dev_dataset, dev_dataset_no_transform = ImageDataset.get_dev()
train_dataset_size = len(train_dataset)
if args.classification_model == 'Resnet34':
    classification_model = models.resnet34(pretrained=True).to('cuda')
    classification_model.fc = torch.nn.Linear(
        classification_model.fc.in_features,
        10).to('cuda')
elif args.classification_model == 'CnnCifar':
    classification_model = CnnCifar(10).to('cuda')
elif args.classification_model == 'MNIST_1':
    classification_model = MNIST_1(args._hidden_size_classification, 10).to('cuda')
else:
    raise NotImplementedError()

if os.path.isfile(file_path):
    checkpoint = torch.load(file_path, map_location='cuda')
    global_epoch = checkpoint['epoch']
    classification_model.load_state_dict(checkpoint['model_states']['classification'])
else:
    print("=> no checkpoint found at '{}'".format(file_path))

x_dev, y_dev = get_single_image_from_dataset(
    dev_dataset, args.dev_id_num)
criterion = CrossEntropyLoss()

number of examples with label '0': 100
number of examples with label '1': 100
number of examples with label '2': 100
number of examples with label '3': 100
number of examples with label '4': 100
number of examples with label '5': 100
number of examples with label '6': 100
number of examples with label '7': 100
number of examples with label '8': 100
number of examples with label '9': 100
loaded train_dataset with 1000 samples
loaded dev FolderDataset with 10 files in folder, 


In [7]:
def loss(params):
    print(params.shape)
    params_splitted_reshaped = []
    start = 0
    end = 0
    for shape in shapes:
        end +=  np.prod(shape)
        print(start, end)
        params_splitted_reshaped.append(params[start:end].reshape(shape))
        start = end
    inputs = train_dataset[:][0].to('cuda')
    labels = train_dataset[:][1].to('cuda')
    out: torch.Tensor = _stateless.functional_call(classification_model, \
        {n: p for n, p in zip(names, params_splitted_reshaped)}, inputs)
    loss = criterion(out, labels)
    return loss

def loss_grad_at_point(model, x, y):
    # w should be torch.cat(tuple([_.view(-1) for _ in model.parameters()]))
    # train_point should be alpha_train_dataset[0], a tuple of image and label
    out = model(x)
    loss = criterion(out, torch.tensor([y]).to('cuda'))
    loss.backward()
    grad = torch.cat(tuple([_.grad.view(-1) for _ in model.parameters()]))
    for p in classification_model.parameters():
        p.grad = None
    return grad

def calculate_if(x_train, y_train, x_test, y_test, inv_Hessian):
    # test point should be alpha_test_dataset[10]

    test_loss = loss_grad_at_point(classification_model, x_train, y_train).to("cpu").numpy()

    train_loss = loss_grad_at_point(classification_model, x_test, y_test).to("cpu").numpy()

    if_score = -np.matmul(np.matmul(test_loss.T, inv_Hessian), train_loss)

    return if_score

In [8]:
x_dev = x_dev.to('cuda')
y_dev = y_dev.to('cuda')

In [47]:
names = list(n for n, _ in classification_model.named_parameters())
shapes = [p.shape for p in classification_model.parameters()] # a parameter for loss function, not passed explicitly
Hessian = hessian(loss, torch.cat(tuple([_.view(-1) for _ in classification_model.parameters()])))
np_Hessian = Hessian.to("cpu").numpy()/train_dataset_size
damping_matrix = np.diag(np.full(Hessian.shape[0],0.01),0)
damping_hessian = np_Hessian + damping_matrix
inv_hessian = np.linalg.inv(damping_hessian)


torch.Size([25450])
0 25088
25088 25120
25120 25440
25440 25450


KeyboardInterrupt: 

In [41]:
np.save(inv_hessian_path, inv_hessian)

In [9]:
inv_hessian = np.load(inv_hessian_path+".npy")

In [10]:
if_score_list = []
for i in tqdm(range(train_dataset_size)):
    if_score = calculate_if(train_dataset[i][0].to('cuda'), train_dataset[i][1].to('cuda'), x_dev, y_dev, inv_hessian)
    if_score_list.append(if_score)

100%|██████████████████████████████████████████████████████████████| 1000/1000 [03:18<00:00,  5.04it/s]


In [49]:
if_score_list

[-0.20505717373774218,
 -0.13675323938152464,
 1.4618436422043177,
 -0.5783401893822859,
 0.29210114835316525,
 -2.9287083337700874,
 -2.396179142163878,
 -0.0644562128974485,
 -19.26733133274052,
 -0.33017337495914173,
 -0.048070874162933014,
 -0.3448035323548823,
 12.35140544530747,
 -3.5751157135686222,
 0.2589196204763886,
 -7.994136526653134,
 -0.4977308160684129,
 -0.3000613195430474,
 -4.1247700115402655,
 2.1935133508446367,
 -0.03815414746735604,
 -1.2846074971721193,
 0.35292067823664053,
 -0.03627522050488789,
 -7.01181707883838,
 -43.54676256936907,
 -0.07032658716526022,
 -2.357623081620661,
 5.219364308945186,
 -0.8521883975233764,
 -0.0940683571280373,
 -1.5727824198900024,
 -28.901703810919532,
 -0.09899149752045851,
 -3.6186159343362254,
 -10.889226175547146,
 -1.1497048673995158,
 -4.639582988899004,
 -1.3576957689389293,
 -0.3339848332074453,
 1.0026317520344112,
 -0.08931735805444588,
 0.0009021972074156311,
 -2.063734903857029,
 -1.4200960209459565,
 0.111275183200

In [80]:
Hessian.shape

torch.Size([25450, 25450])

In [126]:
device = 'cpu'
def CE_loss_new_all2(w, hidden_size=32):
    divide1 = hidden_size*784
    divide2 = divide1 + hidden_size
    divide3 = divide2 + hidden_size*10
    
    images = train_dataset[:][0]
    labels =  train_dataset[:][1]

    out1 = F.linear(images.reshape(-1, 784).to(device), w[:divide1].reshape(hidden_size,784), w[divide1:divide2])
    out2 = F.relu(out1)
    out3 = F.linear(out2, w[divide2:divide3].reshape(10,hidden_size), w[divide3:])
    loss = criterion(out3, labels.to(device))
    print(criterion)
    print(labels.shape)
    print(out3.shape)
    print(loss.shape)
    return loss
names = list(n for n, _ in classification_model.named_parameters())
shapes = [p.shape for p in classification_model.parameters()]
Hessian2 = hessian(CE_loss_new_all2, torch.cat(tuple([_.view(-1) for _ in classification_model.parameters()])))

CrossEntropyLoss()
torch.Size([1000])
torch.Size([1000, 10])
torch.Size([])


# my implemetation loss grad

In [273]:
def loss_grad_at_point(model, x, y):
    # w should be torch.cat(tuple([_.view(-1) for _ in model.parameters()]))
    # train_point should be alpha_train_dataset[0], a tuple of image and label
    out = model(x)
    print(out)
    loss = criterion(out, torch.tensor([y]).to(device))
    print(loss)
    loss.backward()
    grad = torch.cat(tuple([_.grad.view(-1) for _ in model.parameters()]))
    for p in classification_model.parameters():
        p.grad = None
    return grad

In [274]:
a= loss_grad_at_point(classification_model, x_dev, y_dev)
a

tensor([[-3.5083,  5.2741,  0.5708, -0.1582, -0.4546, -0.0932,  0.7715,  0.6345,
          0.1679, -0.3296]], grad_fn=<AddmmBackward0>)
tensor(0.0507, grad_fn=<NllLossBackward0>)


tensor([-0.0022, -0.0022, -0.0022,  ...,  0.0092,  0.0058,  0.0035])

# Xuhui implement loss grad

In [277]:
def loss_grad_at_point(w, train_point, hidden_size = 32):
    # w should be torch.cat(tuple([_.view(-1) for _ in model.parameters()]))
    # train_point should be alpha_train_dataset[0], a tuple of image and label

    w = w.clone().detach().requires_grad_(True)

    divide1 = hidden_size*784
    divide2 = divide1 + hidden_size
    divide3 = divide2 + hidden_size*10

    out1 = F.linear(train_point[0].reshape(-1, 784).to(device), w[:divide1].reshape(hidden_size,784), w[divide1:divide2])
    out2 = F.relu(out1)
    out3 = F.linear(out2, w[divide2:divide3].reshape(10,hidden_size), w[divide3:])
    print(out3)
    loss = criterion(out3, torch.tensor([train_point[1]]).to(device))
    print(loss)
    loss.backward()
    for p in classification_model.parameters():
        p.grad = None
    return w.grad

In [278]:
b = loss_grad_at_point(torch.cat(tuple([_.view(-1) for _ in classification_model.parameters()])), [x_dev, y_dev])
b

tensor([[-3.5083,  5.2741,  0.5708, -0.1582, -0.4546, -0.0932,  0.7715,  0.6345,
          0.1679, -0.3296]], grad_fn=<AddmmBackward0>)
tensor(0.0507, grad_fn=<NllLossBackward0>)


tensor([-0.0022, -0.0022, -0.0022,  ...,  0.0092,  0.0058,  0.0035])

In [283]:
torch.sum(a != b)

tensor(0)

# they are the same

In [171]:
a[np.argwhere(a != b)]

tensor([[ 0.9000, -8.1000,  0.9000,  0.9000,  0.9000,  0.9000,  0.9000,  0.9000,
          0.9000,  0.9000]])

In [210]:
b[np.argwhere(a != b)]

tensor([[-0.0022, -0.0022, -0.0022,  ...,  0.0092,  0.0058,  0.0035]])

In [183]:
classification_model

MNIST_1(
  (l1): Linear(in_features=784, out_features=32, bias=True)
  (relu): ReLU()
  (l2): Linear(in_features=32, out_features=10, bias=True)
)

In [181]:
[x.shape for x in list(classification_model.parameters())]

[torch.Size([32, 784]),
 torch.Size([32]),
 torch.Size([10, 32]),
 torch.Size([10])]

In [198]:
torch.cat(tuple([_.view(-1) for _ in classification_model.parameters()]))[-10:]

tensor([ 0.0513, -0.0209,  0.1288, -0.0991,  0.0932, -0.0194, -0.0962, -0.0468,
         0.1214,  0.0078], grad_fn=<SliceBackward0>)