In [1]:

import torch
from torch import nn

class ConvBlock(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        discriminator=False,
        use_act=True,
        use_bn=True,
        **kwargs,
    ):
        super().__init__()
        self.use_act = use_act
        self.cnn = nn.Conv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
        self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
        self.act = (
            nn.LeakyReLU(0.2, inplace=True) if discriminator else nn.PReLU(num_parameters=out_channels)
        )

    def forward(self, x):
        return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))


class UpsampleBlock(nn.Module):
    def __init__(self, in_c, scale_factor):
        super().__init__()
        self.conv = nn.Conv2d(in_c, in_c * scale_factor ** 2, 3, 1, 1)
        self.ps = nn.PixelShuffle(scale_factor)  # in_c * 4, H, W --> in_c, H*2, W*2
        self.act = nn.PReLU(num_parameters=in_c)

    def forward(self, x):
        return self.act(self.ps(self.conv(x)))


class ResidualBlock(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.block1 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1
        )
        self.block2 = ConvBlock(
            in_channels,
            in_channels,
            kernel_size=3,
            stride=1,
            padding=1,
            use_act=False,
        )

    def forward(self, x):
        out = self.block1(x)
        out = self.block2(out)
        return out + x


class Generator(nn.Module):
    def __init__(self, in_channels=3, num_channels=64, num_blocks=16):
        super().__init__()
        self.initial = ConvBlock(in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False)
        self.residuals = nn.Sequential(*[ResidualBlock(num_channels) for _ in range(num_blocks)])
        self.convblock = ConvBlock(num_channels, num_channels, kernel_size=3, stride=1, padding=1, use_act=False)
        self.upsamples = nn.Sequential(UpsampleBlock(num_channels, 2), UpsampleBlock(num_channels, 2))
        self.final = nn.Conv2d(num_channels, in_channels, kernel_size=9, stride=1, padding=4)

    def forward(self, x):
        initial = self.initial(x)
        x = self.residuals(initial)
        x = self.convblock(x) + initial
        x = self.upsamples(x)
        return torch.tanh(self.final(x))


class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 64, 128, 128, 256, 256, 512, 512]):
        super().__init__()
        blocks = []
        for idx, feature in enumerate(features):
            blocks.append(
                ConvBlock(
                    in_channels,
                    feature,
                    kernel_size=3,
                    stride=1 + idx % 2,
                    padding=1,
                    discriminator=True,
                    use_act=True,
                    use_bn=False if idx == 0 else True,
                )
            )
            in_channels = feature

        self.blocks = nn.Sequential(*blocks)
        self.classifier = nn.Sequential(
            nn.AdaptiveAvgPool2d((6, 6)),
            nn.Flatten(),
            nn.Linear(512*6*6, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
        )

    def forward(self, x):
        x = self.blocks(x)
        return self.classifier(x)


Hyper Parameters & Configuration

In [2]:
import torch
from PIL import Image
import torchvision.transforms

LOAD_MODEL = True
SAVE_MODEL = True
# CHECKPOINT_GEN = "gen.pth.tar"
# CHECKPOINT_DISC = "disc.pth.tar"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LEARNING_RATE = 1e-4
NUM_EPOCHS = 4
BATCH_SIZE = 1
NUM_WORKERS = 4
HIGH_RES = 2560
LOW_RES = HIGH_RES // 4
IMG_CHANNELS = 3

highres_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((HIGH_RES,HIGH_RES)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
        
    ]
)

lowres_transform = torchvision.transforms.Compose(
    [
        torchvision.transforms.Resize((LOW_RES,LOW_RES)),
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
        
    ]
)

both_transforms = torchvision.transforms.Compose(
    [
        torchvision.transforms.RandomCrop((HIGH_RES, HIGH_RES)),
        torchvision.transforms.RandomHorizontalFlip(p=0.5),
        torchvision.transforms.RandomRotation(.5),
    ]
)

test_transform = torchvision.transforms.Compose(
    [
        
        torchvision.transforms.ToTensor(),
        torchvision.transforms.Normalize(mean=[0, 0, 0], std=[1, 1, 1]),
    ]
)


