In [2]:
import torch
from torch import nn
from torch.utils import data
import torchvision
import torch.nn.functional as F
from  torch.nn.utils import spectral_norm

In [39]:
class Config():
    IMAGE_HEIGHT = 256
    IMAGE_WIDTH = 256
    BATCH_SIZE = 4
    LATENT_DIM = 256
    DEVICE = "cpu"
cfg = Config()
class EncoderBlock(nn.Module):
    def __init__(self,channels_in,channels_out,with_norm=True):
        super(EncoderBlock,self).__init__()
        if with_norm:
            self.block = nn.Sequential(
                                        nn.Conv2d(in_channels=channels_in,out_channels=channels_out,\
                                                    kernel_size=3,stride=2,bias=False,padding=1),           
                                        nn.InstanceNorm2d(channels_out),
                                        nn.LeakyReLU()
                                        )
        else:
            self.block = nn.Sequential(
                                        nn.Conv2d(in_channels=channels_in,out_channels=channels_out,\
                                                    kernel_size=3,stride=2,bias=False,padding=1),
                                        nn.LeakyReLU()
                                        )
    def forward(self,x):
        return self.block(x)
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        self.block1 = EncoderBlock(3,64,with_norm=False)
        self.block2 = EncoderBlock(64,128)
        self.block3 = EncoderBlock(128,256)
        self.block4 = EncoderBlock(256,512)
        self.block5 = EncoderBlock(512,512)
        self.block6 = EncoderBlock(512,512)
        self.block7 = EncoderBlock(512,512)
        self.flattening_block = nn.Conv2d(512,2048,kernel_size=1,padding=0)

        self.linear_mu_branch = nn.Linear(in_features=2048,out_features=cfg.LATENT_DIM)
        self.linear_var_branch = nn.Linear(in_features=2048,out_features=cfg.LATENT_DIM)
        

    def forward(self,x):
        x = self.block1(x)
        print(x.shape)
        x = self.block2(x)
        print(x.shape)
        x = self.block3(x)
        print(x.shape)
        x = self.block4(x)
        print(x.shape)
        x = self.block5(x)
        print(x.shape)
        x = self.block6(x)
        print(x.shape)
        x = self.block7(x)
        print(x.shape)
        x = x.reshape(x.shape[0],x.shape[1]*x.shape[2]*x.shape[3])
        print(x.shape)
        
        
        mu = self.linear_mu_branch(x)
        var = self.linear_var_branch(x)
        
        print(mu.shape,var.shape)
        return mu,var
        # return x
    def get_latent_vector(self,mu,var):
        epsilon = torch.randn(mu.size(),device=cfg.DEVICE)
        latent_vec = mu  + torch.exp((var*0.5)) * epsilon  
        return latent_vec


img = torch.randn(1,3,256,256)
# print(img.shape)
enc = Encoder()
mean,variance  = enc(img)
enc.get_latent_vector(mean,variance).shape

torch.Size([1, 64, 128, 128])
torch.Size([1, 128, 64, 64])
torch.Size([1, 256, 32, 32])
torch.Size([1, 512, 16, 16])
torch.Size([1, 512, 8, 8])
torch.Size([1, 512, 4, 4])
torch.Size([1, 512, 2, 2])
torch.Size([1, 2048])
torch.Size([1, 256]) torch.Size([1, 256])


torch.Size([1, 256])

In [69]:
class SPADE(nn.Module):
    def __init__(self,num_channels):
        super(SPADE, self).__init__()
        self.bn = nn.BatchNorm2d(num_channels,affine=False)
        self.conv_1 = nn.Sequential(spectral_norm(nn.Conv2d(num_channels,128,kernel_size=3,padding=1)),\
                                   nn.ReLU())
        self.conv_1_1  = spectral_norm(nn.Conv2d(128, num_channels, kernel_size=3, padding=1))
        self.conv_2 = spectral_norm(nn.Conv2d(128,  num_channels, kernel_size=3, padding=1))
        
    def forward(self,x,segmentation_map):
        # print(x.shape)
        # BN
        x = self.bn(x)
        # Resize Map
        segmentation_map = F.interpolate(segmentation_map, size=x.size()[2:], mode='nearest')
        # Calc gamma and beta 
        output_shared = self.conv_1(x)
        gamma = self.conv_1_1(output_shared)
        beta = self.conv_2(output_shared)
        # rescale
        # print(x.shape,gamma.shape,beta.shape)
        out = x*(1+gamma) + beta
        return out

