In [1]:
import torch.nn as nn
import torch
import os
import cv2
import numpy as np
import torchvision
from torch.utils.data import DataLoader
from torchinfo import summary

In [2]:
class Discriminator(nn.Module):
    def __init__(self, channels_img, features_d):
        super(Discriminator, self).__init__()
        self.features_d = features_d
        # VGG16 model
        self.vgg16 = torchvision.models.vgg16(weights=torchvision.models.VGG16_Weights.DEFAULT)
        for para in self.vgg16.parameters():
            para.requires_grad=True
        self.vgg16_feature_extractor = nn.Sequential(*list(self.vgg16.features.children())[:])
        # Discriminator
        self.disc = nn.Sequential(
            # input: N x channels_img x 480 x 480 (adjusted for image size 480)
            nn.Conv2d(channels_img, features_d, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2),
            # _block(in_channels, out_channels, kernel_size, stride, padding)
            self._conv_block(features_d, features_d * 2, 4, 2, 1),
            self._conv_block(features_d * 2, features_d * 4, 4, 2, 1),
            self._conv_block(features_d * 4, features_d * 8, 4, 2, 1),
            # After all _block img output is 4x4 (Conv2d below makes into 1x1)
            nn.Conv2d(features_d * 8, 1, kernel_size=4, stride=2, padding=0),
            nn.Sigmoid(),
        )
    def _conv_block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.LeakyReLU(0.2),
        )

    def forward(self, x):
        # VGG16 feature extraction
        vgg16_features = self.vgg16_feature_extractor(x)
        # Discriminator
        disc_output = self.disc(x)
        return disc_output

In [3]:
model=torchvision.models.vgg16()


In [4]:
summary(model=model)

Layer (type:depth-idx)                   Param #
VGG                                      --
├─Sequential: 1-1                        --
│    └─Conv2d: 2-1                       1,792
│    └─ReLU: 2-2                         --
│    └─Conv2d: 2-3                       36,928
│    └─ReLU: 2-4                         --
│    └─MaxPool2d: 2-5                    --
│    └─Conv2d: 2-6                       73,856
│    └─ReLU: 2-7                         --
│    └─Conv2d: 2-8                       147,584
│    └─ReLU: 2-9                         --
│    └─MaxPool2d: 2-10                   --
│    └─Conv2d: 2-11                      295,168
│    └─ReLU: 2-12                        --
│    └─Conv2d: 2-13                      590,080
│    └─ReLU: 2-14                        --
│    └─Conv2d: 2-15                      590,080
│    └─ReLU: 2-16                        --
│    └─MaxPool2d: 2-17                   --
│    └─Conv2d: 2-18                      1,180,160
│    └─ReLU: 2-19                

In [5]:
dis=Discriminator(channels_img=3,features_d=128)

In [6]:
summary(model=dis)

Layer (type:depth-idx)                   Param #
Discriminator                            --
├─VGG: 1-1                               --
│    └─Sequential: 2-1                   --
│    │    └─Conv2d: 3-1                  1,792
│    │    └─ReLU: 3-2                    --
│    │    └─Conv2d: 3-3                  36,928
│    │    └─ReLU: 3-4                    --
│    │    └─MaxPool2d: 3-5               --
│    │    └─Conv2d: 3-6                  73,856
│    │    └─ReLU: 3-7                    --
│    │    └─Conv2d: 3-8                  147,584
│    │    └─ReLU: 3-9                    --
│    │    └─MaxPool2d: 3-10              --
│    │    └─Conv2d: 3-11                 295,168
│    │    └─ReLU: 3-12                   --
│    │    └─Conv2d: 3-13                 590,080
│    │    └─ReLU: 3-14                   --
│    │    └─Conv2d: 3-15                 590,080
│    │    └─ReLU: 3-16                   --
│    │    └─MaxPool2d: 3-17              --
│    │    └─Conv2d: 3-18                

In [7]:
x=torch.randn(size=(1,3,480,480))
res=dis(x)

In [8]:
print(res)

