# 使用 tensorflow & keras 实现基于图像语义分割的“魔法换天”

In [16]:
import os
import copy
import time
import queue
import sys

# numpy for matrix compute
import numpy as np

# using tensorflow 2.1.0+
import tensorflow as tf
from tensorflow.keras import initializers
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import *
from tensorflow.keras.losses import categorical_crossentropy
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.regularizers import l2
from tensorflow.keras.callbacks import ModelCheckpoint, TensorBoard

# pillow for image data process
from PIL import Image, ImageFile

# set LOAD_TRUNCATED_IMAGES to True to load 8KB below images
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
# input images format
IMAGE_WIDTH = 224
IMAGE_HEIGHT = 224
IMAGE_SHAPE = (IMAGE_WIDTH, IMAGE_HEIGHT, 3)

# segmentation classes: sky and non-sky
N_CLASSES = 2

In [3]:
# ResNet101 Model
class Scale(Layer):
    def __init__(self, weights=None, axis=-1, momentum = 0.9, beta_init='zero', gamma_init='one', **kwargs):
        self.momentum = momentum
        self.axis = axis
        self.beta_init = initializers.get(beta_init)
        self.gamma_init = initializers.get(gamma_init)
        self.initial_weights = weights
        super(Scale, self).__init__(**kwargs)

    def build(self, input_shape):
        self.input_spec = [InputSpec(shape=input_shape)]
        shape = (int(input_shape[self.axis]),)

        self.gamma = tf.keras.backend.variable(self.gamma_init(shape), name='{}_gamma'.format(self.name))
        self.beta = tf.keras.backend.variable(self.beta_init(shape), name='{}_beta'.format(self.name))
        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights

    def call(self, x, mask=None):
        input_shape = self.input_spec[0].shape
        broadcast_shape = [1] * len(input_shape)
        broadcast_shape[self.axis] = input_shape[self.axis]

        out = tf.keras.backend.reshape(self.gamma, broadcast_shape) * x + tf.keras.backend.reshape(self.beta, broadcast_shape)
        return out

    def get_config(self):
        config = {"momentum": self.momentum, "axis": self.axis}
        base_config = super(Scale, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

def identity_block(input_tensor, kernel_size, filters, stage, block):
    eps = 1.1e-5
    nb_filter1, nb_filter2, nb_filter3 = filters
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    scale_name_base = 'scale' + str(stage) + block + '_branch'

    x = Conv2D(nb_filter1, (1, 1), name=conv_name_base + '2a', use_bias=False)(input_tensor)
    x = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Scale(axis=bn_axis, name=scale_name_base + '2a')(x)
    x = Activation('relu', name=conv_name_base + '2a_relu')(x)

    x = ZeroPadding2D((1, 1), name=conv_name_base + '2b_zeropadding')(x)
    x = Conv2D(nb_filter2, kernel_size,
                      name=conv_name_base + '2b', use_bias=False)(x)
    x = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Scale(axis=bn_axis, name=scale_name_base + '2b')(x)
    x = Activation('relu', name=conv_name_base + '2b_relu')(x)

    x = Conv2D(nb_filter3, (1, 1), name=conv_name_base + '2c', use_bias=False)(x)
    x = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '2c')(x)
    x = Scale(axis=bn_axis, name=scale_name_base + '2c')(x)

    x = Add(name='res' + str(stage) + block)([x, input_tensor])
    x = Activation('relu', name='res' + str(stage) + block + '_relu')(x)
    return x

