In [11]:
import sys, os, cv2, torch, re
sys.path.append(f"{os.getcwd()}")
from PIL import Image
from torchvision.transforms import transforms
from style_transfer.networks import CycleGAN, CycleGANConfig
from CycleGAN.data.unaligned_dataset import UnalignedDataset

In [12]:
def train():
    config = CycleGANConfig.create("style_transfer/config/cyclegan_train.yaml", phase="train")
    
    config.defrost()
    config.checkpoints_dir = "../../../Models/CycleGANtrain/checkpoints"
    config.dataroot = "../../../Datasets/custom/ImpressionismStyleTrainingSmall"
    config.name = "test_impressionism"
    config.display_freq = 10
    config.print_freq = 10
    config.batch_size = 1
    config.num_threads = 1
    config.serial_batches = False
    config.save_no = 1
    config.save_epoch_freq = 1
    config.n_epochs = 100
    config.n_epochs_decay = 100
    config.continue_train = True
    config.display_server = "http://localhost"
    config.display_env = "test_impressionism"
    config.freeze()
    print(config)
    
    dataset = UnalignedDataset(config)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=not config.serial_batches,
        num_workers=int(config.num_threads),
        pin_memory=True
    )
    
    network = CycleGAN(config)
    network.train(dataloader)


In [13]:
def transform(path, direction="AtoB"):
    config = CycleGANConfig.create("style_transfer/config/cyclegan_test.yaml", phase="test")
    
    config.defrost()
    config.results_dir = "../.output/results"   # saves results here.
    config.num_threads = 0   # test code only supports num_threads = 0
    config.batch_size = 1    # test code only supports batch_size = 1
    config.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    config.no_flip = True    # no flip; comment this line if results on flipped images are needed.
    config.display_id = -1   # no visdom display
    config.freeze()
    
    network = CycleGAN(config)
    network.loadModel({
        "G_A": "../../../Models/CycleGAN/baroque/latest_net_G_A.pth",
        "G_B": "../../../Models/CycleGAN/baroque/latest_net_G_B.pth"
    })
    with open(path, 'rb') as file:
        image = Image.open(file)
        image.convert("RGB")
    
    image = transforms.ToTensor()(image)
    # image = transforms.Resize(256)(image)
    image = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))(image)
    image = torch.stack((image,))
    image = image.to("cuda")

    if direction == "AtoB":        
        image = network.artisticToPhotographic(image)
    elif direction == "BtoA":
        image = network.photographicToArtistic(image)
    
    image = image[0] * 0.5 + 0.5
    image = image.detach().cpu().numpy()
    image = image.transpose(1,2,0)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join(config.results_dir, "test_cyclegan.png"), image * 255.0)

In [14]:
def convertFromTrainToTestModel(loadPath, savePath):
    state_dict = torch.load(loadPath)
    if hasattr(state_dict, "_metadata"):
        for meta_key in list(state_dict._metadata):
            match = re.fullmatch(r"model\.[0-9]+\.conv_block\.([4-7])", meta_key)
            if match is not None:
                block_idx = int(meta_key[match.regs[1][0]:match.regs[1][1]])
                if not block_idx == 4:
                    state_dict._metadata[f"{meta_key[0:match.regs[1][0]]}{block_idx-1}"] = state_dict._metadata[f"{meta_key[0:match.regs[1][0]]}{block_idx}"]
                del state_dict._metadata[f"{meta_key[0:match.regs[1][0]]}{block_idx}"]
    for key in list(state_dict.keys()):
        match = re.fullmatch(r"model\.[0-9]+\.conv_block\.([4-7])\..+", key)
        if match is not None:
            block_idx = int(key[match.regs[1][0]:match.regs[1][1]])
            if not block_idx == 4:
                state_dict[f"{key[0:match.regs[1][0]]}{block_idx-1}{key[match.regs[1][1]:]}"] = state_dict[f"{key[0:match.regs[1][0]]}{block_idx}{key[match.regs[1][1]:]}"]
            del state_dict[f"{key[0:match.regs[1][0]]}{block_idx}{key[match.regs[1][1]:]}"]
    torch.save(state_dict, savePath)

