In [16]:
import sys, os, cv2, torch, re
sys.path.append(f"{os.getcwd()}")
from PIL import Image
from torchvision.transforms import transforms
from style_transfer.networks import CycleGAN, CycleGANConfig
from CycleGAN.data.unaligned_dataset import UnalignedDataset

In [17]:
def train():
    config = CycleGANConfig.create("style_transfer/config/cyclegan_train.yaml", phase="train")
    
    config.defrost()
    config.checkpoints_dir = "../../../Models/CycleGAN/checkpoints"
    config.dataroot = "../../../Datasets/custom/BaroqueStyleTrainingSmall"
    config.name = "test_baroque"
    config.display_freq = 1
    config.print_freq = 1
    config.batch_size = 1
    config.num_threads = 1
    config.serial_batches = False
    config.save_no = 1
    config.save_epoch_freq = 1
    config.continue_train = True
    config.display_server = "http://localhost"
    config.display_env = "test_baroque"
    config.freeze()
    print(config)
    
    dataset = UnalignedDataset(config)
    dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=config.batch_size,
        shuffle=not config.serial_batches,
        num_workers=int(config.num_threads)
    )
    
    network = CycleGAN(config)
    network.train(dataloader)


In [30]:
def transform(path, direction="AtoB"):
    config = CycleGANConfig.create("style_transfer/config/cyclegan_test.yaml", phase="test")
    
    config.defrost()
    config.results_dir = "../.output/results"   # saves results here.
    config.num_threads = 0   # test code only supports num_threads = 0
    config.batch_size = 1    # test code only supports batch_size = 1
    config.serial_batches = True  # disable data shuffling; comment this line if results on randomly chosen images are needed.
    config.no_flip = True    # no flip; comment this line if results on flipped images are needed.
    config.display_id = -1   # no visdom display
    config.freeze()
    
    network = CycleGAN(config)
    network.loadModel({
        "G_A": "../../../Models/CycleGAN/baroque/latest_net_G_A.pth",
        "G_B": "../../../Models/CycleGAN/baroque/latest_net_G_B.pth"
    })
    with open(path, 'rb') as file:
        image = Image.open(file)
        image.convert("RGB")
    
    image = transforms.ToTensor()(image)
    image = transforms.Resize(256)(image)
    image = transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))(image)
    image = torch.stack((image,))
    image = image.to("cuda")

    if direction == "AtoB":        
        image = network.transformFromArtisticToPhotographic(image)
    elif direction == "BtoA":
        image = network.transformFromPhotographicToArtistic(image)
    
    image = image[0] * 0.5 + 0.5
    image = image.detach().cpu().numpy()
    image = image.transpose(1,2,0)
    image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
    cv2.imwrite(os.path.join(config.results_dir, "test_cyclegan.png"), image * 255.0)

In [19]:
def convertFromTrainToTestModel(loadPath, savePath):
    state_dict = torch.load(loadPath)
    if hasattr(state_dict, "_metadata"):
        for meta_key in list(state_dict._metadata):
            match = re.fullmatch(r"model\.[0-9]+\.conv_block\.([4-7])", meta_key)
            if match is not None:
                block_idx = int(meta_key[match.regs[1][0]:match.regs[1][1]])
                if not block_idx == 4:
                    state_dict._metadata[f"{meta_key[0:match.regs[1][0]]}{block_idx-1}"] = state_dict._metadata[f"{meta_key[0:match.regs[1][0]]}{block_idx}"]
                del state_dict._metadata[f"{meta_key[0:match.regs[1][0]]}{block_idx}"]
    for key in list(state_dict.keys()):
        match = re.fullmatch(r"model\.[0-9]+\.conv_block\.([4-7])\..+", key)
        if match is not None:
            block_idx = int(key[match.regs[1][0]:match.regs[1][1]])
            if not block_idx == 4:
                state_dict[f"{key[0:match.regs[1][0]]}{block_idx-1}{key[match.regs[1][1]:]}"] = state_dict[f"{key[0:match.regs[1][0]]}{block_idx}{key[match.regs[1][1]:]}"]
            del state_dict[f"{key[0:match.regs[1][0]]}{block_idx}{key[match.regs[1][1]:]}"]
    torch.save(state_dict, savePath)

In [31]:
# train()
transform("../../../Datasets/custom/BaroqueStyleTraining/trainA/15_123_rembrandt_the-young-rembrandt-as-democritus-the-laughing-philosopher-1629.jpg", "AtoB")
# convertFromTrainToTestModel("../../../Models/CycleGAN/baroque/train/latest_net_G_A.pth", "../../../Models/CycleGAN/baroque/latest_net_G_A.pth")
# convertFromTrainToTestModel("../../../Models/CycleGAN/baroque/train/latest_net_G_B.pth", "../../../Models/CycleGAN/baroque/latest_net_G_B.pth")

initialize network with normal
initialize network with normal
loading the model from ../../../Models/CycleGAN/baroque/latest_net_G_A.pth
loading the model from ../../../Models/CycleGAN/baroque/latest_net_G_B.pth


