In [None]:
%reload_ext autoreload
%autoreload 2

In [None]:
from fastai import *
from fastai.vision import *
torch.cuda.set_device(3)

## LSun bedroom data

For this lesson, we'll be using the bedrooms from the [LSUN dataset](http://lsun.cs.princeton.edu/2017/). The full dataset is a bit too large so we'll use a sample from [kaggle](https://www.kaggle.com/jhoward/lsun_bedroom).

In [None]:
path = Path('data/bedroom')
path.mkdir(parents=True, exist_ok=True)
path.ls()

Uncomment the next commands to download and extract the data in your machine.

In [None]:
#! kaggle datasets download -d jhoward/lsun_bedroom -p {path}  

In [None]:
#! unzip -q -n {path}/lsun_bedroom.zip -d {path}
#! unzip -q -n {path}/sample.zip -d {path}

We then grab all the images in the folder with the data block API. We don't create a validation set here for reasons we'll explain later.

In [None]:
class NoisyItem(ItemBase):
    def __init__(self, noise_sz): self.obj,self.data = noise_sz,torch.randn(noise_sz, 1, 1)
    def __str__(self):  return ''
    def apply_tfms(self, tfms, **kwargs): return self

In [None]:
class GANItemList(ImageItemList):
    _label_cls = ImageItemList
    
    def __init__(self, items, noise_sz:int=100, **kwargs):
        super().__init__(items, **kwargs)
        self.noise_sz = noise_sz
        self.copy_new.append('noise_sz')
    
    def get(self, i): return NoisyItem(self.noise_sz)
    def reconstruct(self, t): return NoisyItem(t.size(0))
    
    def show_xys(self, xs, ys, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
        super().show_xys(ys, xs, imgsize=imgsize, figsize=figsize, **kwargs)
    
    def show_xyzs(self, xs, ys, zs, imgsize:int=4, figsize:Optional[Tuple[int,int]]=None, **kwargs):
        super().show_xys(zs, xs, imgsize=imgsize, figsize=figsize, **kwargs)

In [None]:
def get_data(bs, size):
    train_ds = (GANItemList.from_folder(path).label_from_func(noop)
               .transform(tfms=[crop_pad(size=size, row_pct=(0,1), col_pct=(0,1))], size=size, tfm_y=True))
    return (ImageDataBunch.create(train_ds, valid_ds=None, path=path, bs=bs)
                     .normalize(do_x=False, stats = [torch.tensor([0.5,0.5,0.5]), torch.tensor([0.5,0.5,0.5])], do_y=True))

We'll begin with a small side and use gradual resizing.

In [None]:
data = get_data(128, 64)

In [None]:
data.show_batch(rows=5)

## Models

GAN stands for [Generative Adversarial Nets](https://arxiv.org/pdf/1406.2661.pdf) and were invented by Ian Goodfellow. The concept is that we will train two models at the same time: a generator and a critic. The generator will try to make new images similar to the ones in our dataset, and the critic job will try to classify real images from the ones the generator does. The generator returns images, the critic a single number (usually 0. for fake images and 1. for real ones).

We train them against each other in the sense that at each step (more or less), we:
1. Freeze the generator and train the critic for one step by:
  - getting one batch of true images (let's call that `real`)
  - generating one batch of fake images (let's call that `fake`)
  - have the critic evaluate each batch and compute a loss function from that; the important part is that it rewards positively the detection of real images and penalizes the fake ones
  - update the weights of the critic with the gradients of this loss
  
  
2. Freeze the critic and train the generator for one step by:
  - generating one batch of fake images
  - evaluate the critic on it
  - return a loss that rewards posisitivly the critic thinking those are real images; the important part is that it rewards positively the detection of real images and penalizes the fake ones
  - update the weights of the generator with the gradients of this loss
  
Here, we'll use the [Wassertein GAN](https://arxiv.org/pdf/1701.07875.pdf).

We create a generator and a critic that we pass to `gan_learner`. The noise_size is the size of the random vector from which our generator creates images.

In [None]:
generator = models.basic_generator(in_size=64, n_channels=3, n_extra_layers=1)
critic = models.basic_critic(in_size=64, n_channels=3, n_extra_layers=1)

In [None]:
class GANModule(nn.Module):
    "Wrapper around a `generator` and a `critic` to create a GAN."
    def __init__(self, generator:nn.Module, critic:nn.Module, gen_mode:bool=False):
        super().__init__()
        self.gen_mode = gen_mode
        self.generator,self.critic = generator,critic
    
    def forward(self, *args):
        return self.generator(*args) if self.gen_mode else self.critic(*args)
    
    def switch(self, gen_mode:bool=None):
        "Put the model in generator mode if `gen_mode`, in critic mode otherwise."
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode

In [None]:
class GANLoss(GANModule):
    def __init__(self, loss_funcD:Callable, loss_funcG:Callable, gan_model:nn.Module):
        super(GANModule, self).__init__()
        self.loss_funcD,self.loss_funcG,self.gan_model = loss_funcD,loss_funcG,gan_model
        
    def generator(self, output, target):
        fake = self.gan_model.critic(output)
        return self.loss_funcG(fake, target)
    
    def critic(self, real, input):
        fake = self.gan_model.generator(input.requires_grad_(False)).requires_grad_(True)
        fake = self.gan_model.critic(fake)
        return self.loss_funcD(real, fake)

In [None]:
class GANTrainer(LearnerCallback):
    "`LearnerCallback` that handles GAN Training."
    _order=-20
    def __init__(self, learn:Learner, clip:float=0.01, beta:float=0.98, gen_mode:bool=False):
        super().__init__(learn)
        self.clip,self.beta,self.gen_mode = clip,beta,gen_mode
        self.generator,self.critic = self.model.generator,self.model.critic

    def _set_trainable(self):
        train_model = self.generator if     self.gen_mode else self.critic
        loss_model  = self.generator if not self.gen_mode else self.critic
        requires_grad(train_model, True)
        requires_grad(loss_model, False)
    
    def on_train_begin(self, **kwargs):
        "Create the optimizers for the generator and disciminator."
        self.opt_gen = self.opt.new([nn.Sequential(*flatten_model(self.generator))])
        self.opt_disc = self.opt.new([nn.Sequential(*flatten_model(self.critic))])
        self.switch(self.gen_mode)
        self.dlosses,self.glosses = [],[]
        self.smoothenerG,self.smoothenerD = SmoothenValue(self.beta),SmoothenValue(self.beta)
        self.recorder.no_val=True
        self.recorder.add_metric_names(['gen_loss', 'disc_loss'])
        self.imgs,self.titles = [],[]
    
    def on_train_end(self, **kwargs):
        self.switch(gen_mode=True)
        
    def on_batch_begin(self, last_input, last_target, **kwargs):
        "Clamp the weights with `self.clip`."
        if self.clip is not None:
            for p in self.learn.model.critic.parameters(): 
                p.data.clamp_(-self.clip, self.clip)
        return (last_input,last_target) if self.gen_mode else (last_target, last_input)
        
    def on_backward_begin(self, last_loss, last_output, **kwargs):
        "Record `last_loss` in the proper list."
        last_loss = last_loss.detach().cpu()
        smooth = self.smoothenerG if self.gen_mode else self.smoothenerD
        losses = self.glosses if self.gen_mode else self.dlosses
        smooth.add_value(last_loss)
        losses.append(smooth.smooth)
        if self.gen_mode:
            self.last_gen = last_output.detach().cpu()
    
    def on_epoch_end(self, pbar, epoch, **kwargs):
        "Put the various losses in the recorder."
        self.recorder.add_metrics([self.smoothenerG.smooth,self.smoothenerD.smooth])
        self.imgs.append(Image(self.last_gen[0]/2 + 0.5))
        self.titles.append(f'Epoch {epoch}')
        pbar.show_imgs(self.imgs, self.titles)
    
    def switch(self, gen_mode:bool=None):
        self.gen_mode = (not self.gen_mode) if gen_mode is None else gen_mode
        self.opt.opt = self.opt_gen.opt if self.gen_mode else self.opt_disc.opt
        self._set_trainable()
        self.model.switch(gen_mode)
        self.loss_func.switch(gen_mode)

In [None]:
@dataclass
class FixedGANSwitcher(LearnerCallback):
    n_disc_iter:Union[int,Callable]
    n_gen_iter:Union[int,Callable]
    
    def on_train_begin(self, **kwargs):
        self.n_d,self.n_g = 0,0
    
    def on_batch_end(self, iteration, **kwargs):
        if self.learn.gan_trainer.gen_mode: 
            self.n_g += 1
            n_iter,n_in,n_out = self.n_gen_iter,self.n_d,self.n_g
        else:
            self.n_d += 1
            n_iter,n_in,n_out = self.n_disc_iter,self.n_g,self.n_d
        target = n_iter if isinstance(n_iter, int) else n_iter(n_in)
        if target == n_out: 
            self.learn.gan_trainer.switch()
            self.n_d,self.n_g = 0,0

In [None]:
generator = models.basic_generator(in_size=64, n_channels=3, n_extra_layers=1)
critic = models.basic_critic(in_size=64, n_channels=3, n_extra_layers=1)

In [None]:
gan = GANModule(generator, critic)
loss_func = GANLoss(WassersteinLoss(), NoopLoss(), gan)
learn = Learner(data, gan, loss_func=loss_func, opt_func=optim.RMSprop, wd=0., 
                callback_fns=[GANTrainer, partial(FixedGANSwitcher, n_disc_iter=5,n_gen_iter=1)])

In [None]:
learn.fit(30,2e-4)

In [None]:
learn.gan_trainer.switch(gen_mode=True)
learn.show_results(ds_type=DatasetType.Train, rows=25)

In [None]:
learn.save('wgan-30')

### Tests

In [None]:
learn.fit(10, 2e-4)

In [None]:
learn.save('stage1')

In [None]:
learn.show_results(rows=5)