In [1]:
from keras.models import Model
from keras.layers import Input, Concatenate, Conv2D, LeakyReLU, BatchNormalization, Flatten, Dense, UpSampling2D, Reshape
from keras.preprocessing import image
from keras.optimizers import Adam
from keras import activations
import numpy as np
from PIL import Image



In [2]:
target_size = (900, 400)
head_x, head_y = 200, 150  # Replace with the actual coordinates

# Normalize the coordinates
normalized_head_x = head_x / target_size[1]
normalized_head_y = head_y / target_size[0]

# Use the normalized coordinates as input to the model
input_head_position = np.array([[normalized_head_x, normalized_head_y]])

In [3]:
# Function to load and preprocess an image
def load_and_preprocess_image(image_path):
    img = Image.open(image_path)
    img = img.convert('RGB')
    img = img.resize(target_size[::-1], Image.Resampling.LANCZOS)
    # img = image.load_img(image_path, target_size=target_size)
    img_array = image.img_to_array(img)
    img_array = np.expand_dims(img_array, axis=0)
    img_array /= 255.0  # Normalize pixel values to between 0 and 1
    return img_array

In [4]:
# Function to create a segmentation mask (1 for foreground, 0 for background)
def create_segmentation_mask(img_array):
    # Assuming a simple segmentation mask based on pixel intensity.
    mask = np.mean(img_array, axis=-1, keepdims=True) > 0.5
    # mask = image.array_to_img(mask[0, :, :, 0])
    # mask = mask.resize(target_size, Image.Resampling.LANCZOS)
    # mask = image.img_to_array(mask)
    # mask /= 255.0
    return mask

In [5]:
# Load two input images
input_person_image = load_and_preprocess_image('person_image.png')  # Replace with the actual path
input_dress_image = load_and_preprocess_image('dress_image.png')  # Replace with the actual path

In [6]:
input_head_position.shape

(1, 2)

In [7]:
input_dress_image.shape

(1, 900, 400, 3)

In [8]:
# Create segmentation masks for the person and dress images
person_mask = create_segmentation_mask(input_person_image)
dress_mask = create_segmentation_mask(input_dress_image)
# Resize the masks to match the input image dimensions
person_mask = np.array(Image.fromarray(person_mask[0, :, :, 0]).resize((target_size[1], target_size[0])))
dress_mask = np.array(Image.fromarray(dress_mask[0, :, :, 0]).resize((target_size[1], target_size[0])))

# Expand dimensions to match the input shape
person_mask = np.expand_dims(person_mask, axis=-1)
dress_mask = np.expand_dims(dress_mask, axis=-1)

In [9]:
person_mask.shape

(900, 400, 1)

In [10]:
# Define input layers
# input_person = Input(shape=(target_size[0]*4, target_size[1]*4, 3), name='input_person')
# input_dress = Input(shape=(target_size[0]*4, target_size[1]*4, 3), name='input_dress')
# input_head_position = Input(shape=(2,), name='input_head_position')  # Example: (x, y) coordinates

In [11]:
# Concatenate the person and dress images along the channels axis
# concatenated_input = Concatenate(axis=-1)([input_person, input_dress])

In [12]:
 # Add a Dense layer to incorporate head position information
# head_position_embedding = Dense(64, activation='relu')(input_head_position)
# head_position_embedding = Dense(64, activation='relu')(head_position_embedding)
# head_position_embedding = Dense(128, activation='relu')(head_position_embedding)
# expanded_head_position = Dense(target_size[0] * target_size[1], activation='relu')(input_head_position)
# expanded_head_position = Reshape((target_size[0], target_size[1], 1))(expanded_head_position)

In [13]:
# Concatenate the expanded head position with the concatenated input
# concatenated_input_with_position = Concatenate(axis=-1)([concatenated_input, expanded_head_position])

In [14]:
# Generator network
def build_generator():
    input_person = Input(shape=(target_size[0], target_size[1], 3), name='input_person')
    input_dress = Input(shape=(target_size[0], target_size[1], 3), name='input_dress')
    # input_head_position = Input(shape=(1,2), name='input_head_position')  # Example: (x, y) coordinates
    # Concatenate the person and dress images along the channels axis
    concatenated_input = Concatenate(axis=-1)([input_person, input_dress])
    print(concatenated_input.shape)
    # Add a Dense layer to incorporate head position information
    # head_position_embedding = Dense(64, activation='relu')(input_head_position)
    # head_position_embedding = Dense(64, activation='relu')(head_position_embedding)
    # head_position_embedding = Dense(128, activation='relu')(head_position_embedding)
    # expanded_head_position = Dense(target_size[0] * target_size[1], activation='relu')(input_head_position)
    # expanded_head_position = Reshape((target_size[0], target_size[1], 1))(expanded_head_position)
    # # Concatenate the expanded head position with the concatenated input
    # concatenated_input_with_position = Concatenate(axis=-1)([concatenated_input, expanded_head_position])
    gen = Conv2D(64, (3, 3), activation=activations.gelu, padding='same')(concatenated_input)
    # gen = UpSampling2D((2, 2))(gen)
    gen = Conv2D(32, (3, 3), activation=activations.gelu, padding='same')(gen)
    gen = Conv2D(64, (3, 3), activation=activations.gelu, padding='same')(gen)
    gen = Conv2D(64, (3, 3), activation=activations.gelu, padding='same')(gen)
    gen = Conv2D(64, (3, 3), activation=activations.gelu, padding='same')(gen)
    gen = Conv2D(32, (3, 3), activation=activations.gelu, padding='same')(gen)
    # gen = UpSampling2D((2, 2))(gen)
    output_img = Conv2D(3, (3, 3), activation=activations.sigmoid, padding='same')(gen)
    # Apply segmentation masks to preserve the background
    print(output_img.shape)
    print(person_mask.shape)
    print(input_person.shape)
    output_img = output_img * (1 - person_mask) + input_person * person_mask
    output_img = output_img * (1 - dress_mask) + input_dress * dress_mask
    # Create the generator model
    generator = Model(inputs=[input_person, input_dress], outputs=output_img)
    return generator

