In [22]:
import sys, os, cv2, torch
import wandb
import numpy as np
sys.path.append(f"{os.getcwd()}")
from PIL import Image
from torchvision.transforms import transforms
from style_transfer.networks import StarGAN, StarGANConfig
from style_transfer.networks.StarGANVisualizer import StarGANVisualizer
from StarGAN.core.data_loader import get_train_loader, get_test_loader

In [23]:
def train():
    config = StarGANConfig.create("style_transfer/config/stargan.yaml")
    # Dataset
    config.defrost()
    config.train_img_dir = "../../Datasets/StarGANTraining/train"
    config.val_img_dir = "../../Datasets/StarGANTraining/val"
    config.src_dir = "../../Datasets/StarGANTraining"
    config.checkpoint_dir = "../../Models/stargan"
    config.result_dir = "../../Results/stargan"
    config.eval_dir = "../../Results/stargan"
    config.num_domains = 4
    # Training
    config.mode = "train"
    config.num_workers = 4
    config.total_iters = 500000
    config.batch_size = 14
    config.val_batch_size = 1
    config.print_every = 1000
    config.sample_every = 1000
    config.save_every = 10000
    config.eval_every = 50000
    config.save_no = 2
    config.num_outs_per_domain = 1
    
    # Visdom
    config.name = "stargan"
    config.display_server = "http://116.203.134.130"
    config.display_port = 8097
    config.display_env = "stargan"
    config.freeze()
    
    network = StarGAN(config)
    network.train(
        dataloader_src=get_train_loader(
            root=config.train_img_dir,
            which='source',
            img_size=config.img_size,
            batch_size=config.batch_size,
            prob=config.randcrop_prob,
            num_workers=config.num_workers),
        dataloader_ref=get_train_loader(
            root=config.train_img_dir,
            which='reference',
            img_size=config.img_size,
            batch_size=config.batch_size,
            prob=config.randcrop_prob,
            num_workers=config.num_workers),
        dataloader_val=get_test_loader(
            root=config.val_img_dir,
            img_size=config.img_size,
            batch_size=config.val_batch_size,
            shuffle=True,
            num_workers=config.num_workers)
    )

    print("Finished!")

In [24]:
def transform(imagePath, direction="AtoB"):
    config = StarGANConfig.create("style_transfer/config/stargan.yaml")
    network = StarGAN(config)
    network.loadModel("../../../Models/stargan", "100000")
    with open(imagePath, 'rb') as file:
        image = Image.open(file)
        image.convert("RGB")
    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    
    image = torch.stack((transform(image),)).to("cuda")

    for i in range(4):
        style = torch.tensor([i]).to("cuda")
        styled_image = network.imageToStyle(image, style)
    
        styled_image = styled_image[0] * 0.5 + 0.5
        styled_image = styled_image.detach().cpu().numpy()
        styled_image = styled_image.transpose(1,2,0)
        styled_image = cv2.cvtColor(styled_image, cv2.COLOR_RGB2BGR)
        cv2.imwrite(os.path.join("../.output/results/", f"test_StarGAN_{i}.png"), styled_image * 255.0)

In [25]:
def visualizeWithWandB(imagePath):
    wandb.login()
    with wandb.init(project="pytorch-demo"):
        
        config = StarGANConfig.create("style_transfer/config/stargan_train.yaml")
        network = StarGAN(config)
        network.loadModel({
            "nets": "../../../Models/stargan/failed/nets_mode_collapse.ckpt",
            "nets_ema": "../../../Models/stargan/failed/nets_ema_mode_collapse.ckpt",
            "optims": "../../../Models/stargan/failed/optims_mode_collapse.ckpt"
        })
        
        with open(imagePath, 'rb') as file:
            image = Image.open(file)
            image.convert("RGB")
        
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        
        image = transform(image).unsqueeze(0).to("cuda")

        noise = torch.randn(1, config.latent_dim).to(config.device)
        styleEnc = network.model.nets_ema.mapping_network(noise, torch.tensor([0]).to("cuda"))
            
        torch.onnx.export(network.model.nets_ema.generator.module, (image, styleEnc), "model.onnx")
        wandb.save("model.onnx")


In [27]:
# train()
transform("../../../Datasets/coco/images/train2017/000000568462.jpg", "AtoB")
# visualizeWithWandB("../../../Datasets/custom/StarGANTrainingSmall/train/2 - baroque/15_2_anthony-van-dyck_portrait-of-the-princes-palatine-charles-louis-i-and-his-brother-robert-1637.jpg")

Number of parameters of generator: 33892995
Number of parameters of mapping_network: 4079872
Number of parameters of style_encoder: 20982592
Number of parameters of discriminator: 20853316
Initializing generator...
Initializing mapping_network...
Initializing style_encoder...
Initializing discriminator...
loading the model from ../../../Models/stargan\nets_ema_100000.ckpt...
