In [1]:
import torch
from torch.utils.data import DataLoader
from torch.autograd import Variable
import model
from crypko_data import crypkoFace as cy
from tqdm import tqdm
import torchvision
import matplotlib.pyplot as plt

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
class WGAN(object):
    def __init__(self, init_channel=100, batchsize=64, lr=1e-4, max_epoch=10, params_range=0.01, distraintimes=5):
        #models
        self.G=model.generator(init_channel).cuda()
        self.D=model.discriminator().cuda()
        #hyperparameters
        self.init_channel = init_channel
        self.batch_size = batchsize
        self.lr = lr
        self.max_epoch = max_epoch
        self.diss_train_times=distraintimes
        self.params_range=params_range
        #optmizer
        self.gen_opt=torch.optim.RMSprop(self.G.parameters(), lr=self.lr)
        self.dis_opt=torch.optim.RMSprop(self.D.parameters(), lr=self.lr)
        #dataloader
        dataset=cy()
        self.dataloader=DataLoader(dataset, batch_size=self.batch_size, shuffle=True, num_workers=2)
    
    def train(self):
        #turning models into training mode
        self.G.train()
        self.D.train()

        check_noise = Variable(torch.randn(100, self.init_channel, 1, 1)).cuda(non_blocking=True)
        
        #training start
        for e in range(self.max_epoch):
            w_loss=[]

            def main():
                for i,data in enumerate(tqdm(self.dataloader),0):
                    #prepare real data and fake data
                    real_raw=data.cuda(non_blocking=True)
                    real = Variable(real_raw).cuda(non_blocking=True)

                    noise=Variable(torch.randn(self.batch_size, self.init_channel, 1, 1)).cuda(non_blocking=True)
                    fake=self.G(noise).cuda(non_blocking=True)

                    #train the discriminator for several times
                    #enable the gradcomputation of discriminator
                    for p in self.D.parameters():
                        p.requires_grad=True
                    for j in range(self.diss_train_times):
                        #clipping
                        for p in self.D.parameters():
                            p=torch.clamp_(p, min=-self.params_range, max=self.params_range)
                        #neutralize the gradients
                        self.D.zero_grad()
                        #discriminate
                        real_dis=self.D(real.detach())
                        fake_dis=self.D(fake.detach())
                        #compute the loss
                        real_loss=real_dis.mean().view(-1)
                        fake_loss=fake_dis.mean().view(-1)
                        d_loss=fake_loss-real_loss
                        #backward and update the discriminator
                        d_loss.backward()
                        self.dis_opt().step()
                    #track the Wasserstein loss
                    w_loss.append(-d_loss.detach().item())
                    
                    #train the generator for one time
                    #freeze the parameters of discirminator
                    for p in self.D.parameters():
                        p.requires_grad=False
                    #neutralize the gradients
                    self.G.zero_grad()
                    #generate some fake imgs again
                    noise=Variable(torch.randn(self.batch_size, self.init_channel)).cuda()
                    fake=self.G(noise).cuda(non_blocking=True)
                    g_loss = self.D(fake)
                    #backward and update
                    g_loss.backward()
                    self.gen_opt.step()
            if __name__ == '__main__':
                main() 

            #progress check every epoch
            #generate 100 pics from same noise
            fake_sample = (self.G(check_noise).data + 1) / 2.0     #normalization
            torchvision.utils.save_image(fake_sample, f'.\\progress_check\\pics\\epoch_{e}.jpg', nrow=10)
            #track the Wasserstein loss
            plt.plot(w_loss)
            plt.savefig(f'.\\progress_check\\w_loss\\epoch_{e}.jpg')
            plt.cla()

            #save checkpoint every 2 epochs
            if e % 2 == 0:
                torch.save(self.G.state_dict(), f'.\\savepoint\\epoch_{e}_G.pth')
                torch.save(self.D.state_dict(), f'.\\savepoint\\epoch_{e}_D.pth')

In [6]:
net=WGAN(max_epoch=4)