Run the code below once before running everything else.

In [None]:
import os
import glob
from shutil import move
data_dir = "tiny-imagenet-200"

def organize_data():
    train_path = os.path.join(data_dir, "train/")
    val_path = os.path.join(data_dir, "val/")
    test_path = os.path.join(data_dir, "test/")
    test_original_path = os.path.join(data_dir, "test_original")
    if not os.path.exists(test_original_path):
        if os.path.exists(test_path):
            os.rename(test_path, test_original_path)
            os.mkdir(test_path)

        val_dict = {}
        val_anno_path = os.path.join(val_path, "val_annotations.txt")
        with open(val_anno_path, "r") as f:
            for line in f.readlines():
                split_line = line.split("\t")
                val_dict[split_line[0]] = split_line[1]

        paths = glob.glob(os.path.join(data_dir, "val/images/*"))
        for path in paths:
            file = path.split("/")[-1]
            folder = val_dict[file]
            if not os.path.exists(os.path.join(val_path, folder)):
                os.mkdir(os.path.join(val_path, folder))
                os.mkdir(os.path.join(val_path, folder, "images"))
            if not os.path.exists(os.path.join(test_path, folder)):
                os.mkdir(os.path.join(test_path, folder))
                os.mkdir(os.path.join(test_path, folder, "images"))

        for path in paths:
            file = path.split("/")[-1]
            folder = val_dict[file]
            if len(glob.glob(os.path.join(val_path, folder, "images", "*"))) < 25:
                dest = os.path.join(val_path, folder, "images", file)
            else:
                dest = os.path.join(test_path, folder, "images", file)
            move(path, dest)

        os.rmdir(os.path.join(val_path, "images"))
organize_data()

In [None]:
from torchvision import transforms
import sys
import os
sys.path.append(os.path.abspath(".."))
from models import model_dict
from utils import NormalizeByChannelMeanStd
import numpy as np
from torchvision.datasets import CIFAR10, CIFAR100, ImageFolder
from dataset import TinyImageNetDataset
from torch.utils.data import DataLoader, Dataset, Subset
import torch
import pickle
from itertools import cycle

dataset = "TinyImagenet" # "cifar10" "cifar100" "TinyImagenet"
arch = "resnet18" # "resnet18" "resnet50" "vgg16_bn" 
data_dir = "tiny-imagenet-200" if dataset == "TinyImagenet" else "data"
model_init_seed = 1
unlearn_data_seed = 0
batch_size = 256


if dataset == "cifar10":
    classes = 10
    data_dir = data_dir + '/cifar10'
    normalization = NormalizeByChannelMeanStd(
            mean=[0.4914, 0.4822, 0.4465], std=[0.2470, 0.2435, 0.2616]
        )
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )
    train_set = CIFAR10(data_dir, train=True, transform=train_transform, download=False)
    test_set = CIFAR10(data_dir, train=False, transform=test_transform, download=False)
    train_set.targets = np.array(train_set.targets)
    test_set.targets = np.array(test_set.targets)
elif dataset == "cifar100":
    classes = 100
    data_dir = data_dir + '/cifar100'
    normalization = NormalizeByChannelMeanStd(
        mean=[0.5071, 0.4866, 0.4409], std=[0.2673, 0.2564, 0.2762]
    )
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
        ]
    )
    test_transform = transforms.Compose(
        [
            transforms.ToTensor(),
        ]
    )
    train_set = CIFAR100(data_dir, train=True, transform=train_transform, download=False)
    test_set = CIFAR100(data_dir, train=False, transform=test_transform, download=False)
    train_set.targets = np.array(train_set.targets)
    test_set.targets = np.array(test_set.targets)
elif dataset == "TinyImagenet":
    classes = 200
    normalization = NormalizeByChannelMeanStd(
        mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
    )
    train_transform = transforms.Compose(
        [
            transforms.RandomCrop(64, padding=4),
            transforms.RandomHorizontalFlip(),
        ]
    )
    test_transform = transforms.Compose([])
    train_path = os.path.join(data_dir, "train/")
    test_path = os.path.join(data_dir, "test/")
    train_set = ImageFolder(train_path, transform=train_transform)
    train_set = TinyImageNetDataset(train_set, cache_file='assets/tinyimagenet_preprocess/train.pt')
    test_set = ImageFolder(test_path, transform=test_transform)
    test_set = TinyImageNetDataset(test_set, cache_file='assets/tinyimagenet_preprocess/test.pt')
    train_set.targets = np.array(train_set.targets)
    test_set.targets = np.array(test_set.targets)

model = model_dict[arch](num_classes=classes)
model.normalize = normalization



##### orig
init_state = torch.load(f'assets/init_model/{dataset}_{arch}_init_weights_{model_init_seed}.pth', weights_only=True)
model.load_state_dict(init_state)

def _init_fn(worker_id):
    np.random.seed(int(model_init_seed))

train_loader = DataLoader(
    train_set,
    batch_size=batch_size,
    shuffle=True,
    worker_init_fn=_init_fn,
    num_workers=4,
    pin_memory=True
)

test_loader = DataLoader(
    test_set,
    batch_size=batch_size,
    shuffle=False,
    worker_init_fn=_init_fn,
    num_workers=4,
    pin_memory=True
)

for x, y in train_loader:
    model(x)
    break

######## unlearn


with open(f"assets/unlearn_set_idxs/{dataset}_forget_set_idx_{unlearn_data_seed}.pkl", "rb") as f:
    fgt_set_idx = pickle.load(f)
with open(f"assets/unlearn_set_idxs/{dataset}_retain_set_idx_{unlearn_data_seed}.pkl", "rb") as f:
    rtn_set_idx = pickle.load(f)


forget_set = Subset(train_set, fgt_set_idx)
retain_set = Subset(train_set, rtn_set_idx)


forget_loader = DataLoader(
    forget_set,
    batch_size=batch_size,
    shuffle=True,
    worker_init_fn=_init_fn,
    num_workers=4,
    pin_memory=True
)

retain_loader = DataLoader(
    retain_set,
    batch_size=batch_size,
    shuffle=True,
    worker_init_fn=_init_fn,
    num_workers=4,
    pin_memory=True
)

for (fgt_x, fgt_y), (rtn_x, rtn_y) in zip(cycle(forget_loader), retain_loader):
    model(fgt_x)
    model(rtn_x)
    break

