Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

My reimplemention of training #81

Closed
ALLinLLM opened this issue Dec 10, 2020 · 13 comments
Closed

My reimplemention of training #81

ALLinLLM opened this issue Dec 10, 2020 · 13 comments

Comments

@ALLinLLM
Copy link

ALLinLLM commented Dec 10, 2020

follow the jouranl edition of the paper, I wrote a pytorch Pseudocode training code

    # build 2 vae network, 3 discriminators but NO transfer network for now
    # and their optimizer
    vae1, xr_recon_d, z_xr_d, \
        vae2, y_recon_d, \
        optimizer_vae1, optimizer_d1, \
        optimizer_vae2, optimizer_d2 = build_model(opt)
    start_iter = 0
    if opt.load_checkpoint_iter>0:
        checkpoint_path = checkpoint_root + f'/global_checkpoint_{opt.load_checkpoint_iter}.pth'
        if not Path(checkpoint_path).exists():
            print(f"ERROR! checkpoint_path {checkpoint_path} is None")
            exit(-1)
        state_dict = torch.load(checkpoint_path)
        start_iter = state_dict['iter']
        assert state_dict['batch_size'] == opt.batch_size, f"ERROR - batch size changed! load: {state_dict['batch_size']}, but now {opt.batch_size}"
        vae1.load_state_dict(state_dict['vae1'])
        xr_recon_d.load_state_dict(state_dict['xr_recon_d'])
        z_xr_d.load_state_dict(state_dict['z_xr_d'])
        vae2.load_state_dict(state_dict['vae2'])
        y_recon_d.load_state_dict(state_dict['y_recon_d'])
        optimizer_vae1.load_state_dict(state_dict['optimizer_vae1'])
        optimizer_d1.load_state_dict(state_dict['optimizer_d1'])
        optimizer_vae2.load_state_dict(state_dict['optimizer_vae2']) 
        optimizer_d2.load_state_dict(state_dict['optimizer_d2']) 
        print("checkpoint load successfully!")
    # create dataloader
    dataLoaderR, dataLoaderXY = get_dataloader(opt)
    dataLoaderXY_iter = iter(dataLoaderXY)
    dataLoaderR_iter = iter(dataLoaderR)
    start = time.perf_counter()
    print("train start!")
    for ii in range(opt.total_iter - start_iter):
        current_iter = ii + start_iter
        try:
            x, y, path_y = dataLoaderXY_iter.next()
        except:
            dataLoaderXY_iter = iter(dataLoaderXY)
            x, y, path_y = dataLoaderXY_iter.next()
        try:
            r, path_r = dataLoaderR_iter.next()
        except:
            dataLoaderR_iter = iter(dataLoaderR)
            r, path_r = dataLoaderR_iter.next()
        ### following the practice in U-GAT-IT:
        ### train D and G iteratively, but not training D multiple times than training G
        r = r.to(opt.device)
        x = x.to(opt.device)
        y = y.to(opt.device)
        if opt.debug and current_iter%500==0:
            torchvision.utils.save_image(y, 'train_vae_y.png', normalize=True)
            torchvision.utils.save_image(x, 'train_vae_x.png', normalize=True)
            torchvision.utils.save_image(r, 'train_vae_r.png', normalize=True)

        ### vae1 train d
        # save gpu memory since no need calc grad for net G when train net D
        with torch.no_grad():
            z_x, mean_x, var_x, recon_x = vae1(x)
            z_r, mean_r, var_r, recon_r = vae1(r)
            batch_requires_grad(z_x, mean_x, var_x, recon_x,
                                z_r, mean_r, var_r, recon_r)
        loss_1 = 0
        adv_loss_d_x = lsgan_d(xr_recon_d(x), xr_recon_d(recon_x))
        adv_loss_d_r = lsgan_d(xr_recon_d(r), xr_recon_d(recon_r))
        # z_x is real and z_r is fake here because let z_r close to z_x
        adv_loss_d_xr = lsgan_d(z_xr_d(z_x), z_xr_d(z_r))
        loss_1_d = adv_loss_d_x + adv_loss_d_r + adv_loss_d_xr
        loss_1_d.backward()
        optimizer_d1.step()
        optimizer_d1.zero_grad()
        ### vae1 train g
        # since we need update weights of G, the result should be re-calculate with grad
        z_x, mean_x, var_x, recon_x = vae1(x)
        z_r, mean_r, var_r, recon_r = vae1(r)
        adv_loss_g_x = lsgan_g(xr_recon_d(recon_x))
        adv_loss_g_r = lsgan_g(xr_recon_d(recon_r))
        # z_x is real and z_r is fake here because let z_r close to z_x
        adv_loss_g_xr = lsgan_g(z_xr_d(z_r))
        KLDloss_1_x = -0.5 * torch.sum(1 + var_x - mean_x.pow(2) - var_x.exp())  # KLD
        L1loss_1_x  = opt.weight_alpha * F.l1_loss(x, recon_x)
        KLDloss_1_r = -0.5 * torch.sum(1 + var_r - mean_r.pow(2) - var_r.exp())  # KLD
        L1loss_1_r  = opt.weight_alpha * F.l1_loss(r, recon_r)
        loss_1_g = adv_loss_g_x + KLDloss_1_x + L1loss_1_x \
                 + adv_loss_g_r + KLDloss_1_r + L1loss_1_r \
                 + adv_loss_g_xr
        loss_1_g.backward()
        optimizer_vae1.step()
        optimizer_vae1.zero_grad()

        ### vae2 train d
        # save gpu memory since no need calc grad for net G when train net D
        with torch.no_grad():
            z_y, mean_y, var_y, recon_y = vae2(y)
            batch_requires_grad(z_y, mean_y, var_y, recon_y)
        adv_loss_d_y = lsgan_d(y_recon_d(y), y_recon_d(recon_y))
        loss_2_d = adv_loss_d_y
        loss_2_d.backward()
        optimizer_d2.step()
        optimizer_d2.zero_grad()
        ### vae2 train g
        # since we need update weights of G, the result should be re-calculate with grad
        z_y, mean_y, var_y, recon_y = vae2(y)
        adv_loss_g_y = lsgan_g(y_recon_d(recon_y))
        KLDloss_1_y = -0.5 * torch.sum(1 + var_y - mean_y.pow(2) - var_y.exp())  # KLD
        L1loss_1_y  = opt.weight_alpha * F.l1_loss(y, recon_y)
        loss_2_g = adv_loss_g_y + KLDloss_1_y + L1loss_1_y
        loss_2_g.backward()
        optimizer_vae2.step()
        optimizer_vae2.zero_grad()
        # debug
        if opt.debug and current_iter%500==0:
            # [print(k, 'channel 0:\n', v[0][0]) for k,v in list(model.named_parameters()) if k in ["netG_A.encoder.13.conv_block.5.weight", "netG_A.decoder.4.conv_block.5.weight"]]
            torchvision.utils.save_image(recon_x, 'train_vae_recon_x.png', normalize=True)
            torchvision.utils.save_image(recon_r, 'train_vae_recon_r.png', normalize=True)
            torchvision.utils.save_image(recon_y, 'train_vae_recon_y.png', normalize=True)
       
        if current_iter%500==0:
            print(f"""STEP {current_iter:06d} {time.perf_counter() - start:.1f} s
            loss_1_d = adv_loss_d_x + adv_loss_d_r + adv_loss_d_xr
            {loss_1_d:.3f} = {adv_loss_d_x:.3f} + {adv_loss_d_r:.3f} + {adv_loss_d_xr:.3f}
            loss_1_g = adv_loss_g_x + KLDloss_1_x + L1loss_1_x \
                 + adv_loss_g_r + KLDloss_1_r + L1loss_1_r \
                 + adv_loss_g_xr
            {loss_1_g:.3f} = {adv_loss_g_x:.3f} + {KLDloss_1_x:.3f} + {L1loss_1_x:.3f} \
                 + {adv_loss_g_r:.3f} + {KLDloss_1_r:.3f} + {L1loss_1_r:.3f} \
                 + {adv_loss_g_xr:.3f}
            """)
        if (current_iter+1)%2000==0:
            # finish the current_iter-th step, e.g. finish iter0, save as 1, resume train from iter 1
            state = {
                'iter': current_iter,
                'batch_size': opt.batch_size,
                #
                'vae1': vae1.state_dict(),
                'xr_recon_d': xr_recon_d.state_dict(),
                'z_xr_d': z_xr_d.state_dict(),
                #
                'vae2': vae2.state_dict(),
                'y_recon_d': y_recon_d.state_dict(),
                #
                'optimizer_vae1': optimizer_vae1.state_dict(),
                'optimizer_d1': optimizer_d1.state_dict(),
                'optimizer_vae2': optimizer_vae2.state_dict(),
                'optimizer_d2': optimizer_d2.state_dict(),
                }
            torch.save(state, checkpoint_root + f'/global_checkpoint_{current_iter}.pth')
    print("global", time.perf_counter() - start, ' s')