Loss Function

In [3]:
import torch.nn as nn
from torchvision.models import vgg19

# phi_5,4 5th conv layer before maxpooling but after activation

class VGGLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.vgg = vgg19(pretrained=True).features[:36].eval().to(DEVICE)
        self.loss = nn.MSELoss()

        for param in self.vgg.parameters():
            param.requires_grad = False

    def forward(self, input, target):
        vgg_input_features = self.vgg(input)
        vgg_target_features = self.vgg(target)
        return self.loss(vgg_input_features, vgg_target_features)

In [4]:
import torch
import os
import numpy as np
from PIL import Image
from torchvision.utils import save_image


# def gradient_penalty(critic, real, fake, device):
#     BATCH_SIZE, C, H, W = real.shape
#     alpha = torch.rand((BATCH_SIZE, 1, 1, 1)).repeat(1, C, H, W).to(device)
#     interpolated_images = real * alpha + torch.tensor(fake).detach() * (1 - alpha)
#     torch.tensor(interpolated_images).requires_grad_(True)

#     # Calculate critic scores
#     mixed_scores = critic(interpolated_images)

#     # Take the gradient of the scores with respect to the images
#     gradient = torch.autograd.grad(
#         inputs=interpolated_images,
#         outputs=mixed_scores,
#         grad_outputs=torch.ones_like(mixed_scores),
#         create_graph=True,
#         retain_graph=True,
#     )[0]
#     gradient = gradient.view(gradient.shape[0], -1)
#     gradient_norm = gradient.norm(2, dim=1)
#     gradient_penalty = torch.mean((gradient_norm - 1) ** 2)
#     return gradient_penalty


# # def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
# #     print("=> Saving checkpoint")
# #     checkpoint = {
# #         "state_dict": model.state_dict(),
# #         "optimizer": optimizer.state_dict(),
# #     }
# #     torch.save(checkpoint, filename)


# # def load_checkpoint(checkpoint_file, model, optimizer, lr):
# #     print("=> Loading checkpoint")
# #     checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
# #     model.load_state_dict(checkpoint["state_dict"])
# #     optimizer.load_state_dict(checkpoint["optimizer"])

#     # If we don't do this then it will just have learning rate of old checkpoint
#     # and it will lead to many hours of debugging \:
#     # for param_group in optimizer.param_groups:
#     #     param_group["lr"] = lr


def plot_examples(low_res_folder, gen):
    files = os.listdir(low_res_folder)

    gen.eval()
    for file in files:
        image = Image.open(os.path.join(low_res_folder, file))
        with torch.no_grad():
            upscaled_img = gen(
                test_transform(image)
                .unsqueeze(0)
                .to(DEVICE)
            )
            # Saves tensor to image file
        save_image(upscaled_img * 0.5 + 0.5, "HR_gen"+file)
        break
    gen.train()


Train Segment

In [5]:
import torch
from torch import nn
from torch import optim
from loss import VGGLoss
import os
from tqdm import tqdm

torch.backends.cudnn.benchmark = True

genr_loss=[]
dic_loss=[]

