In [1]:
!git clone https://github.com/anminhhung/small_dog_cat_dataset

Cloning into 'small_dog_cat_dataset'...
remote: Enumerating objects: 2608, done.[K
remote: Total 2608 (delta 0), reused 0 (delta 0), pack-reused 2608[K
Receiving objects: 100% (2608/2608), 55.84 MiB | 13.61 MiB/s, done.
Resolving deltas: 100% (1/1), done.


In [5]:
import torchvision.transforms as transforms
import torchvision.models as models
import torchvision.transforms.functional as TF
import torch.nn as nn
import os
from PIL import Image
import numpy as np

from google.colab.patches import cv2_imshow

In [8]:
def top_k(a):
    k = 6
    idx = np.argpartition(a.ravel(),a.size-k)[-k:]
    return np.column_stack(np.unravel_index(idx, a.shape))

def get_attentive_regions(image):
    """
    CIFAR return top k from 8x8
    ImageNet return top k from 7x7
    """
    x = TF.to_tensor(image).unsqueeze_(0).cuda()
    output = model(x)
    last_feature_map = output[0][-1].detach().cpu().numpy()
    return top_k(last_feature_map)

def replace_attentive_regions(rand_img, image, attentive_regions):
    """
    rand_img: the img to be replaced
    image: where the 'patches' come from
    attentive_regions: an array contains the coordinates of attentive regions
    """
    np_rand_img, np_img = np.array(rand_img), np.array(image)
    for attentive_region in attentive_regions:
        replace_attentive_region(np_rand_img, np_img, attentive_region)
    return Image.fromarray(np_rand_img)

def replace_attentive_region(np_rand_img, np_img, attentive_region):
    x, y = attentive_region
    x1, x2, y1, y2 = grid_size * x, grid_size * (x+1), grid_size * y, grid_size * (y+1)
    region = np_img[x1:x2, y1: y2]
    np_rand_img[x1:x2, y1:y2] = region

def select_random_image(i):
    rand_index = np.random.randint(0, len(img_paths))
    while rand_index == i:
        rand_index = np.random.randint(0, len(img_paths))
    rand_img = Image.open(root+img_paths[rand_index])
    return rand_img

model = models.resnet50(pretrained=True)
temp_model = nn.Sequential(*list(model.children())[:-2])
model = temp_model.cuda()
grid_size = 32

root = 'small_dog_cat_dataset/test/cats/'
img_paths = os.listdir(root)
if not os.path.exists('mixed_images'):
  os.mkdir('mixed_images')

for i, rel_path in enumerate(img_paths):
  try:
    img_path = root + rel_path
    img = Image.open(img_path).resize((224, 224))
    rand_img = select_random_image(i)
    ori_size = rand_img.size
    rand_img.resize((224, 224))
    attentive_regions = get_attentive_regions(img)
    rand_img = replace_attentive_regions(rand_img, img, attentive_regions)
    rand_img.resize(ori_size)
    rand_img.save('./mixed_images/{}'.format(rel_path))
  except:
    pass

