In [1]:
import sys, os, cv2, torch
sys.path.append(f"{os.getcwd()}")
from PIL import Image
from torchvision.transforms import transforms
from style_transfer.networks import CycleGAN, CycleGANConfig
from CycleGAN.data.unaligned_dataset import UnalignedDataset

In [4]:
def train():
    config = CycleGANConfig.create("style_transfer/config/cyclegan_train.yaml", phase="train")
    
    config.defrost()
    config.checkpoints_dir = "../../Models/checkpoints"
    config.dataroot = "../../Datasets/BaroqueStyleTraining"
    config.name = "baroque"
    config.lr = 0.0001168
    config.display_freq = 50
    config.print_freq = 50
    config.batch_size = 7
    config.num_threads = 1
    config.serial_batches = False
    config.save_no = 1
    config.save_epoch_freq = 1
    config.continue_train = True
    config.display_server = "http://116.203.134.130"
    config.display_env = "baroque"
    config.freeze()
    print(config)
    
    dataset = UnalignedDataset(config)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=not config.serial_batches,
        num_workers=int(config.num_threads)
    )
    
    network = CycleGAN(config)
    network.train(dataloader)


In [2]:
def transform(path, direction="AtoB"):
    config = CycleGANConfig.create("style_transfer/config/cyclegan_test.yaml", phase="test")
    
    config.defrost()
    config.results_dir = "../.output/results"   # saves results here.
    config.num_threads = 0   # test code only supports num_threads = 0
    config.batch_size = 1    # test code only supports batch_size = 1
    config.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    config.no_flip = True    # no flip; comment this line if results on flipped images are needed.
    config.display_id = -1   # no visdom display
    config.freeze()
    
    network = CycleGAN(config)
    network.loadModel({
        "G_A": "../../../Models/CycleGAN/monet_A_B.pth",
        "G_B": "../../../Models/CycleGAN/monet_B_A.pth"
    })
    with open(path, 'rb') as file:
        image = Image.open(file)
        image.convert("RGB")
    
    image = transforms.ToTensor()(image)
    image = transforms.Resize(256)(image)
    image = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))(image)
    image = torch.stack((image,))
    image = image.to("cuda")

    if direction == "AtoB":        
        image = network.transformFromPhotographicToArtistic(image)
    elif direction == "BtoA":
        image = network.transformFromArtisticToPhotographic(image)
    
    image = image[0] * 0.5 + 0.5
    image = image.detach().cpu().numpy()
    image = image.transpose(1,2,0)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join(config.results_dir, "test_cyclegan.png"), image * 255.0)

In [5]:
# transform("../../../Datasets/custom/Drama/8_1/000000001240.jpg", "AtoB")
train()

batch_size: 1
beta1: 0.5
checkpoints_dir: ../../../Models/CycleGAN/checkpoints
continue_train: True
crop_size: 256
dataroot: ../../../Datasets/custom/cezanne2photo
dataset_mode: unaligned
direction: AtoB
display_env: test_mount
display_freq: 2
display_id: 1
display_ncols: 4
display_port: 8097
display_server: http://116.203.134.130
display_winsize: 256
epoch: latest
epoch_count: 1
gan_mode: lsgan
gpu_ids: [0]
init_gain: 0.02
init_type: normal
input_nc: 3
isTrain: True
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
load_size: 286
lr: 0.0002
lr_decay_iters: 50
lr_policy: linear
max_dataset_size: inf
model: train
n_epochs: 100
n_epochs_decay: 100
n_layers_D: 3
name: cezanne
ndf: 64
netD: basic
netG: resnet_9blocks
ngf: 64
no_dropout: False
no_flip: True
norm: instance
num_threads: 4
output_nc: 3
phase: train
pool_size: 50
preprocess: resize_and_crop
print_freq: 2
save_epoch_freq: 1
save_no: 1
serial_batches: True
suffix: 
verbose: True
initialize network with normal
initialize network 

Setting up a new session...


initialize network with normal
The number of training images = 5
loading the model from ../../../Models/CycleGAN/checkpoints\cezanne\23_net_G_A.pth
loading the model from ../../../Models/CycleGAN/checkpoints\cezanne\23_net_G_B.pth
loading the model from ../../../Models/CycleGAN/checkpoints\cezanne\23_net_D_A.pth
loading the model from ../../../Models/CycleGAN/checkpoints\cezanne\23_net_D_B.pth




learning rate 0.0002000 -> 0.0002000
(epoch: 24, iters: 2, time: 1.295, data: 3.197) D_A: 1.402 G_A: 2.076 cycle_A: 1.687 idt_A: 1.158 D_B: 1.031 G_B: 1.496 cycle_B: 3.807 idt_B: 0.717 
(epoch: 24, iters: 4, time: 1.220, data: 0.063) D_A: 0.630 G_A: 0.788 cycle_A: 4.727 idt_A: 1.211 D_B: 0.697 G_B: 1.003 cycle_B: 3.288 idt_B: 2.098 
saving the model at the end of epoch 24, iters 5
End of epoch 24 / 200 	 Time Taken: 7 sec
learning rate 0.0002000 -> 0.0002000
(epoch: 25, iters: 1, time: 1.299, data: 0.077) D_A: 0.196 G_A: 0.585 cycle_A: 1.931 idt_A: 0.814 D_B: 0.375 G_B: 0.498 cycle_B: 1.958 idt_B: 0.917 
(epoch: 25, iters: 3, time: 1.108, data: 0.064) D_A: 0.278 G_A: 0.394 cycle_A: 2.683 idt_A: 0.928 D_B: 0.323 G_B: 0.612 cycle_B: 1.803 idt_B: 1.250 
(epoch: 25, iters: 5, time: 1.144, data: 0.059) D_A: 0.234 G_A: 0.455 cycle_A: 1.767 idt_A: 0.899 D_B: 0.284 G_B: 0.416 cycle_B: 2.115 idt_B: 0.774 
saving the model at the end of epoch 25, iters 10
End of epoch 25 / 200 	 Time Taken: 7 se

KeyboardInterrupt: 