In [15]:
train()
# transform("../../../Datasets/coco/images/train2017/000000000241.jpg", "AtoB")
# convertFromTrainToTestModel("../../../Models/CycleGAN/baroque/train/2000_net_G_A.pth", "../../../Models/CycleGAN/baroque/latest_net_G_A.pth")
# convertFromTrainToTestModel("../../../Models/CycleGAN/baroque/train/2000_net_G_B.pth", "../../../Models/CycleGAN/baroque/latest_net_G_B.pth")
# convertFromTrainToTestModel("../../../Models/CycleGAN/impressionism/train/750_net_G_A.pth", "../../../Models/CycleGAN/impressionism/latest_net_G_A.pth")
# convertFromTrainToTestModel("../../../Models/CycleGAN/impressionism/train/750_net_G_B.pth", "../../../Models/CycleGAN/impressionism/latest_net_G_B.pth")
# convertFromTrainToTestModel("../../../Models/CycleGAN/renaissance/train/500_net_G_A.pth", "../../../Models/CycleGAN/renaissance/latest_net_G_A.pth")
# convertFromTrainToTestModel("../../../Models/CycleGAN/renaissance/train/500_net_G_B.pth", "../../../Models/CycleGAN/renaissance/latest_net_G_B.pth")

batch_size: 1
beta1: 0.5
checkpoints_dir: ../../../Models/CycleGANtrain/checkpoints
continue_train: True
crop_size: 256
dataroot: ../../../Datasets/custom/ImpressionismStyleTrainingSmall
dataset_mode: unaligned
direction: AtoB
display_env: test_impressionism
display_freq: 10
display_id: 1
display_ncols: 4
display_port: 8097
display_server: http://localhost
display_winsize: 256
epoch: latest
epoch_count: 1
gan_mode: lsgan
gpu_ids: [0]
init_gain: 0.02
init_type: normal
input_nc: 3
isTrain: True
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
load_size: 286
lr: 0.0002
lr_decay_iters: 50
lr_policy: linear
max_dataset_size: inf
model: train
models_dir: checkpoints
n_epochs: 100
n_epochs_decay: 100
n_layers_D: 3
name: test_impressionism
ndf: 64
netD: basic
netG: resnet_9blocks
ngf: 64
no_dropout: False
no_flip: True
norm: instance
num_threads: 1
output_nc: 3
phase: train
pool_size: 50
preprocess: resize_and_crop
print_freq: 10
save_epoch_freq: 1
save_no: 1
serial_batches: False
suffix: 
v

Setting up a new session...


learning rate 0.0002000 -> 0.0002000
saving the model at the end of epoch 7, iters 2
End of epoch 7 / 200 	 Time Taken: 21 sec
learning rate 0.0002000 -> 0.0002000
saving the model at the end of epoch 8, iters 4
End of epoch 8 / 200 	 Time Taken: 4 sec
learning rate 0.0002000 -> 0.0002000
saving the model at the end of epoch 9, iters 6
End of epoch 9 / 200 	 Time Taken: 3 sec
learning rate 0.0002000 -> 0.0002000
saving the model at the end of epoch 10, iters 8
End of epoch 10 / 200 	 Time Taken: 4 sec
learning rate 0.0002000 -> 0.0002000
results
loses
(epoch: 11, iters: 2, time: 0.641, data: 2.345) D_A: 0.637 G_A: 0.922 cycle_A: 5.063 idt_A: 1.405 D_B: 0.646 G_B: 0.970 cycle_B: 3.070 idt_B: 2.422 
save
saving the model at the end of epoch 11, iters 10
End of epoch 11 / 200 	 Time Taken: 3 sec
learning rate 0.0002000 -> 0.0002000
saving the model at the end of epoch 12, iters 12
End of epoch 12 / 200 	 Time Taken: 3 sec
learning rate 0.0002000 -> 0.0002000
saving the model at the end of

KeyboardInterrupt: 

[WinError 10054] De externe host heeft een verbinding verbroken
on_close() takes 1 positional argument but 3 were given
[WinError 10061] Kan geen verbinding maken omdat de doelcomputer de verbinding actief heeft geweigerd
on_close() takes 1 positional argument but 3 were given
[WinError 10061] Kan geen verbinding maken omdat de doelcomputer de verbinding actief heeft geweigerd
on_close() takes 1 positional argument but 3 were given
[WinError 10061] Kan geen verbinding maken omdat de doelcomputer de verbinding actief heeft geweigerd
on_close() takes 1 positional argument but 3 were given
[WinError 10061] Kan geen verbinding maken omdat de doelcomputer de verbinding actief heeft geweigerd
on_close() takes 1 positional argument but 3 were given
[WinError 10061] Kan geen verbinding maken omdat de doelcomputer de verbinding actief heeft geweigerd
on_close() takes 1 positional argument but 3 were given
[WinError 10061] Kan geen verbinding maken omdat de doelcomputer de verbinding actief heef