where the lsgan_d and lsgan_g is defined as following:

import torch
import torch.nn.functional as F
### lsgan: a=0, b=c=1
def lsgan_d(d_logit_real, d_logit_fake):
    return F.mse_loss(d_logit_real, torch.ones_like(d_logit_real)) + d_logit_fake.pow(2).mean()

def lsgan_g(d_logit_fake):
    return F.mse_loss(d_logit_fake, torch.ones_like(d_logit_fake))
@ALLinLLM
Copy link
Author

and the build_model(): you can find all related Net in Global/models/networks

def build_model(opt):
    """ stage 1.1  train 2 vae """
    # TODO stage 1.2 train mapping network
    print("build 2 vae and a transfer network")
    model = Pix2PixHDModel_Mapping()
    model.initialize(opt)

    ##### define networks
    print("build vae1 and vae2 ...")
    vae1 = networks.GlobalGenerator_DCDCv2(
        opt.input_nc,
        opt.output_nc,
        opt.ngf,
        opt.k_size,
        opt.n_downsample_global,
        networks.get_norm_layer(norm_type=opt.norm),
        opt=opt,
    )
    vae2 = networks.GlobalGenerator_DCDCv2(
        opt.input_nc,
        opt.output_nc,
        opt.ngf,
        opt.k_size,
        opt.n_downsample_global,
        networks.get_norm_layer(norm_type=opt.norm),
        opt=opt,
    )
    vae1.apply(networks.weights_init)
    vae2.apply(networks.weights_init)
    print("build vae1 and vae2 finish!")
    print("build D ...")
    xr_recon_d = Z_xr_Discriminator(input_nc=3, ndf=opt.disc_ch, n_layers=opt.disc_layers).to(opt.device)
    z_xr_d = Z_xr_Discriminator(input_nc=opt.feat_dim, ndf=opt.disc_ch, n_layers=opt.disc_layers).to(opt.device)
    y_recon_d = Z_xr_Discriminator(input_nc=3, ndf=opt.disc_ch, n_layers=opt.disc_layers).to(opt.device)
    print("build D finish")
    """ Optim """
    optimizer_vae1 = torch.optim.Adam(vae1.parameters(), 
        lr=opt.lr, betas=(0.5, 0.999), weight_decay=0, eps=1e-6)
    optimizer_d1 = torch.optim.Adam(itertools.chain(
        xr_recon_d.parameters(), 
        z_xr_d.parameters()),
        lr=opt.lr, betas=(0.5, 0.999), weight_decay=0, eps=1e-6)
    optimizer_vae2 = torch.optim.Adam(vae2.parameters(), 
        lr=opt.lr, betas=(0.5, 0.999), weight_decay=0, eps=1e-6)
    optimizer_d2 = torch.optim.Adam(y_recon_d.parameters(), 
        lr=opt.lr, betas=(0.5, 0.999), weight_decay=0, eps=1e-6)
    return vae1, xr_recon_d, z_xr_d, vae2, y_recon_d, \
        optimizer_vae1, optimizer_d1, optimizer_vae2, optimizer_d2

