In [1]:
import torch
from torch import nn
from pytorch_lightning import LightningModule
import torchvision as tv
from torchvision import transforms
import torch.nn.functional as F
import torch.utils.data as data_utils

import pytorch_lightning as pl
from pytorch_lightning import Trainer
from torch.utils.data import DataLoader

from collections import OrderedDict
from pytorch_lightning.loggers import WandbLogger
import wandb
import random

In [2]:
f"{random.random():.4f}"[2:]

'0158'

In [3]:
from my_pgan import *

In [4]:
class PGAN(pl.LightningModule):
    def __init__(self, lr=0.1, latent_size=512, final_res=32, curr_res=4, k=1,
                 alpha=0.0, alpha_step=0.1, loss_f=WGANGP_loss,
                 normalize=True, activation_f=nn.LeakyReLU(negative_slope=0.2)):

        super().__init__()
        self.id=f"{random.random():.3f}"[2:]
        self.save_hyperparameters(ignore=['activation_f', 'loss_f'])

        self.loss_f=loss_f
        self.generator=Generator(latent_size=latent_size, final_res=final_res, normalize=normalize, activation_f=activation_f)
        self.discriminator=Discriminator(latent_size=latent_size, final_res=final_res, normalize=normalize, activation_f=activation_f)

    def forward(self, z):
        return self.generator(z)

    def training_step(self, batch,  batch_idx, optimizer_idx):
      #  print(self.optimizers())

        xi, _=batch
        zi = torch.randn(xi.shape[0], self.hparams.latent_size) #TODO update zi sampling
        if self.hparams.normalize:
            zi=F.normalize(zi, dim=1, p=2)

        if optimizer_idx == 0: # train Generator
            g_loss=self.loss_f(self, zi=zi, net='generator')

            self.log("generator_loss", g_loss)
            self.log("curr_res", float(self.hparams.curr_res))
            return g_loss

        if optimizer_idx>0: # train discriminator
            d_loss=self.loss_f(self, zi=zi, xi=xi, net='discriminator')
            self.log("discriminator_loss", d_loss)
            self.log("curr_res", float(self.hparams.curr_res))
            return d_loss

    def configure_optimizers(self):
        decay=0
        if self.loss_f!=WGANGP_loss:
            decay=1e-4
        g_opt = torch.optim.Adam(self.generator.parameters(), lr=self.hparams.lr, betas=(0,0.99), eps=1e-8, weight_decay=decay)
        d_opt = torch.optim.Adam(self.discriminator.parameters(), lr=self.hparams.lr, betas=(0,0.99), eps=1e-8, weight_decay=decay)

        return [g_opt, d_opt]

    def optimizer_step(
            self,
            epoch,
            batch_idx,
            optimizer,
            optimizer_idx,
            optimizer_closure,
            on_tpu=False,
            using_native_amp=False,
            using_lbfgs=False,
        ):
        # update discrminator every step
        if optimizer_idx == 1:
            optimizer.step(closure=optimizer_closure)

        # update generator every k steps
        if optimizer_idx == 0:
            if (batch_idx + 1) % self.hparams.k == 0:
                # the closure (which includes the `training_step`) will be executed by `optimizer.step`
                optimizer.step(closure=optimizer_closure)
            else:
                # call the closure by itself to run `training_step` + `backward` without an optimizer step
                optimizer_closure()

    def save_generated_images(self, n=10, save_dir='./images/'):
        t=transforms.Compose([ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.ToPILImage()])#, transforms.Resize(size=(256,256))])
        zi = torch.randn(n, self.hparams.latent_size)
        gen_imgs=self.generator(zi)
        i=0
        for img in gen_imgs: 
            t(img).save(save_dir+self.id+'_res_'+str(self.hparams.curr_res)+'_img_'+str(i)+'.png')
            i+=1
            
    def on_train_epoch_end(self):
        if self.current_epoch>int(0.5/self.hparams.alpha_step): # first run 50 epochs on 4x4
            if self.hparams.alpha==0:
                
                self.save_generated_images()
                
                self.generator.add_scale(start_alpha=self.hparams.alpha_step)
                self.discriminator.add_scale(start_alpha=self.hparams.alpha_step)
                self.hparams.alpha+=self.hparams.alpha_step
                self.hparams.curr_res*=2

                
                if self.hparams.curr_res==32:
                    self.hparams.alpha_step/=2
                # update optimizers, 0-generator
                opts=self.optimizers()
                opts[0].add_param_group({'params': self.generator.residual.model.parameters()})
                opts[0].add_param_group({'params': self.generator.residual.introduce.parameters()})
                opts[1].add_param_group({'params': self.discriminator.residual.model.parameters()})
                opts[1].add_param_group({'params': self.discriminator.residual.introduce.parameters()})

               # print(model.generator.residual.model[1].module.weight.data[0,0])

            elif self.hparams.alpha>=1:
                self.generator.finish_adding_scale()
                self.discriminator.finish_adding_scale()
                self.hparams.alpha=0

                print("Done with: ", self.hparams.curr_res)
            else:
                by=self.hparams.alpha_step
                self.generator.increase_alpha(by=by)
                self.discriminator.increase_alpha(by=by)
                self.hparams.alpha=min(self.hparams.alpha+by, 1.0)


