In [1]:
## Change this value when you get a CUDA out of memory error
%env PYTORCH_CUDA_ALLOC_CONF max_split_size_mb:512

env: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512


In [2]:
import sys, os, cv2, torch
sys.path.append(f"{os.getcwd()}")
from PIL import Image
from torchvision.transforms import transforms
from yacs.config import CfgNode as CN
from style_transfer.networks import UGATIT, UGATITConfig
from CycleGAN.data.unaligned_dataset import UnalignedDataset

In [3]:
def train():
    config = UGATITConfig.create("style_transfer/config/ugatit.yaml")
    # Dataset
    config.defrost()
    config.dataset = "BaroqueStyleTrainingSmall"
    config.data_dir = "../../../Datasets/custom"
    config.result_dir = ".output/results"
    # Training
    config.epoch = 3
    config.iteration = 10000 # deprecated
    config.batch_size = 1
    config.log_freq = 1
    config.print_freq = 1
    config.save_freq = 1
    # Visdom
    config.name = "test_ugatit"
    config.display_server = "http://localhost"
    config.display_env = "test"
    config.freeze()
    
    dataset_config = CN()
    dataset_config.dataroot = os.path.join(config.data_dir, config.dataset)
    dataset_config.phase = "train"
    dataset_config.max_dataset_size = float("inf")
    dataset_config.preprocess = "resize"
    dataset_config.load_size = config.img_size
    dataset_config.no_flip = True
    dataset_config.direction = "AtoB"    
    dataset_config.input_nc = 3
    dataset_config.output_nc = 3
    dataset_config.serial_batches = False
    
    dataset = UnalignedDataset(dataset_config)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=True
    )
    
    network = UGATIT(config)
    network.train(dataloader)
    
    print("Finished!")

In [4]:
def transform(path, direction="AtoB"):
    config = UGATITConfig.create("src/style_transfer/config/ugatit.yaml")
    network = UGATIT(config)
    network.loadModel("style_transfer/model")
    with open(path, 'rb') as file:
        image = Image.open(file)
        image.convert("RGB")
        
    image = transforms.ToTensor()(image)
    image = transforms.Resize((256, 256))(image)
    image = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))(image)
    image = torch.stack((image,))
    image = image.to("cuda")

    if direction == "AtoB":
        image = network.transformFromPhotographicToArtistic(image)
    elif direction == "BtoA":
        image = network.transformFromArtisticToPhotographic(image)
    
    image = image[0] * 0.5 + 0.5
    image = image.detach().cpu().numpy()
    image = image.transpose(1,2,0)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join(config.result_dir, "test_ugatit.png"), image * 255.0)

In [5]:
train()
# transform("../../Datasets/human-art/images/2D_virtual_human/oil_painting/000000000000.jpg", "BtoA")


##### Information #####
# light :  False
# dataset :  BaroqueStyleTrainingSmall
# batch_size :  1
# iteration per epoch :  10000

##### Generator #####
# residual blocks :  4

##### Discriminator #####
# discriminator layer :  6

##### Weight #####
# adv_weight :  1
# cycle_weight :  10
# identity_weight :  10
# cam_weight :  1000


Setting up a new session...


training start !


KeyboardInterrupt: 