In [1]:
import os
import pandas as pd
import numpy as np


In [2]:
!git clone -b 18-pl_light_bolts_model_consume_data_changed https://github.com/nmud19/thesisGAN.git

Cloning into 'thesisGAN'...
remote: Enumerating objects: 521, done.[K
remote: Counting objects: 100% (267/267), done.[K
remote: Compressing objects: 100% (167/167), done.[K
remote: Total 521 (delta 163), reused 161 (delta 100), pack-reused 254[K
Receiving objects: 100% (521/521), 14.10 MiB | 12.81 MiB/s, done.
Resolving deltas: 100% (260/260), done.


In [3]:
os.chdir("thesisGAN")
from app.consume_data import consume_data

In [4]:
!pip install pytorch-lightning lightning-bolts

Collecting lightning-bolts
  Downloading lightning_bolts-0.5.0-py3-none-any.whl (316 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m316.8/316.8 kB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: lightning-bolts
Successfully installed lightning-bolts-0.5.0
[0m

In [5]:
from typing import Union, List

import torch
import torchvision
from pl_bolts.models.gans import Pix2Pix
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
import torchvision.models as models
from pytorch_lightning.callbacks import EarlyStopping
from app.consume_data import consume_data
import pytorch_lightning as pl
from pl_bolts.models.gans.pix2pix.components import PatchGAN


class OverpoweredPix2Pix(Pix2Pix):
    
    def __init__(self, in_channels, out_channels):
        super(OverpoweredPix2Pix,self).__init__(
            in_channels=in_channels, 
            out_channels=out_channels
        )
        self._create_inception_score()
    
    def _gen_step(self, real_images, conditioned_images):
        # Pix2Pix has adversarial and a reconstruction loss
        # First calculate the adversarial loss
        fake_images = self.gen(conditioned_images)
        disc_logits = self.patch_gan(fake_images, conditioned_images)
        adversarial_loss = self.adversarial_criterion(disc_logits, torch.ones_like(disc_logits))

        # calculate reconstruction loss
        recon_loss = self.recon_criterion(fake_images, real_images)
        lambda_recon = self.hparams.lambda_recon
        
        # calculate cosine similarity
        representations_real = self.feature_extractor(real_images).flatten(1)
        representations_fake = self.feature_extractor(fake_images).flatten(1)
        similarity_score_list = self.cosine_similarity(representations_real,representations_fake)
        cosine_sim = sum(similarity_score_list)/len(similarity_score_list)
        
        self.log("Gen Cosine Sim Loss ", 1 - cosine_sim.cpu().detach().numpy())    
        #print(adversarial_loss,1-cosine_sim, lambda_recon, recon_loss, )

        return (adversarial_loss) + (lambda_recon * recon_loss) + (lambda_recon * (1- cosine_sim) )
    
    def _create_inception_score(self):
        # init a pretrained resnet
        backbone = models.resnet50(pretrained=True)
        num_filters = backbone.fc.in_features
        layers = list(backbone.children())[:-1]
        self.feature_extractor = torch.nn.Sequential(*layers)
        self.cosine_similarity  = torch.nn.CosineSimilarity(dim=1, eps=1e-6)

    def validation_step(self, batch, batch_idx):
        """ Validation step """
        real, condition = batch
        with torch.no_grad():

            disc_loss = self._disc_step(real, condition)
            self.log("Valid PatchGAN Loss", disc_loss)

            gan_loss = self._gen_step(real, condition)
            self.log("Valid Generator Loss", gan_loss)
            
            #
            fake_images = self.gen(condition)
            representations_real = self.feature_extractor(real).flatten(1)
            representations_fake = self.feature_extractor(fake_images).flatten(1)
            similarity_score_list = self.cosine_similarity(representations_real,representations_fake)
            cosine_sim = sum(similarity_score_list)/len(similarity_score_list)
            
            self.log("Valid Cosine Sim", cosine_sim)
            

        return {
            'sketch': condition,
            'colour': real
        }

    def validation_epoch_end(self, outputs: Union[EPOCH_OUTPUT, List[EPOCH_OUTPUT]]) -> None:
        sketch = outputs[0]['sketch']
        colour = outputs[0]['colour']
        self.feature_extractor.eval()
        with torch.no_grad():
            gen_coloured = self.gen(sketch)
            representations_gen = self.feature_extractor(gen_coloured).flatten(1)
            representations_fake = self.feature_extractor(colour).flatten(1)
        
        similarity_score_list = self.cosine_similarity(representations_gen, representations_fake)
        similarity_score = sum(similarity_score_list)/len(similarity_score_list)
        
        grid_image = torchvision.utils.make_grid(
            [
                sketch[0], colour[0], gen_coloured[0],
            ],
            normalize=True
        )
        self.logger.experiment.add_image(f'Image Grid {str(self.current_epoch)} __ {str(similarity_score)} ', grid_image, self.current_epoch)
        

In [6]:

def test_lightning_model(model):
    """ Test model e2e """
    # data Module
    anime_sketch_data_module = consume_data.AnimeSketchDataModule(
        data_dir="/kaggle/input/anime-sketch-colorization-pair/data/",
        val_batch_size=16,
        train_batch_size=32
    )
    # Trainer
    # epoch_inference_callback = lit_model.EpochInference(valid_dataloader,use_gpu=False)
    # checkpoint_callback = pl.callbacks.model_checkpoint.ModelCheckpoint()
    logger = pl.loggers.TensorBoardLogger("tb_logs_v2", name="lightning_logs")
    """
    early_stop_callback = EarlyStopping(
        monitor="val_PatchGAN_loss",
        # min_delta=0.00,
        patience=5,
        verbose=True,
        mode="min"
    )
    """
    trainer = pl.Trainer(
        #fast_dev_run=True,
        max_epochs=4,
        logger=logger,
        callbacks=[
            #early_stop_callback,
            #     # epoch_inference_callback,
            #     # checkpoint_callback,
            pl.callbacks.TQDMProgressBar(refresh_rate=10)
        ],
        default_root_dir="chk",
        accelerator='gpu', 
        devices=1,
        # progress_bar_refresh_rate=1
    )
    trainer.fit(
        model=model,
        datamodule=anime_sketch_data_module,
        # ckpt_path="/Users/nimud/Downloads/thesisGAN_9/tb_logs/pix2pix_lightning_model/version_0/checkpoints/epoch=9-step=17780.ckpt"
    )
    print("complete!")


In [7]:
model = OverpoweredPix2Pix(
    in_channels=3,
    out_channels=3,
)
test_lightning_model(model=model)

Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


  0%|          | 0.00/97.8M [00:00<?, ?B/s]

The train test dataset lengths are :  14224 3545


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  np.object,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  np.bool,
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  np.object:
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  np.bool:
  import imp
Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  def randint(low, high=None, size=None, dtype=onp.int):  # pylint: disable=missing-function-docstring
  'nearest': pil_image.NEAREST,
  'bilinear': pil_image.BILINEAR,
  'bicubic': pil_image.BICUBIC,
  if hasattr(pil_image, 'HAMMING'):
  if hasattr(pil_image, 'BOX'):
  if hasattr(pil_image, 'LANCZOS'):


Sanity Checking: 0it [00:00, ?it/s]

  image = image.resize((scaled_width, scaled_height), Image.ANTIALIAS)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

complete!
