In [None]:
from torchvision import transforms
import sys
import os
sys.path.append(os.path.abspath(".."))
from models import model_dict
from utils import NormalizeByChannelMeanStd
import numpy as np
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from dataset import TinyImageNetDataset


dataset = "TinyImagenet" # "cifar10" "cifar100" "TinyImagenet"
arch = "resnet18" # "resnet18" "resnet50" "vgg16_bn" 
data_dir = "tiny-imagenet-200" if dataset == "TinyImagenet" else "data"



if dataset == "cifar10":
    classes = 10
    data_dir = data_dir + '/cifar10'
    normalization = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
        )
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )
    train_set = CIFAR10(data_dir, train=True, transform=train_transform, download=False)
    test_set = CIFAR10(data_dir, train=False, transform=test_transform, download=False)
    train_set.targets = np.array(train_set.targets)
    test_set.targets = np.array(test_set.targets)
elif dataset == "cifar100":
    classes = 100
    data_dir = data_dir + '/cifar100'
    normalization = NormalizeByChannelMeanStd(
        mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762]
    )
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )
    train_set = CIFAR100(data_dir, train=True, transform=train_transform, download=False)
    test_set = CIFAR100(data_dir, train=False, transform=test_transform, download=False)
    train_set.targets = np.array(train_set.targets)
    test_set.targets = np.array(test_set.targets)
elif dataset == "TinyImagenet":
    classes = 200
    normalization = NormalizeByChannelMeanStd(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
        ]
    )
    test_transform = transforms.Compose([])
    train_path = os.path.join(data_dir, "train/")
    test_path = os.path.join(data_dir, "test/")
    train_set = ImageFolder(train_path, transform=train_transform)
    train_set = TinyImageNetDataset(train_set)
    test_set = ImageFolder(test_path, transform=test_transform)
    test_set = TinyImageNetDataset(test_set)
    train_set.targets = np.array(train_set.targets)
    train_set.targets = np.array(train_set.targets)

model = model_dict[arch](num_classes=classes)
model.normalize = normalization
model(train_set[0][0]).shape

# """
# cifar10: 50000
# cifar100: 50000
# TinyImagenet: 99999
# """

# train_set.targets.shape



In [None]:
import pickle
train_set.targets
len_forget_set = 3000

for seed in range(3):
    fgt_set = np.random.choice(len(train_set.targets), size=len_forget_set, replace=False)
    rtn_set = np.setdiff1d(np.arange(len(train_set.targets)), fgt_set)
    with open(f"assets/unlearn_set_idxs/{dataset}_forget_set_idx_{seed}.pkl", "wb") as f:
        pickle.dump(fgt_set, f)
    with open(f"assets/unlearn_set_idxs/{dataset}_retain_set_idx_{seed}.pkl", "wb") as f:
        pickle.dump(rtn_set, f)


In [None]:
import torch

for seed in range(3):
    for dataset, arch in zip(["cifar10", "cifar100", "TinyImagenet", "TinyImagenet"], ["resnet18", "resnet50", "resnet18", "vgg16_bn"]):

        if dataset == "cifar10":
            classes = 10
            normalization = NormalizeByChannelMeanStd(
                mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
                )
        elif dataset == "cifar100":
            classes = 100
            normalization = NormalizeByChannelMeanStd(
                mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762]
            )
        elif dataset == "TinyImagenet":
            classes = 200
            normalization = NormalizeByChannelMeanStd(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            )

        model = model_dict[arch](num_classes=classes)
        model.normalize = normalization
        torch.save(model.state_dict(), f"assets/init_model/{dataset}_{arch}_init_weights_{seed}.pth")