def conv_block(input_tensor, kernel_size, filters, stage, block, strides=(2, 2)):
    eps = 1.1e-5
    nb_filter1, nb_filter2, nb_filter3 = filters
    conv_name_base = 'res' + str(stage) + block + '_branch'
    bn_name_base = 'bn' + str(stage) + block + '_branch'
    scale_name_base = 'scale' + str(stage) + block + '_branch'

    x = Conv2D(nb_filter1, (1, 1), strides=strides,
                      name=conv_name_base + '2a', use_bias=False)(input_tensor)
    x = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '2a')(x)
    x = Scale(axis=bn_axis, name=scale_name_base + '2a')(x)
    x = Activation('relu', name=conv_name_base + '2a_relu')(x)

    x = ZeroPadding2D((1, 1), name=conv_name_base + '2b_zeropadding')(x)
    x = Conv2D(nb_filter2, kernel_size,
                      name=conv_name_base + '2b', use_bias=False)(x)
    x = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '2b')(x)
    x = Scale(axis=bn_axis, name=scale_name_base + '2b')(x)
    x = Activation('relu', name=conv_name_base + '2b_relu')(x)

    x = Conv2D(nb_filter3, (1, 1), name=conv_name_base + '2c', use_bias=False)(x)
    x = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '2c')(x)
    x = Scale(axis=bn_axis, name=scale_name_base + '2c')(x)

    shortcut = Conv2D(nb_filter3, (1, 1), strides=strides,
                             name=conv_name_base + '1', use_bias=False)(input_tensor)
    shortcut = BatchNormalization(epsilon=eps, axis=bn_axis, name=bn_name_base + '1')(shortcut)
    shortcut = Scale(axis=bn_axis, name=scale_name_base + '1')(shortcut)

    x = Add(name='res' + str(stage) + block)([x, shortcut])
    x = Activation('relu', name='res' + str(stage) + block + '_relu')(x)
    return x

def resnet101_model(input_shape):
    eps = 1.1e-5

    global bn_axis
    bn_axis = 3
    img_input = Input(shape=input_shape, name='data')

    x = ZeroPadding2D((3, 3), name='conv1_zeropadding')(img_input)
    x = Conv2D(64, (7, 7), strides=(2, 2), name='conv1', use_bias=False)(x)
    x = BatchNormalization(epsilon=eps, axis=bn_axis, name='bn_conv1')(x)
    x = Scale(axis=bn_axis, name='scale_conv1')(x)
    x = Activation('relu', name='conv1_relu')(x)
    x = MaxPooling2D((3, 3), strides=(2, 2), name='pool1', padding = 'same')(x)
    
    x = conv_block(x, (3,3), [64, 64, 256], stage=2, block='a', strides=(1,1)) #conv2_1
    x = identity_block(x, (3,3), [64, 64, 256], stage=2, block='b') #conv2_2
    block_1_out = identity_block(x, (3,3), [64, 64, 256], stage=2, block='c') #conv2_3

    x = conv_block(block_1_out, (3,3), [128, 128, 512], stage=3, block='a') #conv3_1
    for i in range(1,3):
      x = identity_block(x, (3,3), [128, 128, 512], stage=3, block='b'+str(i)) #conv3_2-3
    block_2_out = identity_block(x, (3,3), [128, 128, 512], stage=3, block='b3') #conv3_4

    x = conv_block(block_2_out, (3,3), [256, 256, 1024], stage=4, block='a') #conv4_1
    for i in range(1,22):
      x = identity_block(x, (3,3), [256, 256, 1024], stage=4, block='b'+str(i)) #conv4_2-22
    block_3_out = identity_block(x, (3,3), [256, 256, 1024], stage=4, block='b22') #conv4_23

    x = conv_block(block_3_out, (3,3), [512, 512, 2048], stage=5, block='a') #conv5_1
    x = identity_block(x, (3,3), [512, 512, 2048], stage=5, block='b') #conv5_2
    block_4_out = identity_block(x, (3,3), [512, 512, 2048], stage=5, block='c') #conv5_3
    
    model = Model(inputs = [img_input], outputs = [block_4_out, block_3_out, block_2_out, block_1_out])
    model.load_weights("./resnet101_weights_tf.h5", by_name=True)

    return model

In [4]:
# refinenet on keras for semantic-segmentation
def ResidualConvUnit(input_layer, filters=256, kernel_size=(3, 3)):
    net = ReLU()(input_layer)
    net = Conv2D(filters, kernel_size, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-5))(net)
    net = ReLU()(net)
    net = Conv2D(filters, kernel_size, padding="same", kernel_initializer="he_normal", kernel_regularizer=l2(1e-5))(net)
    net = Add()([net, input_layer])
    return net

kern_init = "he_normal"
kern_reg = l2(1e-5)


