In [1]:
import torch.nn as nn
import torch.nn.functional as F
import torch

import numpy as np

import matplotlib.pyplot as plt
import os
from torch.utils.data import DataLoader, Dataset, TensorDataset
from torchvision import transforms, datasets

# Data

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
def get_backgrounds():
    backgrounds = []
    for file in os.listdir("./images/train"):
        if file.endswith('.jpg'):
            backgrounds.append(plt.imread(os.path.join("./images/train",file)))
    return np.array(backgrounds)
backgrounds = get_backgrounds()


def compose_image(image):
    image = (image > 0).astype(np.float32)
    image = image.reshape([28,28])*255.0
    
    image = np.stack([image,image,image],axis=2)
    
    background = np.random.choice(backgrounds)
    w,h,_ = background.shape
    dw, dh,_ = image.shape
    x = np.random.randint(0,w-dw)
    y = np.random.randint(0,h-dh)
    
    temp = background[x:x+dw, y:y+dh]
    return np.abs(temp-image).astype(np.uint8)


class MNISTM(Dataset):
            
    def __init__(self, train=True,transform=None):
        if train:
            self.data = datasets.MNIST(root='.data/mnist',train=True, download=True)
        else:
            self.data = datasets.MNIST(root='.data/mnist',train=False, download=True)
        self.backgrounds = get_backgrounds()
        self.transform = transform
    def __getitem__(self,index):
        image = np.array(self.data.__getitem__(index)[0])
        target = self.data.__getitem__(index)[1]
        image = compose_image(image)
        #image = Image.fromarray(image.squeeze(), mode="RGB")
        if self.transform is not None:
            image = self.transform(image)
        return image, target
        
    def __len__(self):
        return len(self.data)
    
    
def get_mnistm_loaders(data_aug = False, batch_size=128,test_batch_size=1000):
    if data_aug:
        train_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(64),
            transforms.RandomCrop(64,padding=4),
            transforms.ToTensor()
        ])
    else:
        train_transform = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize(64),
            transforms.ToTensor()
        ])
    test_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize(64),
        transforms.ToTensor()
    ])
    kwargs = {}
    train_loader = DataLoader(
        MNISTM(train=True,transform=train_transform),batch_size=batch_size, shuffle=True, drop_last=True)
    train_eval_loader = DataLoader(
        MNISTM(train=True, transform=test_transform),batch_size=test_batch_size, shuffle=False, drop_last=True)
    test_loader = DataLoader(
        MNISTM(train=False,transform=test_transform),batch_size=test_batch_size, shuffle=False, drop_last=True)
    return train_loader, train_eval_loader, test_loader


def get_mnist_loaders(data_aug = False, batch_size=128,test_batch_size=1000):
    if data_aug:
        train_transform = transforms.Compose(
            [transforms.Resize(64),
            transforms.RandomCrop(64,padding=4),
            transforms.Grayscale(3),
            transforms.ToTensor()
        ])
    else:
        train_transform = transforms.Compose([
            transforms.Resize(64),
            transforms.Grayscale(3),
            transforms.ToTensor()
        ])
    test_transform = transforms.Compose([
        transforms.Resize(64),
        transforms.Grayscale(3),
        transforms.ToTensor()
    ])
    kwargs = {}

    train_loader = DataLoader(
        datasets.MNIST(root='.data/mnist',train=True, download=True,transform=train_transform),batch_size=batch_size, shuffle=True, drop_last=True)
    train_eval_loader = DataLoader(
        datasets.MNIST(root='.data/mnist',train=True, download=True, transform=test_transform),batch_size=test_batch_size, shuffle=False, drop_last=True)
    test_loader = DataLoader(
        datasets.MNIST(root='.data/mnist',train=False, download=True, transform=test_transform),batch_size=test_batch_size, shuffle=False, drop_last=True)
    return train_loader, train_eval_loader, test_loader

  return np.array(backgrounds)


In [4]:
loader_source, mnist_eval_loader, mnist_test_loader = get_mnist_loaders(batch_size=128)
loader_target, mnistm_eval_loader,mnistm_test_loader = get_mnistm_loaders(batch_size=128)

  return np.array(backgrounds)


# Model

In [5]:
class Encoder(nn.Module):
    def __init__(self, input_nc=3):
        super(Encoder, self).__init__()

        # Initial convolution block
        model = [   nn.ReflectionPad2d(3),
                    nn.Conv2d(input_nc, 64, 7),
                    nn.InstanceNorm2d(64),
                    nn.ReLU(inplace=True) ]

        model += [nn.Conv2d(64, 128, 3, stride=2, padding=1),
                  nn.InstanceNorm2d(128),
                  nn.ReLU(inplace=True)]
                  
        model += [nn.Conv2d(128, 256, 3, stride=2, padding=1),
                  nn.InstanceNorm2d(256),
                  nn.ReLU(inplace=True)]

        model += [nn.Conv2d(256, 256, 3, stride=2, padding=1),
                  nn.InstanceNorm2d(256),
                  nn.ReLU(inplace=True)]

        model += [nn.Conv2d(256, 512, 3, stride=2, padding=1),
                  nn.InstanceNorm2d(512),
                  nn.ReLU(inplace=True)]

        model += [nn.Conv2d(512, 1024, 3, stride=2, padding=1),
                  nn.InstanceNorm2d(1024),
                  nn.ReLU(inplace=True)]



        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class Decoder(nn.Module):
    def __init__(self, input_nc=1024, output_nc=3):
        super(Decoder, self).__init__()
        model = [nn.ConvTranspose2d(input_nc, 512, 3, stride=2, padding=1, output_padding=1),
                  nn.InstanceNorm2d(512),
                  nn.ReLU(inplace=True)]

        model += [nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1),
                  nn.InstanceNorm2d(256),
                  nn.ReLU(inplace=True)]
        
        model += [nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1),
                  nn.InstanceNorm2d(128),
                  nn.ReLU(inplace=True)]

        model += [nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1),
                  nn.InstanceNorm2d(64),
                  nn.ReLU(inplace=True)]

        model += [nn.ConvTranspose2d(64, 64, 3, stride=2, padding=1, output_padding=1),
                  nn.InstanceNorm2d(64),
                  nn.ReLU(inplace=True)]
                  
        
        # Output layer
        model += [  nn.ReflectionPad2d(3),
                    nn.Conv2d(64, output_nc, 7),
                    nn.InstanceNorm2d(3),
                    nn.Tanh() ]

        self.model = nn.Sequential(*model)

    def forward(self, x):
        return self.model(x)