def train_fn(hr_dir,lr_dir, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss):
    # loop = tqdm(loader, leave=True)
    hr_files=os.listdir(hr_dir)
    lr_files=os.listdir(lr_dir)
    
    for idx in tqdm(range(len(lr_files))):
        
        # Get the file name 
        high_res=hr_files[idx]
        low_res=lr_files[idx]
        
        # Access the image
        high_res=Image.open(os.path.join(hr_dir,high_res))
        low_res=Image.open(os.path.join(lr_dir,low_res))
        # Transform the images 
        high_res=highres_transform(high_res)
        low_res=lowres_transform(low_res)

        
        # Send to device
        high_res = high_res.to(DEVICE)
        low_res = low_res.to(DEVICE)
        
        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        low_res=torch.unsqueeze(low_res,0)
        fake = gen(low_res)
        
        #print("Success 1")
        
        high_res=torch.unsqueeze(high_res,0)
        disc_real = disc(high_res)
        #print("Success 2")
        disc_fake = disc(fake.detach())
        #print("Success 3")
        disc_loss_real = bce(
            disc_real, torch.ones_like(disc_real) - 0.1 * torch.rand_like(disc_real)
        )
        disc_loss_fake = bce(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = disc_loss_fake + disc_loss_real
        
        dic_loss.append(loss_disc)
        
        #print("Success 4")
        opt_disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()
        #print("Success 5")
        # Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        disc_fake = disc(fake)
        #print("Success 6")
        #l2_loss = mse(fake, high_res)
        adversarial_loss = 1e-3 * bce(disc_fake, torch.ones_like(disc_fake))
        loss_for_vgg = 0.006 * vgg_loss(fake, high_res)
        gen_loss = loss_for_vgg + adversarial_loss
        
        genr_loss.append(gen_loss)
        
        # print("gen_loss:" ,gen_loss)
        # print("disc_loss:" ,loss_disc)
        
        
        
        opt_gen.zero_grad()
        gen_loss.backward()
        opt_gen.step()
        #print("Success 7")
        #print(idx)
        if idx % 20 == 0:
            plot_examples("D:\All-Projects\Super Resolution Dataset Generator\Histopathology\Dataset\Testing Data\LR", gen)
        break



def main():
    # dataset = MyImageFolder(root_dir="SR_GAN_Prac\Data\\")
    # loader = DataLoader(
    #     dataset,
    #     batch_size=config.BATCH_SIZE,
    #     shuffle=True,
    #     pin_memory=True,
    #     num_workers=config.NUM_WORKERS,
    # )


    
    hr_dir="D:\All-Projects\Super Resolution Dataset Generator\Histopathology\Dataset\Testing Data\HR"
    lr_dir="D:\All-Projects\Super Resolution Dataset Generator\Histopathology\Dataset\Testing Data\LR"
    
    gen = Generator(in_channels=3).to(DEVICE)
    disc = Discriminator(in_channels=3).to(DEVICE)
    opt_gen = optim.Adam(gen.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
    opt_disc = optim.Adam(disc.parameters(), lr=LEARNING_RATE, betas=(0.9, 0.999))
    mse = nn.MSELoss()
    bce = nn.BCEWithLogitsLoss()
    vgg_loss = VGGLoss()

    # if config.LOAD_MODEL:
    #     load_checkpoint(
    #         config.CHECKPOINT_GEN,
    #         gen,
    #         opt_gen,
    #         config.LEARNING_RATE,
    #     )
    #     load_checkpoint(
    #        config.CHECKPOINT_DISC, disc, opt_disc, config.LEARNING_RATE,
    #     )

    for epoch in range(NUM_EPOCHS):
        train_fn(hr_dir,lr_dir, disc, gen, opt_gen, opt_disc, mse, bce, vgg_loss)
        print("Gen_loss:",torch.mean(torch.stack(genr_loss)))
        print("Disc_loss:",torch.mean(torch.stack(dic_loss)))
        # break
        #print("Final Success : {epoch} ")
        # break

        # if SAVE_MODEL:
        #     save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
        #     save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)

main()

  0%|          | 0/10 [04:27<?, ?it/s]


Gen_loss: tensor(0.0042, device='cuda:0', grad_fn=<MeanBackward0>)
Disc_loss: tensor(1.3948, device='cuda:0', grad_fn=<MeanBackward0>)


  0%|          | 0/10 [00:16<?, ?it/s]


Gen_loss: tensor(0.0029, device='cuda:0', grad_fn=<MeanBackward0>)
Disc_loss: tensor(1.2561, device='cuda:0', grad_fn=<MeanBackward0>)


  0%|          | 0/10 [00:16<?, ?it/s]


Gen_loss: tensor(0.0028, device='cuda:0', grad_fn=<MeanBackward0>)
Disc_loss: tensor(1.3082, device='cuda:0', grad_fn=<MeanBackward0>)


  0%|          | 0/10 [00:16<?, ?it/s]

Gen_loss: tensor(0.0031, device='cuda:0', grad_fn=<MeanBackward0>)
Disc_loss: tensor(1.1373, device='cuda:0', grad_fn=<MeanBackward0>)