In [15]:
# Discriminator
def build_discriminator(input_shape):
    input_img = Input(shape=input_shape)
    x = Conv2D(64, (3, 3), strides=(2, 2), padding='same')(input_img)
    x = LeakyReLU(alpha=0.2)(x)
    x = BatchNormalization()(x)

    # Add more layers as needed

    x = Flatten()(x)
    validity = Dense(1, activation='sigmoid')(x)

    discriminator = Model(inputs=input_img, outputs=validity)
    return discriminator

In [16]:
# Combined model (GAN)
def build_gan(generator, discriminator):
    # Disable training of the discriminator during GAN training
    discriminator.trainable = False

    # Build GAN by chaining the generator and discriminator
    input_person = Input(shape=(target_size[0], target_size[1], 3), name='input_person')
    input_dress = Input(shape=(target_size[0], target_size[1], 3), name='input_dress')
    # input_head_position = Input(shape=(1,2), name='input_head_position')  # Example: (x, y) coordinates

    # Generate an image using the generator
    generated_img = generator([input_person, input_dress])

    # Discriminator's decision on the generated image
    validity = discriminator(generated_img)

    # Combined GAN model
    gan = Model(inputs=[input_person, input_dress], outputs=validity)
    return gan

In [17]:
# Build the generator, discriminator, and GAN
generator = build_generator()
discriminator = build_discriminator(input_shape=(target_size[0], target_size[1], 3))
# Compile the discriminator model
discriminator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])

# Compile the generator model
generator.compile(optimizer='adam', loss='binary_crossentropy', metrics=['accuracy'])
gan = build_gan(generator, discriminator)
gan.compile(optimizer=Adam(learning_rate=0.0002, beta_1=0.5), loss='binary_crossentropy')



(None, 900, 400, 6)
(None, 900, 400, 3)
(900, 400, 1)
(None, 900, 400, 3)


In [None]:
person_images = np.full((1000, 900, 400, 3), input_person_image)
dress_images = np.full((1000, 900, 400, 3), input_dress_image)
# Set the number of training iterations
num_iterations = 10

# Set the batch size
batch_size = 320

# Set the update ratio (how many times to update the discriminator per generator update)
update_ratio = 1

# Training loop
for iteration in range(num_iterations):
    # ---------------------
    # Train Discriminator
    # ---------------------
    for _ in range(update_ratio):
        # Select a random batch of real person and dress images
        idx = np.random.randint(0, person_images.shape[0], batch_size)
        real_person_batch = person_images[idx]
        real_dress_batch = dress_images[idx]
        print(real_person_batch.shape)
        print(real_dress_batch.shape)
        print(input_head_position.shape)

        # Generate a batch of fake images using the current generator
        generated_images = generator.predict([real_person_batch, real_dress_batch])

        # Create labels for real and fake samples
        real_labels = np.ones((batch_size, 1))
        fake_labels = np.zeros((batch_size, 1))

        # Train the discriminator on real samples
        d_loss_real = discriminator.train_on_batch([generated_images], real_labels)

        # Train the discriminator on fake samples
        d_loss_fake = discriminator.train_on_batch([generated_images], fake_labels)

        # Calculate total discriminator loss
        d_loss = 0.5 * np.add(d_loss_real, d_loss_fake)

    # ---------------------
    # Train Generator
    # ---------------------
    # Select a new batch of real person and dress images
    idx = np.random.randint(0, person_images.shape[0], batch_size)
    real_person_batch = person_images[idx]
    real_dress_batch = dress_images[idx]

    # Create labels for the generator (tricking the discriminator)
    valid_labels = np.ones((batch_size, 1))

    # Train the generator to minimize the discriminator's loss
    g_loss = gan.train_on_batch([real_person_batch, real_dress_batch], valid_labels)

    # Print progress and losses (optional)
    # if iteration % 100 == 0:
    print(f"Iteration {iteration}, D Loss: {d_loss[0]}, G Loss: {g_loss}")

(320, 900, 400, 3)
(320, 900, 400, 3)
(1, 2)


In [None]:
# Display the model summary
generator.summary()

In [18]:
# Generate the merged image
merged_image_array = generator.predict([input_person_image, input_dress_image])



In [19]:
# Save the merged image as a PNG file
merged_image_path = 'merged_image.png'
image.save_img(merged_image_path, merged_image_array[0])