In [1]:
import sys, os, cv2, torch, re
sys.path.append(f"{os.getcwd()}")
from PIL import Image
from torchvision.transforms import transforms
from style_transfer.networks import CycleGAN, CycleGANConfig
from CycleGAN.data.unaligned_dataset import UnalignedDataset

In [2]:
def train():
    config = CycleGANConfig.create("style_transfer/config/cyclegan_train.yaml", phase="train")
    
    config.defrost()
    config.checkpoints_dir = "../../../Models/CycleGAN/checkpoints"
    config.dataroot = "../../../Datasets/custom/ImpressionismStyleTrainingSmall"
    config.name = "test_impressionism"
    config.display_freq = 1
    config.print_freq = 1
    config.batch_size = 1
    config.num_threads = 1
    config.serial_batches = False
    config.save_no = 1
    config.save_epoch_freq = 1
    config.continue_train = True
    config.display_server = "http://localhost"
    config.display_env = "test_impressionism"
    config.freeze()
    print(config)
    
    dataset = UnalignedDataset(config)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=not config.serial_batches,
        num_workers=int(config.num_threads),
        pin_memory=True
    )
    
    network = CycleGAN(config)
    network.train(dataloader)


In [3]:
def transform(path, direction="AtoB"):
    config = CycleGANConfig.create("style_transfer/config/cyclegan_test.yaml", phase="test")
    
    config.defrost()
    config.results_dir = "../.output/results"   # saves results here.
    config.num_threads = 0   # test code only supports num_threads = 0
    config.batch_size = 1    # test code only supports batch_size = 1
    config.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    config.no_flip = True    # no flip; comment this line if results on flipped images are needed.
    config.display_id = -1   # no visdom display
    config.freeze()
    
    network = CycleGAN(config)
    network.loadModel({
        "G_A": "../../../Models/CycleGAN/baroque/latest_net_G_A.pth",
        "G_B": "../../../Models/CycleGAN/baroque/latest_net_G_B.pth"
    })
    with open(path, 'rb') as file:
        image = Image.open(file)
        image.convert("RGB")
    
    image = transforms.ToTensor()(image)
    image = transforms.Resize(256)(image)
    image = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))(image)
    image = torch.stack((image,))
    image = image.to("cuda")

    if direction == "AtoB":        
        image = network.artisticToPhotographic(image)
    elif direction == "BtoA":
        image = network.photographicToArtistic(image)
    
    image = image[0] * 0.5 + 0.5
    image = image.detach().cpu().numpy()
    image = image.transpose(1,2,0)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join(config.results_dir, "test_cyclegan.png"), image * 255.0)

In [4]:
def convertFromTrainToTestModel(loadPath, savePath):
    state_dict = torch.load(loadPath)
    if hasattr(state_dict, "_metadata"):
        for meta_key in list(state_dict._metadata):
            match = re.fullmatch(r"model\.[0-9]+\.conv_block\.([4-7])", meta_key)
            if match is not None:
                block_idx = int(meta_key[match.regs[1][0]:match.regs[1][1]])
                if not block_idx == 4:
                    state_dict._metadata[f"{meta_key[0:match.regs[1][0]]}{block_idx-1}"] = state_dict._metadata[f"{meta_key[0:match.regs[1][0]]}{block_idx}"]
                del state_dict._metadata[f"{meta_key[0:match.regs[1][0]]}{block_idx}"]
    for key in list(state_dict.keys()):
        match = re.fullmatch(r"model\.[0-9]+\.conv_block\.([4-7])\..+", key)
        if match is not None:
            block_idx = int(key[match.regs[1][0]:match.regs[1][1]])
            if not block_idx == 4:
                state_dict[f"{key[0:match.regs[1][0]]}{block_idx-1}{key[match.regs[1][1]:]}"] = state_dict[f"{key[0:match.regs[1][0]]}{block_idx}{key[match.regs[1][1]:]}"]
            del state_dict[f"{key[0:match.regs[1][0]]}{block_idx}{key[match.regs[1][1]:]}"]
    torch.save(state_dict, savePath)

In [6]:
train()
# transform("../../../Datasets/custom/BaroqueStyleTraining/trainA/15_123_rembrandt_the-young-rembrandt-as-democritus-the-laughing-philosopher-1629.jpg", "AtoB")
# convertFromTrainToTestModel("../../../Models/CycleGAN/baroque/train/latest_net_G_A.pth", "../../../Models/CycleGAN/baroque/latest_net_G_A.pth")
# convertFromTrainToTestModel("../../../Models/CycleGAN/baroque/train/latest_net_G_B.pth", "../../../Models/CycleGAN/baroque/latest_net_G_B.pth")

batch_size: 1
beta1: 0.5
checkpoints_dir: ../../../Models/CycleGAN/checkpoints
continue_train: True
crop_size: 256
dataroot: ../../../Datasets/custom/ImpressionismStyleTrainingSmall
dataset_mode: unaligned
direction: AtoB
display_env: test_impressionism
display_freq: 1
display_id: 1
display_ncols: 4
display_port: 8097
display_server: http://localhost
display_winsize: 256
epoch: latest
epoch_count: 1
gan_mode: lsgan
gpu_ids: [0]
init_gain: 0.02
init_type: normal
input_nc: 3
isTrain: True
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
load_size: 286
lr: 0.0002
lr_decay_iters: 50
lr_policy: linear
max_dataset_size: inf
model: train
models_dir: checkpoints
n_epochs: 100
n_epochs_decay: 100
n_layers_D: 3
name: test_impressionism
ndf: 64
netD: basic
netG: resnet_9blocks
ngf: 64
no_dropout: False
no_flip: True
norm: instance
num_threads: 1
output_nc: 3
phase: train
pool_size: 50
preprocess: resize_and_crop
print_freq: 1
save_epoch_freq: 1
save_no: 1
serial_batches: False
suffix: 
verbose:

Setting up a new session...


initialize network with normal
initialize network with normal
initialize network with normal
The number of training images = 2




learning rate 0.0002000 -> 0.0002000
(epoch: 1, iters: 1, time: 23.190, data: 2.694) D_A: 2.138 G_A: 2.855 cycle_A: 6.121 idt_A: 3.341 D_B: 1.587 G_B: 2.206 cycle_B: 6.690 idt_B: 3.044 
(epoch: 1, iters: 2, time: 0.425, data: 0.001) D_A: 8.370 G_A: 1.812 cycle_A: 6.831 idt_A: 3.171 D_B: 1.962 G_B: 1.305 cycle_B: 6.468 idt_B: 3.339 
saving the model at the end of epoch 1, iters 2
End of epoch 1 / 200 	 Time Taken: 27 sec
learning rate 0.0002000 -> 0.0002000
(epoch: 2, iters: 1, time: 0.534, data: 2.948) D_A: 1.566 G_A: 1.521 cycle_A: 5.185 idt_A: 3.425 D_B: 6.210 G_B: 1.656 cycle_B: 6.836 idt_B: 2.543 
(epoch: 2, iters: 2, time: 0.405, data: 0.000) D_A: 1.499 G_A: 1.278 cycle_A: 6.106 idt_A: 3.002 D_B: 1.959 G_B: 1.313 cycle_B: 6.232 idt_B: 2.982 
saving the model at the end of epoch 2, iters 4
End of epoch 2 / 200 	 Time Taken: 4 sec
learning rate 0.0002000 -> 0.0002000
(epoch: 3, iters: 1, time: 0.362, data: 3.263) D_A: 1.602 G_A: 1.649 cycle_A: 4.911 idt_A: 2.934 D_B: 6.584 G_B: 1.52

KeyboardInterrupt: 