In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

In [2]:
class Discriminator(nn.Module):
  def __init__(self,img_channel,features_d):
    super(Discriminator,self).__init__()
    self.disc=nn.Sequential(
        #512
        nn.Conv2d(img_channel,features_d,kernel_size=4,stride=2,padding=1),
        nn.LeakyReLU(0.2),
        #256
        self._block(features_d,features_d*2,4,2,1),
        #128
        self._block(features_d*2,features_d*4,4,2,1),
        #64
        self._block(features_d*4,features_d*8,4,2,1),
        #32
        self._block(features_d*8,features_d*16,4,2,1),
        #16
        self._block(features_d*16,features_d*32,4,2,1),
        #8
        self._block(features_d*32,features_d*64,4,2,1),
        #4
        nn.Conv2d(features_d*64,1,kernel_size=4,stride=2,padding=0),

    )

  def _block(self,in_channel,out_channel,kernel,stride,padding):
    return nn.Sequential(
        nn.Conv2d(in_channel,out_channel,kernel,stride,padding,bias=False),
        nn.InstanceNorm2d(out_channel,affine=True),
        nn.LeakyReLU(0.2),
    )

  def forward(self,x):
    return self.disc(x)

In [3]:
class Generator(nn.Module):
  def __init__(self,z_dim,img_channel,features_g):
    super(Generator,self).__init__()
    self.gen=nn.Sequential(
        self._block(z_dim,features_g*128,4,1,0),#4
        self._block(features_g*128,features_g*64,4,2,1),#8
        self._block(features_g*64,features_g*32,4,2,1),#16
        self._block(features_g*32,features_g*16,4,2,1),#32
        self._block(features_g*16,features_g*8,4,2,1),#64
        self._block(features_g*8,features_g*4,4,2,1),#128
        self._block(features_g*4,features_g*2,4,2,1),#256
        nn.ConvTranspose2d(features_g*2,img_channel,4,2,1),#512
        nn.Tanh()
    )
  def _block(self,in_channel,out_channel,kernel,stride,padding):
    return nn.Sequential(
        nn.ConvTranspose2d(in_channel,out_channel,kernel,stride,padding,bias=False),
        nn.BatchNorm2d(out_channel),
        nn.LeakyReLU(0.2),
    )

  def forward(self,x):
    return self.gen(x)


In [4]:
def intilize_w(model):
  for m in model.modules():
    if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data,0.0,0.02)

In [4]:
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
learning_rate=1e-4
batch_size=30
image_size=512
img_channels=3
z_dim=100
num_epochs=100
features_d=16
features_g=16
critic_iteration=5
lambda_gp=10

In [5]:
transform=transforms.Compose([
    transforms.Resize([image_size,image_size]),
    transforms.ToTensor(),
    transforms.Normalize(
        [0.5 for _ in range(img_channels)],[0.5 for _ in range(img_channels)]
    ),
])

In [10]:
dataset=datasets.ImageFolder(root="/content/drive/MyDrive/img",transform=transform)
loader=DataLoader(dataset,batch_size=batch_size,shuffle=True)

In [6]:
gen=Generator(z_dim,img_channels,features_g).to(device)
disc=Discriminator(img_channels,features_d).to(device)
# intilize_w(gen)
# intilize_w(disc)
PATH1="/content/drive/MyDrive/weights/gen.pth"
gen.load_state_dict(torch.load(PATH1))
PATH2="/content/drive/MyDrive/weights/disc.pth"
disc.load_state_dict(torch.load(PATH2))

opt_gen=optim.Adam(gen.parameters(),lr=learning_rate,betas=(0.0,0.9))
opt_disc=optim.Adam(disc.parameters(),lr=learning_rate,betas=(0.0,0.9))

gen.train()
disc.train()

