In [1]:
import PIL
from PIL import Image
from PIL import ImageDraw
from PIL import ImageFilter

import matplotlib.pyplot as plt
import numpy as np
import os
import cv2
import random
from pascal_voc_writer import Writer
from skimage.util import random_noise


def RandomPaste(origin, img, bndbox, overlapping):
    # Resize the pasted image
    width, height = origin.size
    propotion = random.uniform(0.1, 0.7)
    img_width = int(propotion * width)
    img_height = int(propotion * height)
    img = img.resize((img_width, img_height), Image.ANTIALIAS)
    
    # Rotate the pasted image
    rotate_angle = random.randint(0, 360)
    img = img.rotate(rotate_angle, expand=True)
    
    # Crop extra edges
    maxsize = (width / 2, height / 2)
    img.thumbnail(maxsize, Image.ANTIALIAS)
    imageSize = img.size
    imageComponents = img.split()
    rgbImage = Image.new("RGB", imageSize, (0,0,0))
    rgbImage.paste(img, mask=imageComponents[3])
    croppedBox = rgbImage.getbbox()
    img = img.crop(croppedBox)
    
    # Paste image
    r,g,b,a = img.split()
    img_width, img_height = img.size
    img_x = int(random.uniform(0, 1) * (width - img_width))
    img_y = int(random.uniform(0, 1) * (height - img_height))
    newbox = (img_x, img_y, img_width, img_height)
    
    if overlapping:
        origin.paste(img, (img_x, img_y), a)
        bndbox = combine_boxes(bndbox, newbox)
    elif not is_overlapping(bndbox, newbox):
        origin.paste(img, (img_x, img_y), a)
        bndbox.append(newbox)
    else:
        return False
    
    return True

In [2]:
def union(a, b):
    x = min(a[0], b[0])
    y = min(a[1], b[1])
    w = max(a[0] + a[2], b[0] + b[2]) - x
    h = max(a[1] + a[3], b[1] + b[3]) - y
    return (x, y, w, h)

In [3]:
def intersection(a, b):
    x = max(a[0], b[0])
    y = max(a[1], b[1])
    w = min(a[0] + a[2], b[0] + b[2]) - x
    h = min(a[1] + a[3], b[1] + b[3]) - y
    if w < 0 or h < 0: 
        return ()
    return (x, y, w, h)

In [4]:
def combine_boxes(boxes, new):
    flag = False
    for i in range(len(boxes) - 1, -1, -1):
        if intersection(boxes[i], new):
            flag = True
            newbox = union(boxes[i], new)
            boxes.remove(boxes[i])
            combine_boxes(boxes, newbox)
            
    if not flag:
        boxes.append(new)
        
    return boxes

In [5]:
def is_overlapping(boxes, new):
    for b in boxes:
        if intersection(b, new):
            return True
    return False

In [6]:
def add_noise(img):
    mode = ['gaussian','localvar','poisson','salt','pepper','s&p','speckle','gaussian','gaussian','gaussian']
    img_arr = np.asarray(img)
    mode_index = random.uniform(0,10)
    mode_index = int(mode_index)
    #暂定高斯
#     mode_index = 0
    if mode[mode_index] == 'gaussian' or mode[mode_index] =='speckle':
        devia = random.uniform(0,0.1)
        noise_img = random_noise(img_arr, mode='gaussian', var=devia**2)
    else:
        noise_img = random_noise(img_arr, mode = mode[mode_index])
        
    noise_img = (255*noise_img).astype(np.uint8)
    img = Image.fromarray(noise_img)
    return img

In [7]:
def convert_temp(img):
    kelvin_table = {
    1000: (255,56,0),
    1500: (255,109,0),
    2000: (255,137,18),
    2500: (255,161,72),
    3000: (255,180,107),
    3500: (255,196,137),
    4000: (255,209,163),
    4500: (255,219,186),
    5000: (255,228,206),
    5500: (255,236,224),
    6000: (255,243,239),
    6500: (255,249,253),
    7000: (245,243,255),
    7500: (235,238,255),
    8000: (227,233,255),
    8500: (220,229,255),
    9000: (214,225,255),
    9500: (208,222,255),
    10000: (204,219,255)}
    temp_index = int(random.uniform(0,19))
#     temp = 2500
#     print(kelvin_table.keys())
    temp = list(kelvin_table.keys())[temp_index]
    r, g, b = kelvin_table[temp]
    matrix = ( r / 255.0, 0.0, 0.0, 0.0,
               0.0, g / 255.0, 0.0, 0.0,
               0.0, 0.0, b / 255.0, 0.0 )
    return img.convert('RGB', matrix)

In [8]:
def generatePictures(num_pair, num_obj=None, noise_type=None, deviation=False, overlapping=False):
    '''
    num_pair: The number of the pairs generated
    num_obj: The number of the objects pasted, default to 2-3
    noise_type: The type of the noise applied to the pics, default to random
    deviation: Whether apply deviation to the pics, default to False
    '''
    
    bg_path = "./img/background"
    obj_path = "./img/object"
    images_path = "./data/images"
    xml_path = "./data/XML"
    
    for i in range(num_pair):
        count = i + 1
        bndbox = []
        bg_name = random.choice(os.listdir(bg_path))
        background = Image.open(os.path.join(bg_path, bg_name))
        background.save(os.path.join(images_path, str(count) + '.jpg'))
        
        if num_obj is None:
            num_obj = random.randint(1, 5);

        for i in range(num_obj):
            obj = Image.open(os.path.join(obj_path, random.choice(os.listdir(obj_path))))
            RandomPaste(background, obj, bndbox, overlapping)

        # Add noise
#         plt.figure(figsize=(10, 20))
#         plt.imshow(background)
        background = add_noise(background)
#         plt.figure(figsize=(10, 20))
#         plt.imshow(background)
        # Add deviation

        # Adjust Color Temperature
#         plt.figure(figsize=(10, 20))
#         plt.imshow(background)
        background = convert_temp(background)
#         plt.figure(figsize=(10, 20))
#         plt.imshow(background)
        # Adjust Color Difference
        
        
        # Draw boxes in picture
#         draw = ImageDraw.Draw(background)
#         for b in bndbox:
#             draw.rectangle((b[0], b[1], b[0]+ b[2], b[1] + b[3]), outline='red')
#         plt.figure(figsize=(10, 20))
#         plt.imshow(background)
        
        background.save(os.path.join(images_path, str(count) + '_A.jpg'))
        
        # Write Boxes to XML
        width, height = background.size
        writer = Writer(str(count) + ".jpg", width, height)
        for b in bndbox :
            writer.addObject('True', b[0], b[1], b[0] + b[2], b[1] + b[3])
        writer.save(os.path.join(xml_path, str(count) + ".xml"))

In [9]:
generatePictures(5000)