In [70]:
# spade , relu ,conv, spade, relu, conv
# skip should also be a spade,conv block
class SPADEResBlk(nn.Module):
    def __init__(self,num_features_in,num_features_out):
        super(SPADEResBlk,self,).__init__()
        self.spade1 = SPADE(num_channels=num_features_in)
        self.conv1 = spectral_norm(nn.Conv2d(in_channels=num_features_in,\
            out_channels=num_features_out,kernel_size=3,padding=1))
        self.spade2 = SPADE(num_channels=num_features_out)
        self.conv2 = spectral_norm(nn.Conv2d(in_channels=num_features_out,\
            out_channels=num_features_out,kernel_size=3,padding=1))
        self.skip_connection_spade = SPADE(num_channels=num_features_in)
        self.skip_connection_conv = spectral_norm(nn.Conv2d(in_channels=num_features_in,\
                                                out_channels=num_features_out,\
                                                    kernel_size=1,\
                                                        bias=False))
    
    def forward(self,x,segmentation_map):
        skip_features = self.skip_connection_spade(x,segmentation_map)
        skip_features = F.leaky_relu(skip_features)
        skip_features = self.skip_connection_conv(skip_features)

        x = self.conv1(F.leaky_relu(self.spade1(x,segmentation_map)))
        x = self.conv2(F.leaky_relu(self.spade2(x,segmentation_map)))
        return skip_features + x

In [125]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.linear1 = nn.Linear(cfg.LATENT_DIM,16384)
        self.upsample = nn.Upsample(scale_factor=2)
        self.block1 = SPADEResBlk(num_features_in=1024,num_features_out=1024)
        self.block2 = SPADEResBlk(num_features_in=1024,num_features_out=1024)
        self.block3 = SPADEResBlk(num_features_in=1024,num_features_out=512)
        self.block4 = SPADEResBlk(num_features_in=512,num_features_out=256)
        self.block5 = SPADEResBlk(num_features_in=256,num_features_out=128)
        self.block6 = SPADEResBlk(num_features_in=128,num_features_out=64)
        self.block7 = SPADEResBlk(num_features_in=64,num_features_out=32)

        self.conv = nn.Conv2d(in_channels=32,out_channels=3,kernel_size=3,padding=1)

    def forward(self,latent_vec,segmentation_map):
        x = self.linear1(latent_vec)
        x = x.reshape(-1,1024,4,4)
        x = self.block1(x,segmentation_map)
        x = self.upsample(x)
        x = self.block2(x,segmentation_map)
        x = self.upsample(x)
        x = self.block3(x,segmentation_map)
        x = self.upsample(x)
        x = self.block4(x,segmentation_map)
        x = self.upsample(x)
        x = self.block5(x,segmentation_map)
        x = self.upsample(x)
        x = self.block6(x,segmentation_map)
        x = self.upsample(x)
        print(x.shape)        
        x = self.block7(x,segmentation_map)
        x = self.upsample(x)
        print(x.shape)        
        x = F.leaky_relu(x)
        x = self.conv(x)
        x = torch.tanh(x)
        return x
        

In [100]:
gen = Generator().to("cpu")
segmentation_map = torch.randn(1,256,256,10).to("cpu")
lvec = torch.randn(1,256).to("cpu")
op = gen(lvec,segmentation_map)
op.shape

torch.Size([1, 64, 256, 256])
torch.Size([1, 32, 512, 512])