@ALLinLLM
Copy link
Author

and the Z_xr_Discriminator:

import torch
import torch.nn as nn


class Z_xr_Discriminator(nn.Module):
    def __init__(self, input_nc, ndf=64, n_layers=5):
        super(Z_xr_Discriminator, self).__init__()
        model = [nn.ReflectionPad2d(1),
                 nn.utils.spectral_norm(
                 nn.Conv2d(input_nc, ndf, kernel_size=4, stride=2, padding=0, bias=True)),
                 nn.LeakyReLU(0.2, True)]

        for i in range(1, n_layers - 2):
            mult = 2 ** (i - 1)
            model += [nn.ReflectionPad2d(1),
                      nn.utils.spectral_norm(
                      nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=2, padding=0, bias=True)),
                      nn.LeakyReLU(0.2, True)]

        mult = 2 ** (n_layers - 2 - 1)
        model += [nn.ReflectionPad2d(1),
                  nn.utils.spectral_norm(
                  nn.Conv2d(ndf * mult, ndf * mult * 2, kernel_size=4, stride=1, padding=0, bias=True)),
                  nn.LeakyReLU(0.2, True)]

        # Class Activation Map
        mult = 2 ** (n_layers - 2)
        self.gap_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
        self.gmp_fc = nn.utils.spectral_norm(nn.Linear(ndf * mult, 1, bias=False))
        self.conv1x1 = nn.Conv2d(ndf * mult * 2, ndf * mult, kernel_size=1, stride=1, bias=True)
        self.leaky_relu = nn.LeakyReLU(0.2, True)

        self.pad = nn.ReflectionPad2d(1)
        self.conv = nn.utils.spectral_norm(
            nn.Conv2d(ndf * mult, 1, kernel_size=4, stride=1, padding=0, bias=False))

        self.model = nn.Sequential(*model)

    def forward(self, input, need_each_activation=False):
        each_activations = []
        if need_each_activation:
            x = input
            for i in range(len(self.model)):
                x = self.model[i](x)
                if isinstance(self.model[i], torch.nn.modules.activation.LeakyReLU):
                    each_activations.append(x)
        else:
            x = self.model(input)

        gap = torch.nn.functional.adaptive_avg_pool2d(x, 1)
        gap_weight = list(self.gap_fc.parameters())[0]
        gap = x * gap_weight.unsqueeze(2).unsqueeze(3)

        gmp = torch.nn.functional.adaptive_max_pool2d(x, 1)
        gmp_weight = list(self.gmp_fc.parameters())[0]
        gmp = x * gmp_weight.unsqueeze(2).unsqueeze(3)

        # gap_logit = self.gap_fc(gap.view(x.shape[0], -1))
        # gmp_logit = self.gmp_fc(gmp.view(x.shape[0], -1))
        # cam_logit = torch.cat([gap_logit, gmp_logit], 1)
        x = torch.cat([gap, gmp], 1)
        x = self.leaky_relu(self.conv1x1(x))
        # heatmap = torch.sum(x, dim=1, keepdim=True)
        x = self.pad(x)
        out = self.conv(x)
        if need_each_activation:
            return out, each_activations
        else:
            return out

