In [11]:
import sys, os, cv2, torch
import wandb
import numpy as np
sys.path.append(f"{os.getcwd()}")
from PIL import Image
from torchvision.transforms import transforms
from style_transfer.networks import StarGAN, StarGANConfig
from StarGAN.core.data_loader import get_train_loader, get_test_loader

In [12]:
def train():
    config = StarGANConfig.create("style_transfer/config/stargan.yaml")
    # Dataset
    config.defrost()
    config.train_img_dir = "../../../Datasets/custom/StarGANTrainingSmall/train"
    config.val_img_dir = "../../../Datasets/custom/StarGANTrainingSmall/val"
    config.src_dir = "../../../Datasets/custom/StarGANTrainingSmall"
    config.checkpoint_dir = "../../../Models/stargan"
    config.result_dir = "../.output/results"
    config.eval_dir = "../.output/results"
    config.num_domains = 4
    # Training
    config.mode = "train"
    config.num_workers = 4
    config.total_iters = 5
    config.batch_size = 1
    config.val_batch_size = 1
    config.print_every = 1
    config.sample_every = 1
    config.save_every = 1
    config.eval_every = 1
    config.num_outs_per_domain = 1
    config.continue_training = False
    
    # Visdom
    config.name = "test_stargan"
    config.display_server = "http://localhost"
    config.display_port = 8097
    config.display_env = "test_stargan"
    config.freeze()
    
    network = StarGAN(config)
    network.train(
        dataloader_src=get_train_loader(
            root=config.train_img_dir,
            which='source',
            img_size=config.img_size,
            batch_size=config.batch_size,
            prob=config.randcrop_prob,
            num_workers=config.num_workers),
        dataloader_ref=get_train_loader(
            root=config.train_img_dir,
            which='reference',
            img_size=config.img_size,
            batch_size=config.batch_size,
            prob=config.randcrop_prob,
            num_workers=config.num_workers),
        dataloader_val=get_test_loader(
            root=config.val_img_dir,
            img_size=config.img_size,
            batch_size=config.val_batch_size,
            shuffle=True,
            num_workers=config.num_workers)
    )

    print("Finished!")

In [13]:
def transform(imagePath, direction="AtoB"):
    config = StarGANConfig.create("style_transfer/config/stargan.yaml")
    network = StarGAN(config)
    network.loadModel("../../../Models/afhq", 100000)
    with open(imagePath, 'rb') as file:
        image = Image.open(file)
        image.convert("RGB")
    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    
    image = torch.stack((transform(image),)).to("cuda")

    if direction == "AtoB":
        style = torch.tensor([1]).to("cuda")
        image = network.imageToStyle(image, style)
    elif direction == "BtoA":
        style = torch.tensor([0]).to("cuda")
        image = network.imageToStyle(image, style)
    
    image = image[0] * 0.5 + 0.5
    image = image.detach().cpu().numpy()
    image = image.transpose(1,2,0)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join("../.output/results/", "test_StarGAN.png"), image * 255.0)

In [14]:
def visualizeWithWandB(imagePath):
    wandb.login()
    with wandb.init(project="pytorch-demo"):
        
        config = StarGANConfig.create("style_transfer/config/stargan_train.yaml")
        network = StarGAN(config)
        network.loadModel({
            "nets": "../../../Models/stargan/failed/nets_mode_collapse.ckpt",
            "nets_ema": "../../../Models/stargan/failed/nets_ema_mode_collapse.ckpt",
            "optims": "../../../Models/stargan/failed/optims_mode_collapse.ckpt"
        })
        
        with open(imagePath, 'rb') as file:
            image = Image.open(file)
            image.convert("RGB")
        
        transform = transforms.Compose([
            transforms.Resize((256, 256)),
            transforms.ToTensor(),
            transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
        ])
        
        image = transform(image).unsqueeze(0).to("cuda")

        noise = torch.randn(1, config.latent_dim).to(config.device)
        styleEnc = network.model.nets_ema.mapping_network(noise, torch.tensor([0]).to("cuda"))
            
        torch.onnx.export(network.model.nets_ema.generator.module, (image, styleEnc), "model.onnx")
        wandb.save("model.onnx")


In [15]:
# train()
# transform("../../../Datasets/afhq/train/cat/flickr_cat_000018.jpg", "AtoB")
visualizeWithWandB(
    "../../../Datasets/custom/StarGANTrainingSmall/train/2 - baroque/15_2_anthony-van-dyck_portrait-of-the-princes-palatine-charles-louis-i-and-his-brother-robert-1637.jpg"
)



Number of parameters of generator: 33892995
Number of parameters of mapping_network: 4079872
Number of parameters of style_encoder: 20982592
Number of parameters of discriminator: 20853316
Initializing generator...
Initializing mapping_network...
Initializing style_encoder...
Initializing discriminator...
loading the model from ../../../Models/stargan/failed/nets_mode_collapse.ckpt...
loading the model from ../../../Models/stargan/failed/nets_ema_mode_collapse.ckpt...
loading the model from ../../../Models/stargan/failed/optims_mode_collapse.ckpt...




verbose: False, log level: Level.ERROR