def ChainedResidualPooling(inputs,n_filters=256,name=''):
    net = ReLU()(inputs)
    net_out_1 = net
    
    net = Conv2D(n_filters, 3, padding='same', kernel_initializer=kern_init, kernel_regularizer=kern_reg)(net)
    net = BatchNormalization()(net)
    net = MaxPool2D(pool_size = (5,5), strides = 1, padding = 'same', data_format='channels_last')(net)
    net_out_2 = net
    
    net = Conv2D(n_filters, 3, padding='same', kernel_initializer=kern_init, kernel_regularizer=kern_reg)(net)
    net = BatchNormalization()(net)
    net = MaxPool2D(pool_size = (5,5), strides = 1, padding = 'same', data_format='channels_last')(net)
    net_out_3 = net
    
    net = Conv2D(n_filters, 3, padding='same', kernel_initializer=kern_init, kernel_regularizer=kern_reg)(net)
    net = BatchNormalization()(net)
    net = MaxPool2D(pool_size = (5,5), strides = 1, padding = 'same', data_format='channels_last')(net)
    net_out_4 = net
    
    net = Conv2D(n_filters, 3, padding='same', name=name+'conv4', kernel_initializer=kern_init, kernel_regularizer=kern_reg)(net)
    net = BatchNormalization()(net)
    net = MaxPool2D(pool_size = (5,5), strides = 1, padding = 'same', data_format='channels_last')(net)
    net_out_5 = net
    
    net = Add()([net_out_1, net_out_2, net_out_3, net_out_4, net_out_5])

    return net


def MultiResolutionFusion(high_inputs=None,low_inputs=None,n_filters=256,name=''):
    if low_inputs is None: # RefineNet block 4
        return high_inputs

    else:
        conv_low = Conv2D(n_filters, 3, padding='same', name=name+'conv_lo', kernel_initializer=kern_init, kernel_regularizer=kern_reg)(low_inputs)
        conv_low = BatchNormalization()(conv_low)
        conv_high = Conv2D(n_filters, 3, padding='same', name=name+'conv_hi', kernel_initializer=kern_init, kernel_regularizer=kern_reg)(high_inputs)
        conv_high = BatchNormalization()(conv_high)
        
        conv_low_up = UpSampling2D(size=2, interpolation='bilinear', name=name+'up')(conv_low)
        
        return Add()([conv_low_up, conv_high])


def RefineBlock(high_inputs=None, low_inputs=None, block=0):
    if low_inputs is None:
        rcu_high = ResidualConvUnit(high_inputs, filters=512)
        rcu_high = ResidualConvUnit(rcu_high, filters=512)
        fuse = MultiResolutionFusion(high_inputs = rcu_high,
                                     low_inputs = None,
                                     n_filters = 512,
                                     name = 'rb_{}_mrf_'.format(block))
        
        fuse_pooling = ChainedResidualPooling(fuse, n_filters = 512, name='rb_{}_crp_'.format(block))
        
        output = ResidualConvUnit(fuse, filters = 512)
        return output
    else:
        high_n = tf.keras.backend.int_shape(high_inputs)[-1]
        low_n = tf.keras.backend.int_shape(low_inputs)[-1]
        
        rcu_high = ResidualConvUnit(high_inputs, filters = high_n)
        rcu_high = ResidualConvUnit(rcu_high,filters = high_n)
        
        rcu_low = ResidualConvUnit(low_inputs, filters = low_n)
        rcu_low = ResidualConvUnit(rcu_low, filters = low_n)

        fuse = MultiResolutionFusion(high_inputs = rcu_high,
                                     low_inputs = rcu_low,
                                     n_filters = 256,
                                     name = 'rb_{}_mrf_'.format(block))
        fuse_pooling = ChainedResidualPooling(fuse, n_filters = 256, name='rb_{}_crp_'.format(block))
        output = ResidualConvUnit(fuse_pooling, filters = 256)
        return output

model_base = resnet101_model(IMAGE_SHAPE)

high_layers = model_base.output
low_layers = [None, None, None]