tensor([[[[0.5038, 0.4998, 0.5008, 0.4971, 0.5012, 0.5031, 0.4981, 0.5016,
           0.5024, 0.4940, 0.4990, 0.5007, 0.4960, 0.5003],
          [0.4920, 0.5065, 0.4976, 0.5029, 0.4939, 0.4999, 0.4973, 0.4994,
           0.4966, 0.5020, 0.4966, 0.4985, 0.5029, 0.5043],
          [0.5037, 0.5018, 0.5033, 0.5034, 0.5029, 0.5008, 0.5065, 0.4987,
           0.5028, 0.5034, 0.4956, 0.5029, 0.4970, 0.4972],
          [0.4993, 0.4993, 0.5034, 0.4956, 0.5003, 0.4960, 0.4993, 0.5013,
           0.5024, 0.5007, 0.5026, 0.4997, 0.5055, 0.5001],
          [0.5028, 0.5033, 0.4996, 0.5095, 0.5002, 0.4984, 0.5017, 0.4968,
           0.4990, 0.5000, 0.4980, 0.5002, 0.5041, 0.5019],
          [0.4979, 0.5007, 0.5002, 0.4991, 0.5030, 0.5025, 0.5049, 0.4989,
           0.5042, 0.4996, 0.5035, 0.5022, 0.4992, 0.4987],
          [0.5045, 0.4939, 0.4998, 0.5016, 0.4984, 0.5052, 0.4965, 0.5002,
           0.4987, 0.5038, 0.4969, 0.5030, 0.4959, 0.4989],
          [0.5020, 0.5028, 0.4964, 0.5038, 0.5001, 0.49

In [9]:
class Generator(nn.Module):
    def __init__(self, channels_noise, channels_img, features_g):
        super(Generator, self).__init__()
        self.net = nn.Sequential(
            # Input: N x channels_noise x 1 x 1
            self._block(channels_noise, features_g * 32, 4, 1, 0),
            self._block(features_g * 32, features_g * 16, 4, 1, 0),  # img: 4x4
            self._block(features_g * 16, features_g * 8, 4, 2, 1),   # img: 8x8
            self._block(features_g * 8, features_g * 4, 4, 2, 1),    # img: 16x16
            self._block(features_g * 4, features_g * 2, 4, 2, 1),    # img: 32x32

            nn.ConvTranspose2d(
                features_g * 2, channels_img, kernel_size=4, stride=2, padding=1
            ),
            nn.Tanh(),
        )

    def _block(self, in_channels, out_channels, kernel_size, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(
                in_channels,
                out_channels,
                kernel_size,
                stride,
                padding,
                bias=False,
            ),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.net(x)


In [10]:
N, in_channels, H, W = 1, 3, 480, 480
noise_dim = 128

In [11]:
gen = Generator(128, in_channels, 8)


In [12]:
class Dataset(torch.utils.data.Dataset):
    def __init__(self,images_dir,transform):
        super(Dataset,self).__init__()
        self.images_dir=images_dir
        self.images=os.listdir(self.images_dir)
        self.transform=transform
    def __len__(self):
        return len(os.listdir(self.images_dir))
    def __getitem__(self,idx):
        image_dir=self.images[idx]
        image=cv2.imread(os.path.join(self.images_dir,image_dir))
        image=cv2.cvtColor(image,cv2.COLOR_BGR2RGB)
        label=torch.tensor(data=[1])
        image=self.transform(image)
        return image,label
        
        

In [13]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
LEARNING_RATE_G = 2e-4  # could also use two lrs, one for gen and one for disc
LEARNING_RATE_D=1e-2
BATCH_SIZE = 1
IMAGE_SIZE = 480
CHANNELS_IMG = 3
NOISE_DIM = 128
NUM_EPOCHS = 100
FEATURES_DISC = 64
FEATURES_GEN = 64

transforms = torchvision.transforms.Compose(
    [
    torchvision.transforms.ToTensor(),
    torchvision.transforms.Resize(IMAGE_SIZE,antialias=True)
    ]
)
print(device)

cuda


In [14]:
training_data=Dataset(images_dir="GAN_dataset",transform=transforms)

In [15]:
image,_=training_data[0]

In [16]:
print(image)

tensor([[[0.3792, 0.4387, 0.4141,  ..., 0.5222, 0.5079, 0.5902],
         [0.3712, 0.4364, 0.4326,  ..., 0.5320, 0.5043, 0.6215],
         [0.4543, 0.4513, 0.4479,  ..., 0.6220, 0.5585, 0.6169],
         ...,
         [0.5989, 0.6022, 0.6065,  ..., 0.5697, 0.5293, 0.5399],
         [0.6087, 0.6046, 0.5976,  ..., 0.5703, 0.5364, 0.5465],
         [0.5793, 0.5925, 0.6008,  ..., 0.5100, 0.5441, 0.5217]],

        [[0.2775, 0.3317, 0.3092,  ..., 0.3305, 0.3292, 0.4293],
         [0.2833, 0.3430, 0.3435,  ..., 0.3421, 0.3418, 0.4830],
         [0.3845, 0.3930, 0.4101,  ..., 0.4259, 0.4134, 0.5044],
         ...,
         [0.5353, 0.5300, 0.5318,  ..., 0.5061, 0.4720, 0.4962],
         [0.5437, 0.5321, 0.5218,  ..., 0.5129, 0.4760, 0.4937],
         [0.5085, 0.5164, 0.5200,  ..., 0.4501, 0.4783, 0.4593]],

        [[0.2445, 0.3044, 0.2837,  ..., 0.2612, 0.2465, 0.3430],
         [0.2625, 0.3296, 0.3386,  ..., 0.2704, 0.2465, 0.3752],
         [0.3656, 0.3787, 0.4045,  ..., 0.3423, 0.2999, 0.

In [17]:
print(training_data.__len__())

56


In [18]:
train_dataloader = DataLoader(training_data, batch_size=BATCH_SIZE, shuffle=True)


In [19]:
gen=gen.to(device=device)
disc=dis.to(device=device)

In [20]:
opt_gen = torch.optim.SGD(gen.parameters(), lr=LEARNING_RATE_G)
opt_disc = torch.optim.SGD(disc.parameters(), lr=LEARNING_RATE_D)
criterion = nn.BCELoss()


In [21]:
from torch.utils.tensorboard import SummaryWriter
criterion = nn.BCELoss()
fixed_noise = torch.randn(1, NOISE_DIM, 1, 1).to(device)
writer_real = SummaryWriter(f"logs/real")
writer_fake = SummaryWriter(f"logs/fake")
step = 0



In [22]:
gen.train()
disc.train()

Discriminator(
  (vgg16): VGG(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (15): ReLU(inplace=True)
      

In [23]:
def initialize_weights(model):
    # Initializes weights according to the DCGAN paper
    for m in model.modules():
        if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
            nn.init.normal_(m.weight.data, 0.0, 0.02)

In [24]:
initialize_weights(gen)
#initialize_weights(disc)

In [25]:
# Assuming your generator expects 128 channels
for epoch in range(NUM_EPOCHS):
    for batch_idx, (real, _) in enumerate(train_dataloader):
        real = real.to(device)
        noise = torch.randn(BATCH_SIZE, 128, 1, 1).to(device)  # Use 128 channels for the noise
        fake = gen(noise)

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        disc_real = disc(real).reshape(-1)
        loss_disc_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake.detach()).reshape(-1)
        loss_disc_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_disc = (loss_disc_real + loss_disc_fake) / 2
        disc.zero_grad()
        loss_disc.backward()
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = disc(fake).reshape(-1)
        loss_gen = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        loss_gen.backward()
        opt_gen.step()

        # Print losses occasionally and print to tensorboard
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{NUM_EPOCHS}] Batch {batch_idx+1}/{len(train_dataloader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise)
                # take out (up to) 32 examples
                img_grid_real = torchvision.utils.make_grid(real[:32], normalize=True)
                img_grid_fake = torchvision.utils.make_grid(fake[:32], normalize=True)
    
                writer_real.add_image("Real", img_grid_real, global_step=step)
                writer_fake.add_image("Fake", img_grid_fake, global_step=step)

            step += 1

Epoch [0/100] Batch 1/56                   Loss D: 0.6930, loss G: 0.6915
Epoch [1/100] Batch 1/56                   Loss D: 0.3880, loss G: 0.7249
Epoch [2/100] Batch 1/56                   Loss D: 0.0544, loss G: 2.4040
Epoch [3/100] Batch 1/56                   Loss D: 0.0172, loss G: 3.7571
Epoch [4/100] Batch 1/56                   Loss D: 0.0579, loss G: 2.9082