torch.Size([1, 3, 512, 512])

In [102]:
imga = torch.randn(2,3,128,128)
imgb = torch.randn(2,3,128,128)

torch.concat([imga,imgb],axis=1).shape

torch.Size([2, 6, 128, 128])

In [129]:
class DescriminatorBlock(nn.Module):
    def __init__(self,channels_in,channels_out,with_norm=True):
        super(DescriminatorBlock,self).__init__()
        if with_norm:
            self.block = nn.Sequential(
                                        spectral_norm(nn.Conv2d(in_channels=channels_in,out_channels=channels_out,\
                                                    kernel_size=4,stride=2,bias=False,padding=1)),           
                                        nn.InstanceNorm2d(channels_out),
                                        nn.LeakyReLU()
                                        )
        else:
            self.block = nn.Sequential(
                                        spectral_norm(nn.Conv2d(in_channels=channels_in,out_channels=channels_out,\
                                                    kernel_size=4,stride=2,bias=False,padding=1)),
                                        nn.LeakyReLU()
                                        )
    def forward(self,x):
        return self.block(x)
class Descriminator(nn.Module):
    def __init__(self):
        super(Descriminator,self).__init__()
        # change input channeldim        
        self.block1 = spectral_norm(nn.Conv2d(10,64,kernel_size=4,stride=2,bias=True))
        self.block2 = DescriminatorBlock(64,128,False)
        self.block3 = DescriminatorBlock(128,256)
        self.block4 = DescriminatorBlock(256,512)
        self.block5 = DescriminatorBlock(512,512)
        self.in7 = nn.InstanceNorm2d(512)
        self.conv8 = spectral_norm(nn.Conv2d(512,1,kernel_size=4))
    
    def forward(self,segmentation_map,img):
        concat_img = torch.concat([segmentation_map,img],dim=1)
        op1 = self.block2(self.block1(concat_img))
        op2 = self.block3(op1)
        op3 = self.block4(op2)
        # print(op3.shape)
        op4 = self.block5(op3)
        op5 = self.conv8(F.leaky_relu(self.in7(op4))).mean(dim=(1,2,3))
        return [op1,op2,op3,op4,op5]

desc = Descriminator()
imgA = torch.randn(5,7,512,512)
imgB = torch.randn(5,3,512,512)
op = desc(imgA,imgB)
[t.shape for t in op]

[torch.Size([5, 128, 127, 127]),
 torch.Size([5, 256, 63, 63]),
 torch.Size([5, 512, 31, 31]),
 torch.Size([5, 512, 15, 15]),
 torch.Size([5])]

In [162]:
# gen loss = g_loss + kl_loss + vgg_loss + feature_loss
# g loss between - loss between descriminator prediction, and actual label
# kl_loss : encoder output mean,variance
# vgg_loss: loss between generated image (by generator) and actual image
# feature_loss: loss between real desc output and fake desc output
# def gen_loss(pred,target):
#     loss = F.binary_cross_entropy_with_logits(pred,target)
#     return loss
class Gen_loss(nn.Module):
    def __init__(self):
        super(Gen_loss,self).__init__()
        self.criterion = F.binary_cross_entropy
    def forward(self,pred,target):
        return self.criterion(pred,target)
# def kl_loss( mu, logvar):
#     return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
class KLD_Loss(nn.Module):
    def __init__(self):
        super(KLD_Loss,self).__init__()
    def forward(self,mu,logvar):
        return -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())

# def vgg_loss ()

In [132]:
from torchvision import models
vgg = models.vgg19(pretrained=True)
vgg

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to C:\Users\noufal.samsudin/.cache\torch\hub\checkpoints\vgg19-dcbb9e9d.pth
100%|██████████| 548M/548M [06:03<00:00, 1.58MB/s] 


VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padd

In [136]:
for i, v in vgg.named_children():
    print(i)

features
avgpool
classifier