In [5]:
workers=8
batch_size=32

In [6]:
data_train=DataLoader(tv.datasets.CIFAR10("../data/01_raw",transform=transforms.ToTensor()), batch_size=batch_size, num_workers=workers)
#data_val=DataLoader(tv.datasets.CIFAR10("../data/01_raw", transform=transforms.ToTensor() , train=False ), batch_size=batch_size, num_workers=workers)

In [7]:
oneclass_data=[]
for d in data_train:
    data, labels=d
  #  print(labels)
    for l in range(len(labels)):
        if labels[l]==0:
            oneclass_data.append(data[l].view(1,3,32,32))

In [8]:
oneclass_dataloader=DataLoader(data_utils.TensorDataset(torch.cat(oneclass_data, dim=0), torch.zeros(len(oneclass_data))),  batch_size=batch_size, num_workers=workers)

Load or create new

In [9]:
model = PGAN(lr=1e-4, latent_size=512, final_res=4, activation_f=nn.LeakyReLU(negative_slope=0.2), 
             alpha_step=0.5)
wandb.finish()
wandb_logger=None

In [None]:
wandb_logger = WandbLogger(project="PGAN",  name='locally_WPGANPG_airplanes' ,entity="dl_image_classification")

In [10]:
model.generator.device='cuda'
trainer = Trainer(gpus=1, max_epochs=15, log_every_n_steps=50, logger=wandb_logger)
trainer.fit(model, oneclass_dataloader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name          | Type          | Params
------------------------------------------------
0 | generator     | Generator     | 6.6 M 
1 | discriminator | Discriminator | 6.6 M 
------------------------------------------------
13.1 M    Trainable params
0         Non-trainable params
13.1 M    Total params
52.472    Total estimated model params size (MB)


Training: 0it [00:00, ?it/s]

Done with:  8
Done with:  16
Done with:  32


RuntimeError: CUDA out of memory. Tried to allocate 256.00 MiB (GPU 0; 6.00 GiB total capacity; 2.95 GiB already allocated; 0 bytes free; 3.53 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
wandb.finish()

In [None]:
torch.save(model, 'model_16_local_airplanes.model')

In [None]:
1e-3

# Load from checkpoint

In [None]:
model = PGAN(lr=1e-3, latent_size=512, final_res=4, activation_f=nn.LeakyReLU(negative_slope=0.2), alpha_step=0.02)
model.hparams.alpha=0.02
model.generator.alpha=0.02
model.discriminator.alpha=0.02

In [None]:
model.generator.add_scale()        
model.discriminator.add_scale()        
model.generator.finish_adding_scale()        
model.discriminator.finish_adding_scale() 
model.hparams.curr_res*=2

model.generator.add_scale()        
model.discriminator.add_scale()        
model.generator.finish_adding_scale()        
model.discriminator.finish_adding_scale() 
model.hparams.curr_res*=2

model.generator.add_scale(start_alpha=0.02)        
model.discriminator.add_scale(start_alpha=0.02)        
model.hparams.curr_res*=2


In [None]:
model.hparams.alpha, model.generator.alpha, model.discriminator.alpha

In [None]:
checkpoint = torch.load('./PGAN/2lg1fi03/checkpoints/epoch=120-step=4840.ckpt')

In [None]:
for k, v in (checkpoint['state_dict'].items()):
    k=k.split(sep='.')
    print(k)
    if k[0]=='generator':
        
        if k[1]=='layers':
            if k[-1]=='weight':
                model.generator.layers[int(k[2])].module.weight.data=v
            else:
                model.generator.layers[int(k[2])].module.bias.data=v
        else: # residual
            
            if k[2]=='model':
                if k[-1]=='weight':
                    model.generator.residual.model[int(k[3])].module.weight.data=v
                else:
                    model.generator.residual.model[int(k[3])].module.bias.data=v
                    
            else: # introduce
                if k[-1]=='weight':
                    model.generator.residual.introduce[int(k[3])].module.weight.data=v
                else:
                    model.generator.residual.introduce[int(k[3])].module.bias.data=v
    else:
        if k[1]=='layers':
            if k[-1]=='weight':
                model.discriminator.layers[int(k[2])].module.bias.data=v
            else:
                model.discriminator.layers[int(k[2])].module.bias.data=v
                
        elif k[1]=='residual':
            
            if k[2]=='model':
                if k[-1]=='weight':
                    model.discriminator.residual.model[int(k[3])].module.weight.data=v
                else:
                    model.discriminator.residual.model[int(k[3])].module.bias.data=v
                    
            else:
                if k[-1]=='weight':
                    model.discriminator.residual.introduce[int(k[3])].module.weight.data=v
                else:
                    model.discriminator.residual.introduce[int(k[3])].module.bias.data=v
                    
        else: # decision layer
            if k[-1]=='weight':
                model.discriminator.decision_layer[int(k[2])].module.weight.data=v
            else:
                model.discriminator.decision_layer[int(k[2])].module.bias.data=v

In [None]:
wandb_logger = WandbLogger(project="PGAN",  name='locally_from_ckpt_airplanes' ,entity="dl_image_classification")
model.generator.device='cuda'
trainer = Trainer(gpus=1, max_epochs=60, log_every_n_steps=50, logger=wandb_logger)
trainer.fit(model, oneclass_dataloader)

In [None]:
model.load_from_checkpoint('epoch=317-step=124656.ckpt')


In [None]:
wandb.finish()

# Debugging

In [None]:
model.generator

In [None]:
latent_vec=F.normalize(torch.rand(40,512), p=2, dim=1)
latent_vec
model.generator.to('cpu')
model.generator.device='cpu'

In [None]:
from torchvision.utils import save_image


In [None]:
model.hparams

In [None]:

gen_imgs=model.generator(latent_vec)
t=transforms.Compose([ transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), transforms.ToPILImage()])#, transforms.Resize(size=(256,256))])
for img in gen_imgs: 
   # imgs.append(t(img))
    print(img.shape)
    
    img_upsampled=(nn.Upsample(scale_factor=(4,4), mode='nearest')(img.view(1, 3, model.hparams.curr_res,model.hparams.curr_res)))[0]
    
   # print(img_upsampled)
    t(img_upsampled).save("./imgs/img.png")
    display(t(img_upsampled))  
    


In [None]:
for d in oneclass_dataloader: 
   # imgs.append(t(img))
    img=d[0][0]
    print(img.shape)
#    img+=1
 #   img -= img.min(1, keepdim=True)[0]
  #  img /= img.max(1, keepdim=True)[0]
    
    img_upsampled=(nn.Upsample(scale_factor=(4,4), mode='nearest')(img.view(1, 3, 32,32)))[0]
    print(img_upsampled)
    
  #  img_upsampled[0]=1.0
    display(t(img_upsampled)) 
    break

In [None]:
img.min(0, keepdim=True)[0].shape

In [None]:
img_upsampled[0]=1

In [None]:
print(latent_vec.shape)
gen_imgs=model.generator(latent_vec)
(gen_imgs.shape)

In [None]:
decision=model.discriminator(gen_imgs)
decision.shape

In [None]:
 (decision ** 2).sum() * 1e-3

In [None]:
(decision ** 2).flatten() == (decision[:, 0] ** 2)

In [None]:
torch.mean(decision[:, 0])#.sum()

In [None]:
#t=transforms.ToPILImage()
#imgs=[]
t=transforms.ToPILImage()
for img in gen_imgs: 
   # imgs.append(t(img))
    display(t(img))
#display(imgs[0]), display(imgs[1]), display(imgs[2]), display(imgs[3])

In [None]:
#gen.finish_adding_scale()
gen.add_scale()
#gen.finish_adding_scale()
gen.layers

In [None]:
#dis.finish_adding_scale()
dis.add_scale()
#dis.finish_adding_scale()
dis.layers

In [None]:
print(latent_vec.shape)
gen_imgs=gen(latent_vec)
print(gen_imgs.shape)

In [None]:
decision=dis(gen_imgs)
decision.shape, decision