Discriminator(
  (disc): Sequential(
    (0): Conv2d(3, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(negative_slope=0.2)
    (2): Sequential(
      (0): Conv2d(16, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (3): Sequential(
      (0): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (4): Sequential(
      (0): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
      (1): InstanceNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=False)
      (2): LeakyReLU(negative_slope=0.2)
    )
    (5): Sequential(
      (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=

In [7]:
def gp(disc,real,fake,device="cpu"):
  b,c,h,w=real.shape
  epsilon=torch.randn(b,1,1,1).repeat(1,c,h,w).to(device)
  interpolated_img=real*epsilon+fake*(1-epsilon)
  mixed_scores=disc(interpolated_img)
  gradient=torch.autograd.grad(
      inputs=interpolated_img,
      outputs=mixed_scores,
      grad_outputs=torch.ones_like(mixed_scores),
      create_graph=True,
      retain_graph=True,
  )[0]
  gradient=gradient.view(gradient.shape[0],-1)
  gradient_norm=gradient.norm(2,dim=1)
  gradient_penalty=torch.mean((gradient_norm-1)**2)
  return gradient_penalty

In [8]:
from torchvision.utils import save_image
import os
sample_dir='/content/drive/MyDrive/gen2'

In [11]:
from tqdm import tqdm
for epoch in range(num_epochs):
  for batch_idx,(real,_) in enumerate(tqdm(loader)):
    real=real.to(device)

    for _ in range(critic_iteration):
      noise=torch.randn(batch_size,z_dim,1,1).to(device)
      fake=gen(noise)
      disc_real=disc(real).reshape(-1)
      disc_fake=disc(fake).reshape(-1)
      gpp=gp(disc,real,fake,device=device)
      loss_disc=-(torch.mean(disc_real)-torch.mean(disc_fake))+lambda_gp*gpp
      disc.zero_grad()
      loss_disc.backward(retain_graph=True)
      opt_disc.step()

    output=disc(fake).reshape(-1)
    loss_gen=-torch.mean(output)
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    if ((epoch % 5 == 0) and (batch_idx%100==0)):
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                  Loss D: {loss_disc:.4f}, loss G: {loss_gen:.4f}"
            )
            fixed_noise=torch.randn(1,z_dim,1,1).to(device)
            with torch.no_grad():
                fake = gen(fixed_noise)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)

                fake_fname = 'generated-images1-{0:0=4d}.png'.format(batch_idx+epoch)
                save_image(img_grid_fake, os.path.join(sample_dir, fake_fname), nrow=1)

  0%|          | 0/253 [00:00<?, ?it/s]

Epoch [0/100] Batch 0/253                   Loss D: -140.5782, loss G: 9584.6328


 40%|███▉      | 101/253 [14:00<15:01,  5.93s/it]

Epoch [0/100] Batch 100/253                   Loss D: -502.6076, loss G: 9439.4736


 79%|███████▉  | 200/253 [23:08<04:53,  5.53s/it]

Epoch [0/100] Batch 200/253                   Loss D: -64.7062, loss G: 9932.8008


100%|██████████| 253/253 [28:02<00:00,  6.65s/it]
100%|██████████| 253/253 [23:01<00:00,  5.46s/it]
100%|██████████| 253/253 [23:02<00:00,  5.46s/it]
100%|██████████| 253/253 [23:01<00:00,  5.46s/it]
100%|██████████| 253/253 [23:02<00:00,  5.46s/it]
  0%|          | 0/253 [00:00<?, ?it/s]

Epoch [5/100] Batch 0/253                   Loss D: -25.2509, loss G: -11373.3174


 40%|███▉      | 100/253 [09:06<13:53,  5.45s/it]

Epoch [5/100] Batch 100/253                   Loss D: -33.5646, loss G: -11318.0908


 79%|███████▉  | 200/253 [18:13<04:49,  5.46s/it]

Epoch [5/100] Batch 200/253                   Loss D: -28.6698, loss G: -11499.8633


100%|██████████| 253/253 [23:04<00:00,  5.47s/it]
100%|██████████| 253/253 [23:02<00:00,  5.46s/it]
100%|██████████| 253/253 [23:01<00:00,  5.46s/it]
100%|██████████| 253/253 [23:00<00:00,  5.46s/it]
100%|██████████| 253/253 [23:00<00:00,  5.46s/it]
  0%|          | 1/253 [00:06<27:30,  6.55s/it]

Epoch [10/100] Batch 0/253                   Loss D: -50.6007, loss G: -11023.5938


 40%|███▉      | 101/253 [09:13<14:40,  5.79s/it]

Epoch [10/100] Batch 100/253                   Loss D: -36.1520, loss G: -10427.5645


 79%|███████▉  | 200/253 [18:12<04:49,  5.46s/it]

Epoch [10/100] Batch 200/253                   Loss D: -44.5852, loss G: -10230.9814


100%|██████████| 253/253 [23:02<00:00,  5.46s/it]
100%|██████████| 253/253 [23:00<00:00,  5.45s/it]
100%|██████████| 253/253 [23:00<00:00,  5.46s/it]
100%|██████████| 253/253 [23:00<00:00,  5.46s/it]
100%|██████████| 253/253 [22:59<00:00,  5.45s/it]
  0%|          | 1/253 [00:06<27:29,  6.55s/it]

Epoch [15/100] Batch 0/253                   Loss D: -21.1246, loss G: -10439.2363


 40%|███▉      | 101/253 [09:13<14:38,  5.78s/it]

Epoch [15/100] Batch 100/253                   Loss D: -8.0934, loss G: -10179.5566


 79%|███████▉  | 201/253 [18:19<05:00,  5.78s/it]

Epoch [15/100] Batch 200/253                   Loss D: -17.9825, loss G: -9938.2178


100%|██████████| 253/253 [23:02<00:00,  5.46s/it]
100%|██████████| 253/253 [23:00<00:00,  5.46s/it]
100%|██████████| 253/253 [23:01<00:00,  5.46s/it]
100%|██████████| 253/253 [23:00<00:00,  5.46s/it]
100%|██████████| 253/253 [23:00<00:00,  5.45s/it]
  0%|          | 1/253 [00:06<27:31,  6.55s/it]

Epoch [20/100] Batch 0/253                   Loss D: -27.7576, loss G: -9499.9902


 40%|███▉      | 101/253 [09:12<14:40,  5.79s/it]

Epoch [20/100] Batch 100/253                   Loss D: -24.0155, loss G: -9491.1650


 79%|███████▉  | 201/253 [18:18<05:00,  5.78s/it]

Epoch [20/100] Batch 200/253                   Loss D: -9.2519, loss G: -9415.0898


100%|██████████| 253/253 [23:02<00:00,  5.46s/it]
100%|██████████| 253/253 [23:00<00:00,  5.46s/it]
100%|██████████| 253/253 [23:00<00:00,  5.46s/it]
100%|██████████| 253/253 [23:00<00:00,  5.46s/it]
100%|██████████| 253/253 [23:00<00:00,  5.45s/it]
  0%|          | 0/253 [00:00<?, ?it/s]

Epoch [25/100] Batch 0/253                   Loss D: -25.4309, loss G: -9115.2549


 40%|███▉      | 101/253 [09:13<14:41,  5.80s/it]

Epoch [25/100] Batch 100/253                   Loss D: -20.8227, loss G: -8980.8105


 79%|███████▉  | 201/253 [18:19<05:00,  5.78s/it]

Epoch [25/100] Batch 200/253                   Loss D: -17.2978, loss G: -8819.5693


100%|██████████| 253/253 [23:02<00:00,  5.47s/it]
 83%|████████▎ | 211/253 [19:12<03:49,  5.46s/it]


KeyboardInterrupt: 

In [12]:
torch.save(gen.state_dict(),"/content/drive/MyDrive/weights/gen1.pth")

In [13]:
torch.save(disc.state_dict(),"/content/drive/MyDrive/weights/disc1.pth")