high_layers[0] = Conv2D(512, 1, padding='same', kernel_initializer="he_normal", kernel_regularizer=l2(1e-5))(high_layers[0])
high_layers[1] = Conv2D(256, 1, padding='same', kernel_initializer="he_normal", kernel_regularizer=l2(1e-5))(high_layers[1])
high_layers[2] = Conv2D(256, 1, padding='same', kernel_initializer="he_normal", kernel_regularizer=l2(1e-5))(high_layers[2])
high_layers[3] = Conv2D(256, 1, padding='same', kernel_initializer="he_normal", kernel_regularizer=l2(1e-5))(high_layers[3])

for h in high_layers:
    h = BatchNormalization()(h)

low_layers[0] = RefineBlock(high_inputs = high_layers[0], low_inputs = None, block = 4)
low_layers[1] = RefineBlock(high_inputs = high_layers[1], low_inputs = low_layers[0], block = 3)
low_layers[2] = RefineBlock(high_inputs = high_layers[2], low_inputs = low_layers[1], block = 2)
net = RefineBlock(high_inputs = high_layers[3], low_inputs = low_layers[2], block = 1)
net = ResidualConvUnit(net)
net = ResidualConvUnit(net)

net = UpSampling2D(size=4, interpolation='bilinear', name='rf_up_o')(net)
net = Conv2D(N_CLASSES, 1, activation = 'softmax', name='rf_pred')(net)
model = Model(model_base.input, net)

In [5]:
model.compile(optimizer=Adam(1e-4),
              loss="categorical_crossentropy",
              metrics=['accuracy'])

model.summary()

