In [1]:
%env CUDA_LAUNCH_BLOCKING 1
%env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:512

env: CUDA_LAUNCH_BLOCKING=1
env: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512


In [2]:
import sys, os, cv2, torch
sys.path.append(f"{os.getcwd()}")
from PIL import Image
from torchvision.transforms import transforms
from style_transfer.networks import StarGAN, StarGANConfig
from StarGAN.core.data_loader import get_train_loader, get_test_loader

In [3]:
def train():
    config = StarGANConfig.create("style_transfer/config/stargan.yaml")
    # Dataset
    config.defrost()
    config.train_img_dir = "../../../Datasets/custom/StarGANTrainingSmall/train"
    config.val_img_dir = "../../../Datasets/custom/StarGANTrainingSmall/val"
    config.src_dir = "../../../Datasets/custom/StarGANTrainingSmall"
    config.checkpoint_dir = "../../../Models/stargan"
    config.result_dir = ".output/results"
    config.eval_dir = ".output/results"
    config.num_domains = 4
    # Training
    config.mode = "train"
    config.num_workers = 4
    config.total_iters = 5
    config.batch_size = 1
    config.val_batch_size = 1
    config.print_every = 1
    config.sample_every = 1
    config.save_every = 1
    config.eval_every = 10000
    config.num_outs_per_domain = 1
    config.continue_training = False
    
    # Visdom
    config.name = "test_stargan"
    config.display_server = "http://localhost"
    config.display_port = 8097
    config.display_env = "test_stargan"
    config.freeze()
    
    network = StarGAN(config)
    network.train(
        dataloader_src=get_train_loader(
            root=config.train_img_dir,
            which='source',
            img_size=config.img_size,
            batch_size=config.batch_size,
            prob=config.randcrop_prob,
            num_workers=config.num_workers),
        dataloader_ref=get_train_loader(
            root=config.train_img_dir,
            which='reference',
            img_size=config.img_size,
            batch_size=config.batch_size,
            prob=config.randcrop_prob,
            num_workers=config.num_workers),
        dataloader_val=get_test_loader(
            root=config.val_img_dir,
            img_size=config.img_size,
            batch_size=config.val_batch_size,
            shuffle=True,
            num_workers=config.num_workers)
    )

    print("Finished!")

In [4]:
def transform(imagePath, direction="AtoB"):
    config = StarGANConfig.create("style_transfer/config/stargan.yaml")
    network = StarGAN(config)
    network.loadModel("../../../Models/afhq", 100000)
    with open(imagePath, 'rb') as file:
        image = Image.open(file)
        image.convert("RGB")
    
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    ])
    
    image = torch.stack((transform(image),)).to("cuda")

    if direction == "AtoB":
        style = torch.tensor([1]).to("cuda")
        image = network.imageToStyle(image, style)
    elif direction == "BtoA":
        style = torch.tensor([0]).to("cuda")
        image = network.imageToStyle(image, style)
    
    image = image[0] * 0.5 + 0.5
    image = image.detach().cpu().numpy()
    image = image.transpose(1,2,0)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join("../.output/results/", "test_StarGAN.png"), image * 255.0)

In [5]:
train()
# transform("../.output/results/test_StarGAN copy 2.png", "BtoA")

Setting up a new session...


Number of parameters of generator: 33892995
Number of parameters of mapping_network: 4079872
Number of parameters of style_encoder: 20982592
Number of parameters of discriminator: 20853316
Initializing generator...
Initializing mapping_network...
Initializing style_encoder...
Initializing discriminator...
Preparing DataLoader to fetch source images during the training phase...
Preparing DataLoader to fetch reference images during the training phase...
Preparing DataLoader for the generation phase...
Start training...
(iters: 1, time: 7.198) D/latent_real: 0.579 D/latent_fake: 0.307 D/latent_reg: 0.001 D/ref_real: 0.000 D/ref_fake: 0.000 D/ref_reg: 0.004 G/latent_adv: 25.414 G/latent_sty: 1.617 G/latent_ds: 0.769 G/latent_cyc: 0.901 G/ref_adv: 17.313 G/ref_sty: 1.033 G/ref_ds: 0.286 G/ref_cyc: 0.879 G/lambda_ds: 1.000 
(iters: 2, time: 10.462) D/latent_real: 0.715 D/latent_fake: 8.404 D/latent_reg: 0.001 D/ref_real: 0.000 D/ref_fake: 0.000 D/ref_reg: 0.002 G/latent_adv: 27.986 G/latent_

KeyboardInterrupt: 