@ALLinLLM
Copy link
Author

ALLinLLM commented Dec 10, 2020

and you need append some code in class GlobalGenerator_DCDCv2 @ Global/models/networks.py

Notice: self.mean_layer and self.var_layer here is just two full connect layers

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def forward(self, input, flow="enc_dec"):
        if flow == "enc":
            h = self.encoder(input)
            if not self.eval_:
                mean = self.mean_layer(h)
                var = self.var_layer(h)
                h = self.reparameterize(mean, var)
                return h, mean, var
            else:
                return h
        elif flow == "dec":
            return self.decoder(input)
        elif flow == "enc_dec":
            z_x = self.encoder(input)
            if not self.eval_:
                mean = self.mean_layer(z_x)
                var = self.var_layer(z_x)
                z_x = self.reparameterize(mean, var)
                x = self.decoder(z_x)
                return z_x, mean, var, x
            return self.decoder(z_x)

@ALLinLLM
Copy link
Author

here is my training loss, both loss of d and g are obviously decreased, which means the training works
图片

@ALLinLLM
Copy link
Author

ALLinLLM commented Dec 10, 2020

append the training code of mapping(transfer z_x to z_y) net

    # build mapping net and their optimizer
    mapping_net, vgg, optimizer_mapping = build_mapping_models(opt)
    # according to the paper, fix vaes when train mapping in stage 1-2
    # build 2 vae network, 3 discriminators and their optimizer
    # Notice: here we use y_recon_d as net D in the adv loss of mapping
    vae1, xr_recon_d, z_xr_d, \
        vae2, y_recon_d, \
        optimizer_vae1, optimizer_d1, \
        optimizer_vae2, optimizer_d2 = build_model(opt)
    start_iter = 0
    if opt.load_vae_checkpoint_iter>0:
        checkpoint_path = checkpoint_root + f'/global_checkpoint_{opt.load_vae_checkpoint_iter}.pth'
        if not Path(checkpoint_path).exists():
            print(f"ERROR! checkpoint_path {checkpoint_path} is None")
            exit(-1)
        state_dict = torch.load(checkpoint_path)
        # start_iter = state_dict['iter']
        # assert state_dict['batch_size'] == opt.batch_size, f"ERROR - batch size changed! load: {state_dict['batch_size']}, but now {opt.batch_size}"
        vae1.load_state_dict(state_dict['vae1'])
        xr_recon_d.load_state_dict(state_dict['xr_recon_d'])
        z_xr_d.load_state_dict(state_dict['z_xr_d'])
        vae2.load_state_dict(state_dict['vae2'])
        y_recon_d.load_state_dict(state_dict['y_recon_d'])
        optimizer_vae1.load_state_dict(state_dict['optimizer_vae1'])
        optimizer_d1.load_state_dict(state_dict['optimizer_d1'])
        optimizer_vae2.load_state_dict(state_dict['optimizer_vae2']) 
        optimizer_d2.load_state_dict(state_dict['optimizer_d2']) 
        print("checkpoint load successfully!")
    # create dataloader
    dataLoaderR, dataLoaderXY = get_dataloader(opt)
    dataLoaderXY_iter = iter(dataLoaderXY)
    dataLoaderR_iter = iter(dataLoaderR)
    start = time.perf_counter()
    print("train start!")
    for ii in range(opt.total_iter - start_iter):
        current_iter = ii + start_iter
        try:
            x, y, path_y = dataLoaderXY_iter.next()
        except:
            dataLoaderXY_iter = iter(dataLoaderXY)
            x, y, path_y = dataLoaderXY_iter.next()
        try:
            r, path_r = dataLoaderR_iter.next()
        except:
            dataLoaderR_iter = iter(dataLoaderR)
            r, path_r = dataLoaderR_iter.next()
        ### following the practice in U-GAT-IT:
        ### train D and G iteratively, but not training D multiple times than training G
        r = r.to(opt.device)
        x = x.to(opt.device)
        y = y.to(opt.device)
        if opt.debug and current_iter%500==0:
            torchvision.utils.save_image(y, 'train_mapping_y.png', normalize=True)
            torchvision.utils.save_image(x, 'train_mapping_x.png', normalize=True)
            torchvision.utils.save_image(r, 'train_mapping_r.png', normalize=True)
        ### train mapping and fix vaes iteratively
        ### train mapping
        #     train d
        # save gpu memory since no need calc grad for net G when train net D
        with torch.no_grad():
            z_x, mean_x, var_x = vae1(x, flow="enc")
            z_x2y              = mapping_net(z_x)
            recon_x2y          = vae2(z_x2y, flow="dec")
        adv_loss_d = lsgan_d(y_recon_d(r), y_recon_d(recon_x2y))
        loss_map_d = adv_loss_d
        loss_map_d.backward()
        optimizer_d2.step()
        optimizer_d2.zero_grad()
        #     train g
        z_x, mean_x, var_x = vae1(x, flow="enc")
        z_x2y              = mapping_net(z_x)
        recon_x2y          = vae2(z_x2y, flow="dec")
        z_y, mean_y, var_y = vae2(y, flow="enc")
        mapL1_loss = F.l1_loss(z_x2y, z_y)
        perc_loss = calc_loss_perc(vgg(recon_x2y), vgg(y))
        adv_loss_g = lsgan_g(y_recon_d(recon_x2y))
        loss_map_g = opt.weight_lambda1 * mapL1_loss + opt.weight_lambda2 * perc_loss + adv_loss_g
        loss_map_g.backward()
        optimizer_mapping.step()
        optimizer_mapping.zero_grad()
        # print mapping loss
        if current_iter%500==0:
            print(f"""STEP {current_iter:06d} {time.perf_counter() - start:.1f} s
            loss_map_d = adv_loss_d
            >>> {loss_map_d:.3f} = {adv_loss_d:.3f}
            loss_map_g = opt.weight_lambda1 * mapL1_loss + opt.weight_lambda2 * perc_loss + adv_loss_g
            >>> {loss_map_g:.3f} = {opt.weight_lambda1:.0f}*{mapL1_loss:.3f} + {opt.weight_lambda2:.0f}*{perc_loss:.3f} + {adv_loss_g:.3f}
            """)
            if opt.debug:
                # [print(k, 'channel 0:\n', v[0][0]) for k,v in list(model.named_parameters()) if k in ["netG_A.encoder.13.conv_block.5.weight", "netG_A.decoder.4.conv_block.5.weight"]]
                torchvision.utils.save_image(recon_x2y, 'train_mapping_recon_x2y.png', normalize=True)
        if (current_iter+1)%2000==0:
            # finish the current_iter-th step, e.g. finish iter0, save as 1, resume train from iter 1
            state = {
                'iter': current_iter,
                'batch_size': opt.batch_size,
                #
                'mapping_net': mapping_net.state_dict(),
                'y_recon_d': y_recon_d.state_dict(),
                #
                'optimizer_mapping': optimizer_mapping.state_dict(),
                'optimizer_d2': optimizer_d2.state_dict(),
                }
            torch.save(state, checkpoint_root + f'/mapping_checkpoint_{current_iter+1}.pth')
        # before fix vaes, the mapping net should work first
        if current_iter < opt.mapping_warmup_iters:
            continue
        ### fix vae1 and vae2
        #  train vae1 
        #     train d
        # save gpu memory since no need calc grad for net G when train net D
        with torch.no_grad():
            z_x, mean_x, var_x, recon_x = vae1(x)
            z_r, mean_r, var_r, recon_r = vae1(r)
            batch_requires_grad(z_x, mean_x, var_x, recon_x,
                                z_r, mean_r, var_r, recon_r)
        adv_loss_d_x = lsgan_d(xr_recon_d(x), xr_recon_d(recon_x))
        adv_loss_d_r = lsgan_d(xr_recon_d(r), xr_recon_d(recon_r))
        # z_x is real and z_r is fake here because let z_r close to z_x
        adv_loss_d_xr = lsgan_d(z_xr_d(z_x), z_xr_d(z_r))
        loss_1_d = adv_loss_d_x + adv_loss_d_r + adv_loss_d_xr
        loss_1_d.backward()
        optimizer_d1.step()
        optimizer_d1.zero_grad()
        ### vae1 train g
        # since we need update weights of G, the result should be re-calculate with grad
        z_x, mean_x, var_x, recon_x = vae1(x)
        z_r, mean_r, var_r, recon_r = vae1(r)
        adv_loss_g_x = lsgan_g(xr_recon_d(recon_x))
        adv_loss_g_r = lsgan_g(xr_recon_d(recon_r))
        # z_x is real and z_r is fake here because let z_r close to z_x
        adv_loss_g_xr = lsgan_g(z_xr_d(z_r))
        KLDloss_1_x = -0.5 * torch.sum(1 + var_x - mean_x.pow(2) - var_x.exp())  # KLD
        L1loss_1_x  = opt.weight_alpha * F.l1_loss(x, recon_x)
        KLDloss_1_r = -0.5 * torch.sum(1 + var_r - mean_r.pow(2) - var_r.exp())  # KLD
        L1loss_1_r  = opt.weight_alpha * F.l1_loss(r, recon_r)
        loss_1_g = adv_loss_g_x + KLDloss_1_x + L1loss_1_x \
                 + adv_loss_g_r + KLDloss_1_r + L1loss_1_r \
                 + adv_loss_g_xr
        loss_1_g.backward()
        optimizer_vae1.step()
        optimizer_vae1.zero_grad()

        ### vae2 train d
        # save gpu memory since no need calc grad for net G when train net D
        with torch.no_grad():
            z_y, mean_y, var_y, recon_y = vae2(y)
            batch_requires_grad(z_y, mean_y, var_y, recon_y)
        adv_loss_d_y = lsgan_d(y_recon_d(y), y_recon_d(recon_y))
        loss_2_d = adv_loss_d_y
        loss_2_d.backward()
        optimizer_d2.step()
        optimizer_d2.zero_grad()
        ### vae2 train g
        # since we need update weights of G, the result should be re-calculate with grad
        z_y, mean_y, var_y, recon_y = vae2(y)
        adv_loss_g_y = lsgan_g(y_recon_d(recon_y))
        KLDloss_1_y = -0.5 * torch.sum(1 + var_y - mean_y.pow(2) - var_y.exp())  # KLD
        L1loss_1_y  = opt.weight_alpha * F.l1_loss(y, recon_y)
        loss_2_g = adv_loss_g_y + KLDloss_1_y + L1loss_1_y
        loss_2_g.backward()
        optimizer_vae2.step()
        optimizer_vae2.zero_grad()
        # debug
        if opt.debug and current_iter%500==0:
            # [print(k, 'channel 0:\n', v[0][0]) for k,v in list(model.named_parameters()) if k in ["netG_A.encoder.13.conv_block.5.weight", "netG_A.decoder.4.conv_block.5.weight"]]
            torchvision.utils.save_image(recon_x, 'train_vae_recon_x.png', normalize=True)
            torchvision.utils.save_image(recon_r, 'train_vae_recon_r.png', normalize=True)
            torchvision.utils.save_image(recon_y, 'train_vae_recon_y.png', normalize=True)
        # print vae loss
        if current_iter%500==0:
            print(f"""STEP {current_iter:06d} {time.perf_counter() - start:.1f} s
            loss_1_d = adv_loss_d_x + adv_loss_d_r + adv_loss_d_xr
            >>> {loss_1_d:.3f} = {adv_loss_d_x:.3f} + {adv_loss_d_r:.3f} + {adv_loss_d_xr:.3f}
            loss_1_g = adv_loss_g_x + KLDloss_1_x + L1loss_1_x \
                 + adv_loss_g_r + KLDloss_1_r + L1loss_1_r \
                 + adv_loss_g_xr
            >>> {loss_1_g:.3f} = {adv_loss_g_x:.3f} + {KLDloss_1_x:.3f} + {L1loss_1_x:.3f} \
            |    + {adv_loss_g_r:.3f} + {KLDloss_1_r:.3f} + {L1loss_1_r:.3f} \
            |    + {adv_loss_g_xr:.3f}
            """)
        if (current_iter+1)%2000==0:
            # finish the current_iter-th step, e.g. finish iter0, save as 1, resume train from iter 1
            state = {
                'iter': current_iter,
                'batch_size': opt.batch_size,
                #
                'vae1': vae1.state_dict(),
                'xr_recon_d': xr_recon_d.state_dict(),
                'z_xr_d': z_xr_d.state_dict(),
                #
                'vae2': vae2.state_dict(),
                'y_recon_d': y_recon_d.state_dict(),
                #
                'optimizer_vae1': optimizer_vae1.state_dict(),
                'optimizer_d1': optimizer_d1.state_dict(),
                'optimizer_vae2': optimizer_vae2.state_dict(),
                'optimizer_d2': optimizer_d2.state_dict(),
                }
            torch.save(state, checkpoint_root + f'/vaes_checkpoint_{current_iter+1}.pth')
    print("global", time.perf_counter() - start, ' s')