Model: "model_1"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
data (InputLayer)               [(None, 224, 224, 3) 0                                            
__________________________________________________________________________________________________
conv1_zeropadding (ZeroPadding2 (None, 230, 230, 3)  0           data[0][0]                       
__________________________________________________________________________________________________
conv1 (Conv2D)                  (None, 112, 112, 64) 9408        conv1_zeropadding[0][0]          
__________________________________________________________________________________________________
bn_conv1 (BatchNormalization)   (None, 112, 112, 64) 256         conv1[0][0]                      
____________________________________________________________________________________________

scale4b15_branch2b (Scale)      (None, 14, 14, 256)  512         bn4b15_branch2b[0][0]            
__________________________________________________________________________________________________
res4b15_branch2b_relu (Activati (None, 14, 14, 256)  0           scale4b15_branch2b[0][0]         
__________________________________________________________________________________________________
res4b15_branch2c (Conv2D)       (None, 14, 14, 1024) 262144      res4b15_branch2b_relu[0][0]      
__________________________________________________________________________________________________
bn4b15_branch2c (BatchNormaliza (None, 14, 14, 1024) 4096        res4b15_branch2c[0][0]           
__________________________________________________________________________________________________
scale4b15_branch2c (Scale)      (None, 14, 14, 1024) 2048        bn4b15_branch2c[0][0]            
__________________________________________________________________________________________________
res4b15 (A

add_7 (Add)                     (None, 7, 7, 512)    0           conv2d_20[0][0]                  
                                                                 add_6[0][0]                      
__________________________________________________________________________________________________
conv2d_16 (Conv2D)              (None, 14, 14, 256)  590080      re_lu_10[0][0]                   
__________________________________________________________________________________________________
rb_3_mrf_conv_lo (Conv2D)       (None, 7, 7, 256)    1179904     add_7[0][0]                      
__________________________________________________________________________________________________
add_5 (Add)                     (None, 14, 14, 256)  0           conv2d_16[0][0]                  
                                                                 add_4[0][0]                      
__________________________________________________________________________________________________
batch_norm

In [6]:
# prepare training data
with open("./train.txt", "r") as file:
    lines = file.readlines()

# shuffle the input data
np.random.seed(998244353)
np.random.shuffle(lines)
np.random.seed(None)

# split test dataset size and validation dataset size
validation_size = int(len(lines) * 0.25)
train_size = len(lines) - validation_size

# make a custom data generator for keras
def data_generator(lines, batch_size):
    n = len(lines)
    i = 0
    while 1:
        X_train = []
        Y_train = []
        
        for idx in range(batch_size):
            name = lines[i].split(';')[0]
            img = Image.open(os.path.join("./jpg", name))
            img = img.resize((IMAGE_WIDTH, IMAGE_HEIGHT))
            img = np.array(img)
            img = img / 255
            
            X_train.append(img)

            name = (lines[i].split(';')[1]).replace("\n", "")
            
            img = Image.open(os.path.join("./png", name))
            img = img.resize((int(IMAGE_WIDTH), int(IMAGE_HEIGHT)))
            img = np.array(img)
            seg_labels = np.zeros((int(IMAGE_HEIGHT),int(IMAGE_WIDTH),N_CLASSES))
            for c in range(N_CLASSES):
                seg_labels[: , : , c ] = (img[:,:,0] == c ).astype(int)
            Y_train.append(seg_labels)

            i = (i + 1) % n
            
        yield (np.array(X_train),np.array(Y_train))

In [39]:
# train the model!
BATCH_SIZE = 10
EPOCH = 5
TENSORBOARD_LOG_NAME = 'semantic-segmentation-{}'.format(int(time.time()))
tensorboard = TensorBoard(log_dir='./tf_dir/{}'.format(TENSORBOARD_LOG_NAME))

model.fit(data_generator(lines[:train_size], BATCH_SIZE),
          steps_per_epoch=max(1, train_size // BATCH_SIZE),
          validation_data=data_generator(lines[train_size:], BATCH_SIZE),
          validation_steps=max(1, validation_size // BATCH_SIZE),
          epochs=EPOCH,
          callbacks=[tensorboard])

  ...
    to  
  ['...']
  ...
    to  
  ['...']
Train for 1034 steps, validate for 344 steps
Epoch 1/5
Epoch 2/5
Epoch 3/5
Epoch 4/5
Epoch 5/5


<tensorflow.python.keras.callbacks.History at 0x7fa4f94e9e80>

In [40]:
# save the model as tensorflow format (.pb) & keras format (.h5)
tf.saved_model.save(model, "retrained_model")
model.save_weights("retrain.h5")

INFO:tensorflow:Assets written to: retrained_model/assets


In [6]:
model.load_weights("retrain.h5")    # load pretrained model weight

In [7]:
def model_predict_as_array(image_path):
    img = Image.open(image_path)
    array = np.array(img.resize((IMAGE_WIDTH, IMAGE_HEIGHT)))
    array = array / 255
    array = array.reshape(-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)
    
    pr = model.predict(array)[0].reshape((IMAGE_HEIGHT, IMAGE_WIDTH, N_CLASSES)).argmax(axis=-1)
    
    segment_img = np.zeros((IMAGE_HEIGHT, IMAGE_WIDTH, 3))
    for c in range(N_CLASSES):
        segment_img[:,:,0] = ((pr[:, : ] == c) * 255).astype('uint8')
        segment_img[:,:,1] = ((pr[:, : ] == c) * 255).astype('uint8')
        segment_img[:,:,2] = ((pr[:, : ] == c) * 255).astype('uint8')       
    return segment_img

In [10]:
# evaluate
evaluate_path = "./evaluate"
predict_save_path = "./predict"
evaluate_images = os.listdir(evaluate_path)

for filename in evaluate_images:
    if not filename.endswith(".jpg"):
        continue
    img = Image.open(os.path.join(evaluate_path, filename))
    original = copy.deepcopy(img)
    original_arr = np.array(original)
    original_height = original_arr.shape[0]
    original_width = original_arr.shape[1]

    img = np.array(img.resize((IMAGE_WIDTH, IMAGE_HEIGHT)))
    img = img / 255
    img = img.reshape(-1, IMAGE_HEIGHT, IMAGE_WIDTH, 3)

    prediction = model.predict(img)[0]
    pr = prediction.reshape((int(IMAGE_HEIGHT), int(IMAGE_WIDTH), N_CLASSES)).argmax(axis=-1)

    colors = [[0,0,0],[0,255,0]]

    seg_img = np.zeros((int(IMAGE_HEIGHT), int(IMAGE_WIDTH), 3))
    for c in range(N_CLASSES):
        seg_img[:,:,0] += ((pr[:, : ] == c) * (colors[c][0])).astype('uint8')
        seg_img[:,:,1] += ((pr[:,: ] == c )*( colors[c][1] )).astype('uint8')
        seg_img[:,:,2] += ((pr[:,: ] == c )*( colors[c][2] )).astype('uint8')
    seg_img = Image.fromarray(np.uint8(seg_img)).resize((original_width, original_height))
    image = Image.blend(original, seg_img, 0.3)
    image.save(os.path.join(predict_save_path, "%s.predict.png" % filename))

In [41]:
# convert to tensorflow-lite model
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.post_training_quantize = True
converter.optimizations = [tf.lite.Optimize.DEFAULT]

tflite_model = converter.convert()

open('model_tflite.tflite', 'wb').write(tflite_model)

96776920

In [59]:
process_file = "./evaluate/1.jpg"

origin = Image.open(process_file)
origin_pixels = np.asarray(origin)
(original_width, original_height) = (origin_pixels.shape[1], origin_pixels.shape[0])
print(original_width, original_height)

alpha_array = model_predict_as_array(process_file)
alpha_img = Image.fromarray(np.uint8(alpha_array)).resize((original_width, original_height))
alpha_pixels = np.asarray(alpha_img)

(r_average, g_average, b_average) = (0, 0, 0)
predict_pixel_count = 0

for y in range(original_height):
    for x in range(original_width):
        if alpha_pixels[y][x][0] == 255:
            predict_pixel_count += 1
            r_average += origin_pixels[y][x][0]
            g_average += origin_pixels[y][x][1]
            b_average += origin_pixels[y][x][2]

r_average //= predict_pixel_count
g_average //= predict_pixel_count
b_average //= predict_pixel_count

print(r_average, g_average, b_average)

p_img = np.zeros((original_height, original_width, 4))
visited = np.zeros((original_height, original_width))

direction = [[-1, 0], [1, 0], [0, -1], [0, 1]]


for y in range(original_height):
    for x in range(original_width):
        p_img[y][x] = [origin_pixels[y][x][0], origin_pixels[y][x][1], origin_pixels[y][x][2], 255]

def mark_useful_pixels(y, x):
    q = queue.Queue(maxsize = original_width * original_height)
    q.put((y, x))
    
    visited[y][x] = 1

    while q.qsize() != 0:
        (cy, cx) = q.get()
        r = origin_pixels[cy][cx][0]
        g = origin_pixels[cy][cx][1]
        b = origin_pixels[cy][cx][2]
    
        if (alpha_pixels[cy][cx][0] >= 30 and (math.sqrt(((r - r_average) ** 2 + (g - g_average) ** 2 + (b - b_average) ** 2) / (255 ** 2 * 3)) <= 0.25) or (r < 150 and g < 150 and b >= 150)) \
        or (alpha_pixels[cy][cx][0] == 0 and math.sqrt(((r - r_average) ** 2 + (g - g_average) ** 2 + (b - b_average) ** 2) / (255 ** 2 * 3)) <= 0.15):
            p_img[cy][cx] = [255, 255, 255, 0]        
            for i in range(0, 4):
                ny = cy + direction[i][0]
                nx = cx + direction[i][1]
                if ny >= 0 and nx >= 0 and ny < original_height and nx < original_width:
                    if not visited[ny][nx]:
                        visited[ny][nx] = 1
                        q.put((ny, nx))
                    
for y in range(original_height):
    for x in range(original_width):
#         if alpha_pixels[y][x][0] != 255 and alpha_pixels[y][x][0] != 0:
#             print(alpha_pixels[y][x])
        if not visited[y][x] and alpha_pixels[y][x][0] >= 30:
            mark_useful_pixels(y, x)

p_img_object = Image.fromarray(np.uint8(p_img))
p_img_object.save("./predict/test.png")

1440 1080
230 218 238


In [60]:
# composite
background = Image.open("./rosy.jpg").resize((original_width, original_height)).convert('RGBA')
background_pixels = np.asarray(background)
composite_pixels = np.zeros((original_height, original_width, 4))
for y in range(original_height):
    for x in range(original_width):
        if p_img[y][x][3] == 255:
            composite_pixels[y][x] = p_img[y][x]
        else:
            composite_pixels[y][x] = background_pixels[y][x]

composite = Image.fromarray(np.uint8(composite_pixels))
composite.save("./predict/composite.png")