In [2]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"

In [None]:
import torchvision.transforms as transforms

mean = [x / 255 for x in [125.3, 123.0, 113.9]]
std = [x / 255 for x in [63.0, 62.1, 66.7]]

trans = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize(std=std, mean=mean)
])

In [7]:
import torch
from torchvision.datasets import CIFAR10

batch_size = 1

train_dataset = CIFAR10(root="data", train=True, download=True, transform=trans)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=False, num_workers=16)

print("Size of training dataset:", len(train_dataset))

Files already downloaded and verified
Size of training dataset: 50000


In [9]:
from pytorch_ood.model import WideResNet
from torchcam.methods import LayerCAM

num_classes = 10
img_size = 32
input_shape = (3, img_size, img_size)

model = WideResNet(num_classes=num_classes, pretrained="cifar10-pt").cuda().eval()
target_layer = model.block3
localize_net = LayerCAM(model, target_layer=target_layer, input_shape=input_shape)

In [10]:
import cv2
from tqdm import tqdm
import numpy as np

cam_lambda = 0.3
save_dir = os.path.join(f"./cifar10_KIRBY/ood_training_images")

# Generate the folder to save the processed images.
for class_idx in range(num_classes):
    os.makedirs(os.path.join(save_dir, str(class_idx)), exist_ok=True)

for batch_idx, (data, target) in enumerate(tqdm(train_loader)):
    if torch.cuda.is_available():
        data = data.cuda()
        target = target.cuda()

    # Forward the original images.
    out = model(data)
    # Calculate the Class Activation Map (CAM) using the outputs of the model.
    activation_map = localize_net(out.squeeze(0).argmax().item(), out)
    activation_map = activation_map[0].squeeze().detach().cpu().numpy()

    # Resize the resolution of activation maps as the same size of images.
    if activation_map.shape[0] != img_size:
        x = cv2.resize(activation_map, (img_size, img_size))
    else:
        x = activation_map
    activation_map = x

    # Save the original images as NumPy objects.
    x_data_array = np.transpose(data.detach().cpu().numpy(), [0, 2, 3, 1])
    origin_x_data = (x_data_array * np.array(std).reshape([1, 1, 1, 3])) + np.array(mean).reshape([1, 1, 1, 3])
    origin_x_data = np.uint8(origin_x_data * 255)[0]
    
    # Get mask images in which the regions whose values are lower than the threshold are masked.
    background_mask = np.uint8(activation_map < cam_lambda)
    # Remove the masked area.
    remove_image = np.copy(origin_x_data) * np.expand_dims(background_mask, axis=-1)
    # Generate target mask images for in-painting.
    target_mask = -1 * (background_mask.astype(np.float32) - 1.)
    # Generate the synthesized OOD images using the FS (Fast Marching) in-painting method.
    inpaint = cv2.inpaint(remove_image, target_mask.astype(np.uint8), 5, cv2.INPAINT_TELEA)

    # Save the final synthesized OOD training images.
    class_idx = target.detach().cpu().numpy().flatten()[0]

    # When we use the ImageFolder library, it is recommended to use the reverse class_to_idx map.
    save_original_train_path = os.path.join(save_dir, str(idx_to_class[class_idx]), f"{batch_idx}_{class_idx}_{str(cam_lambda)}_original.png")
    save_ood_train_path = os.path.join(save_dir, str(idx_to_class[class_idx]), f"{batch_idx}_{class_idx}_{str(cam_lambda)}_train.png")
    save_ood_mask_path = os.path.join(save_dir, str(idx_to_class[class_idx]), f"{batch_idx}_{idx_to_class[class_idx]}_{str(cam_lambda)}_mask.png")

    cv2.imwrite(save_original_train_path, origin_x_data[..., ::-1].astype(np.uint8))
    cv2.imwrite(save_ood_mask_path, (target_mask * 255).astype(np.uint8))
    cv2.imwrite(save_ood_train_path, inpaint[..., ::-1].astype(np.uint8))

100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 50000/50000 [25:30<00:00, 32.68it/s]


In [18]:
import os

# Verify the number of images per each class.
for i in range(num_classes):
    directory = f"{save_dir}/{i}"
    cnt = 0
    # Iterate all image paths for a certain label.
    for path in os.listdir(directory):
        # Count all the image files.
        if os.path.isfile(os.path.join(directory, path)):
            cnt += 1
    # Print the number of files per label.
    print(f"The number of file: {cnt} for the label {i}")

The number of file: 15000 for the label 0
The number of file: 15000 for the label 1
The number of file: 15000 for the label 2
The number of file: 15000 for the label 3
The number of file: 15000 for the label 4
The number of file: 15000 for the label 5
The number of file: 15000 for the label 6
The number of file: 15000 for the label 7
The number of file: 15000 for the label 8
The number of file: 15000 for the label 9


In [20]:
!du -hs ./cifar10_KIRBY/ood_training_images

593M	./cifar10_KIRBY/ood_training_images


#### <b>(Option) Remove Unnecessary Files</b>

<pre>
import os

dir_name = "./cifar10_KIRBY/ood_training_images"

for i in range(num_classes):
    test = os.listdir(dir_name + "/" + str(i))
    for item in test:
        if item.endswith("_original.png") or item.endswith("_mask.png"):
            os.remove(os.path.join(dir_name + "/" + str(i) + "/" + item))
</pre>

In [None]:
!du -hs ./cifar10_KIRBY/ood_training_images

In [None]:
!zip -r output.zip ./cifar10_KIRBY/ood_training_images