@ghost
Copy link

ghost commented Dec 10, 2020

What is your latent space dimension for the VAE_1 and the VAE_2? We have tried z_dim = 256, 512, ..., 4096

Selection_405

@araomv
Copy link

araomv commented Dec 10, 2020

Self.mean_layer(z_x) layer involves (1) flattening followed by a fully connected layer or (2) just a fully connected layer?
If it is (1), What is the output dimension, I see that the number parameter for this will be very high (64x64x64x256, if the latent dimension is 256)
If it is (2), does it involve permuting the dimension? FC transform changes the last dimension (by sharing the same matrix)

@ALLinLLM
Copy link
Author

What is your latent space dimension for the VAE_1 and the VAE_2

I have tried fc layers to 64x32x32 latent space but the vae is not trained successfully and then tried 1x32x32, but again it failed.

I need find a third party vae github project of face restruction to check where i am wrong when implement reparam trick

@jmandivarapu1
Copy link

and the build_model(): you can find all related Net in Global/models/networks

def build_model(opt):
    """ stage 1.1  train 2 vae """
    # TODO stage 1.2 train mapping network
    print("build 2 vae and a transfer network")
    model = Pix2PixHDModel_Mapping()
    model.initialize(opt)

    ##### define networks
    print("build vae1 and vae2 ...")
    vae1 = networks.GlobalGenerator_DCDCv2(
        opt.input_nc,
        opt.output_nc,
        opt.ngf,
        opt.k_size,
        opt.n_downsample_global,
        networks.get_norm_layer(norm_type=opt.norm),
        opt=opt,
    )
    vae2 = networks.GlobalGenerator_DCDCv2(
        opt.input_nc,
        opt.output_nc,
        opt.ngf,
        opt.k_size,
        opt.n_downsample_global,
        networks.get_norm_layer(norm_type=opt.norm),
        opt=opt,
    )
    vae1.apply(networks.weights_init)
    vae2.apply(networks.weights_init)
    print("build vae1 and vae2 finish!")
    print("build D ...")
    xr_recon_d = Z_xr_Discriminator(input_nc=3, ndf=opt.disc_ch, n_layers=opt.disc_layers).to(opt.device)
    z_xr_d = Z_xr_Discriminator(input_nc=opt.feat_dim, ndf=opt.disc_ch, n_layers=opt.disc_layers).to(opt.device)
    y_recon_d = Z_xr_Discriminator(input_nc=3, ndf=opt.disc_ch, n_layers=opt.disc_layers).to(opt.device)
    print("build D finish")
    """ Optim """
    optimizer_vae1 = torch.optim.Adam(vae1.parameters(), 
        lr=opt.lr, betas=(0.5, 0.999), weight_decay=0, eps=1e-6)
    optimizer_d1 = torch.optim.Adam(itertools.chain(
        xr_recon_d.parameters(), 
        z_xr_d.parameters()),
        lr=opt.lr, betas=(0.5, 0.999), weight_decay=0, eps=1e-6)
    optimizer_vae2 = torch.optim.Adam(vae2.parameters(), 
        lr=opt.lr, betas=(0.5, 0.999), weight_decay=0, eps=1e-6)
    optimizer_d2 = torch.optim.Adam(y_recon_d.parameters(), 
        lr=opt.lr, betas=(0.5, 0.999), weight_decay=0, eps=1e-6)
    return vae1, xr_recon_d, z_xr_d, vae2, y_recon_d, \
        optimizer_vae1, optimizer_d1, optimizer_vae2, optimizer_d2

Can I get your source code repo of the training re-implementaion code?

@ALLinLLM
Copy link
Author

@jmandivarapu1 I would like to share, but I failed to train any meaningful result, so the training code may mislead you

@kileybck
Copy link

kileybck commented Dec 29, 2020 via email

@hayanakamura
Copy link

hayanakamura commented Jan 26, 2021

Thanks for your nice code!!
but what is the "build_mapping_models"?

mapping_net, vgg, optimizer_mapping = build_mapping_models(opt)

@zhangmozhe
Copy link
Contributor

Training code is just added. Welcome to go through the training details.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

6 participants