In [18]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [12]:
class Discriminator(nn.Module):
  def __init__(self,img_channel,features_d):
    super(Discriminator,self).__init__()
    self.disc=nn.Sequential(
        #512
        nn.Conv2d(img_channel,features_d,kernel_size=4,stride=2,padding=1),
        nn.LeakyReLU(0.2),
        #256
        self._block(features_d,features_d*2,4,2,1),
        #128
        self._block(features_d*2,features_d*4,4,2,1),
        #64
        self._block(features_d*4,features_d*8,4,2,1),
        #32
        self._block(features_d*8,features_d*16,4,2,1),
        #16
        self._block(features_d*16,features_d*32,4,2,1),
        #8
        self._block(features_d*32,features_d*64,4,2,1),
        #4
        nn.Conv2d(features_d*64,1,kernel_size=4,stride=2,padding=0),
        nn.Sigmoid()
    )

  def _block(self,in_channel,out_channel,kernel,stride,padding):
    return nn.Sequential(
        nn.Conv2d(in_channel,out_channel,kernel,stride,padding,bias=False),
        nn.BatchNorm2d(out_channel),
        nn.LeakyReLU(0.2),
    )

  def forward(self,x):
    return self.disc(x)

In [13]:
class Generator(nn.Module):
  def __init__(self,z_dim,img_channel,features_g):
    super(Generator,self).__init__()
    self.gen=nn.Sequential(
        self._block(z_dim,features_g*128,4,1,0),#4
        self._block(features_g*128,features_g*64,4,2,1),#8
        self._block(features_g*64,features_g*32,4,2,1),#16
        self._block(features_g*32,features_g*16,4,2,1),#32
        self._block(features_g*16,features_g*8,4,2,1),#64
        self._block(features_g*8,features_g*4,4,2,1),#128
        self._block(features_g*4,features_g*2,4,2,1),#256
        nn.ConvTranspose2d(features_g*2,img_channel,4,2,1),#512
        nn.Tanh()
    )
  def _block(self,in_channel,out_channel,kernel,stride,padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channel,out_channel,kernel,stride,padding,bias=False),
        nn.BatchNorm2d(out_channel),
        nn.LeakyReLU(0.2),
    )

  def forward(self,x):
    return self.gen(x)


In [14]:
def intilize_w(model):
  for m in model.modules():
    if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data,0.0,0.02)

In [20]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate=2e-4
batch_size=64
image_size=512
img_channels=3
z_dim=100
num_epochs=1000
features_d=64
features_g=64

In [17]:
transform=transforms.Compose([
    transforms.Resize(image_size,image_size),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(img_channels)],[0.5 for _ in range(img_channels)]
    ),
])

torch.Size([8, 1, 1, 1])


In [None]:
dataset=datasets.ImageFolder(root=" ",transform=transforms)
loader=DataLoader(dataset,batch_size=batch_size,shuffle=True)

In [24]:
gen=Generator(z_dim,img_channels,features_g)
disc=Discriminator(img_channels,features_d)
intilize_w(gen)
intilize_w(disc)

opt_gen=optim.Adam(gen.parameters(),lr=learning_rate,betas=(0.5,0.999))
opt_disc=optim.Adam(disc.parameters(),lr=learning_rate,betas=(0.5,0.999))
criterion=nn.BCELoss()

In [None]:
gen.train()
disc.train()

In [28]:
from torchvision.utils import save_image
import os
sample_dir='/content/drive/MyDrive/generated'

In [None]:
for epoch in range(num_epochs):
  for batch_idx,(real,_) in enumerate(loader):
    real=real.to(device)
    noise=torch.randn(batch_size,z_dim,1,1).to(device)
    fake=gen(noise)

    disc_real=disc(real).reshape(-1)
    loss_disc_real=criterion(disc_real,torch.ones_like(disc_real))
    disc_fake=disc(fake).reshape(-1)
    loss_disc_fake=criterion(disc_fake,torch.zeros_like(disc_fake))
    disc_loss=loss_disc_real+loss_disc_fake
    disc.zero_grad()
    disc_loss.backward(retain_graph=True)
    opt_disc.step()

    output=disc(fake).rehshap(-1)
    loss_gen=criterion(output,torch.ones_like(output))
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    if batch_idx % 50 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {disc_loss:.4f}, loss G: {loss_gen:.4f}"
            )
            fixed_noise=torch.randn(1,z_dim,1,1).to
            with torch.no_grad():
                fake = gen(fixed_noise)
                img_grid_fake = torchvision.utils.make_grid(fake[:1], normalize=True)

                fake_fname = 'generated-images1-{0:0=4d}.png'.format(batch_idx+epoch)
                save_image(img_grid_fake, os.path.join(sample_dir, fake_fname), nrow=1)