In [167]:
vgg = models.vgg19(pretrained=True).features
f5 = nn.Sequential(*[vgg[x] for x in range(30)])
img = torch.randn(1,3,256,256)
f5(img).shape
# for param in vgg.parameters():
#     param.requires_grad = False
#     print(param.name)
i=0
for _ in vgg.parameters():
    i+=1
print(i)

32


In [169]:
from torchvision import models
class VGGLoss(nn.Module):
    def __init__(self):
        super(VGGLoss,self).__init__()
        vgg = models.vgg19(pretrained=True).to(cfg.DEVICE).features
        for param in vgg.parameters():
            param.requires_grad = False
        self.f1 = nn.Sequential(*[vgg[x] for x in range(2)])
        self.f2 = nn.Sequential(*[vgg[x] for x in range(7)])
        self.f3 = nn.Sequential(*[vgg[x] for x in range(12)])
        self.f4 = nn.Sequential(*[vgg[x] for x in range(21)])
        self.f5 = nn.Sequential(*[vgg[x] for x in range(30)])
    def forward(self,x,y):
        # print(x_input.shape,x_input.shape)
        loss=0

        x1 = self.f1(x)
        y1 = self.f1(y)
        # print(x.shape,y.shape)
        # print(x1.shape,y1.shape)
        loss1 = F.l1_loss(x1,y1)
        # print(loss1)

        x2 = self.f2(x)
        y2 = self.f2(y)
        loss2 = F.l1_loss(x2,y2)

        x3 = self.f3(x)
        y3 = self.f3(y)
        loss3 = F.l1_loss(x3,y3)

        x4 = self.f4(x)
        y4 = self.f4(y)
        loss4 = F.l1_loss(x4,y4)

        x5 = self.f5(x)
        y5 = self.f5(y)
        loss5 = F.l1_loss(x5,y5)

        loss += loss1/32 + loss2/16 + loss3/8 + loss4/4 + loss5
        print(loss)

        return loss

vggloss = VGGLoss()
img1 = torch.randn(2,3,256,256)
img2 = torch.randn(2,3,256,256)
vggloss(img1,img2)

tensor(0.8025)


tensor(0.8025)

In [171]:
def feature_loss_disc(real_disc_outputs,fake_disc_outputs):
    with torch.no_grad():
        loss = 0
        for real_disc_output,fake_disc_output in zip(real_disc_outputs,fake_disc_outputs):
            for r_disc_output_feature,f_disc_output_feature in zip(real_disc_output,fake_disc_output):
                loss+= F.l1_loss(r_disc_output_feature,f_disc_output_feature)
        return loss/len(real_disc_outputs)
class FeatureLossDisc(nn.Module):
    def __init__(self):
        super(FeatureLossDisc,self).__init__()
    def forward(self,real_disc_outputs,fake_disc_outputs):
        loss=0
        for real_disc_output,fake_disc_output in zip(real_disc_outputs,fake_disc_outputs):
            loss+= F.l1_loss(real_disc_output,fake_disc_output)
        return loss/len(real_disc_outputs)

imgA = torch.randn(5,7,256,256)
imgB = torch.randn(5,3,256,256)
imgC = torch.randn(5,3,256,256)
op_real = desc(imgA,imgB)
op_fake = desc(imgA,imgC)
FeatureLossDisc()(op_real,op_fake)

tensor(0.2813, grad_fn=<DivBackward0>)

In [176]:
hinge_loss = nn.HingeEmbeddingLoss()
op_real[-1],torch.ones_like(op_real[-1]),torch.zeros_like(op_real[-1])

(tensor([0.4656, 0.3340, 0.4195, 0.3598, 0.6264], grad_fn=<MeanBackward1>),
 tensor([1., 1., 1., 1., 1.]),
 tensor([0., 0., 0., 0., 0.]))

In [177]:
hinge_loss(op_real[-1],torch.ones_like(op_real[-1]))

tensor(0.4411, grad_fn=<MeanBackward0>)