class Identity_Generator(nn.Module):
    def __init__(self, encoder, decoder):
        super(Identity_Generator, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, A, B):
        latentA = self.encoder(A)
        latentB = self.encoder(B)

        reconstructedA = self.decoder(latentA)
        reconstructedB = self.decoder(latentB)
        return reconstructedA, reconstructedB

class Perceptual(nn.Module):
    def __init__(self, encoder, decoder, generator):
        super(Perceptual, self).__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator

    def forward(self, A, B):

        reconstructedA, reconstructedB = self.generator(A, B)

        latentA = self.encoder(A)
        latentB = self.encoder(B)

        latentA.detach()
        latentB.detach()

        style = latentA[:, 0:512, : , :]
        content = latentB[:, 512:1024, :, :]
        
        mixed_latent = torch.cat([style, content], dim=1)
        mixed_image = self.decoder(mixed_latent)

        return mixed_image, reconstructedA, reconstructedB

class Discriminator(nn.Module):
    def __init__(self, input_nc = 3):
        super(Discriminator, self).__init__()

        model = [   nn.Conv2d(input_nc, 64, 4, stride=2, padding=1),
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(64, 128, 4, padding=1),
                    nn.InstanceNorm2d(128), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(128, 256, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(256, 512, 4, stride=2, padding=1),
                    nn.InstanceNorm2d(256), 
                    nn.LeakyReLU(0.2, inplace=True) ]

        model += [  nn.Conv2d(512, 1, 4, padding=1)]


        self.model = nn.Sequential(*model)

    def forward(self, x):
        x =  self.model(x)
        return F.avg_pool2d(x, x.size()[2:]).view(x.size()[0], -1)


# Load S3gan

In [6]:
perceptual = torch.load("./model_weight/perceptual")

# Generate new dataset

In [7]:
def generate_mixed_image(perceptual, source, target, use_cuda):
    mixed = []
    label = []
    reconstructed_A = []

    data_target_iter = iter(target)
    for _, (source_inputs, source_label) in enumerate(source):
        batch_size = source_inputs.size(0)
        target_inputs, target_label = data_target_iter.next()

        if use_cuda: 
            source_inputs, target_inputs = source_inputs.cuda(), target_inputs.cuda()

        mixed_image, reconstructed_source, _ = perceptual(source_inputs, target_inputs)

        label.extend(source_label.detach().cpu().numpy())
        mixed.extend(mixed_image.detach().cpu().numpy())
        reconstructed_A.extend(reconstructed_source.detach().cpu().numpy())
    
    return mixed, reconstructed_A, label

In [8]:
if (os.path.exists("./dataset/mixed.pt")) == False and os.path.exists("./dataset/recon_A.pt") == False:
    mixed, reconstructionA, label = generate_mixed_image(perceptual, loader_target, loader_source, True)

In [9]:
if (os.path.exists("./dataset/mixed.pt")) == False:
    tensor_mixed = torch.Tensor(mixed)
    tensor_label = torch.Tensor(label)

    mixed_dataset = TensorDataset(tensor_mixed, tensor_label)
    torch.save(mixed_dataset, './dataset/mixed.pt')

In [10]:
if (os.path.exists("./dataset/recon_A.pt")) == False:
    tensor_reconstructionA = torch.Tensor(reconstructionA)
    tensor_label = torch.Tensor(label)

    recon_dataset = TensorDataset(tensor_reconstructionA, tensor_label)
    torch.save(recon_dataset, './dataset/recon_A.pt')

In [11]:
mixed_dataset = torch.load('./dataset/mixed.pt', map_location=torch.device('cpu'))
mixed_dataloader = DataLoader(mixed_dataset, batch_size= 128, shuffle=True, drop_last=True)

In [12]:
recon_dataset = torch.load('./dataset/recon_A.pt', map_location=torch.device('cpu'))
recon_dataloader = DataLoader(mixed_dataset, batch_size= 128, shuffle=True, drop_last=True)

In [3]:
import torch.nn as nn
import torch.nn.functional as F
import torch

In [5]:
t = torch.rand(20, 3, 64, 64)
out = F.interpolate(t, size=32)
out.size()

torch.Size([20, 3, 32, 32])