In [1]:
import os
import numpy as np
import torch
from tqdm import tqdm
import torch.nn as nn
import torch.utils.data as data_utils
import wandb
import custom_loaders
import conv_layers

In [2]:
latent_len = 100
img_size = 64
n_channels = 1

class Args():
    def __init__(self):
        self.dataset = 'GAN'
        self.imagePath = '/root/data/data/JSRT/Images'
        self.image_size = img_size
        self.download = False
        self.imgC = n_channels
        self.num_images = None

args = Args()
train_dataset = custom_loaders.get_data_loader(args)

100%|██████████| 247/247 [00:02<00:00, 105.81it/s]


In [7]:
device = "cuda" if torch.cuda.is_available() else "cpu"

channelsG = [latent_len, 64, 128, 128, 32, n_channels]
channelsD = [n_channels, 64, 128, 128, 32, 1]

assert(len(channelsD) == len(channelsG))
assert(img_size == 2**(len(channelsD)))

G = conv_layers.GeneratorTrans(channelsG).to(device)
# G = conv_layers.GeneratorUpSample(channelsG).to(device)
D = conv_layers.Discriminator(channelsD).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizerG = torch.optim.Adam(G.parameters(), lr=0.001)
optimizerD = torch.optim.Adam(D.parameters(), lr=0.001)
wandb.init(project='pytorch-gan', entity='basujindal123')

In [8]:
bs = 32
train_loader = data_utils.DataLoader(train_dataset, batch_size=bs, shuffle=True)
fixed_noise = torch.rand(bs,latent_len,1,1).to(device)

epochs = 100
D_epochs = 5
lossD_Real = 0
lossD_Fake = 0
lossG = 0
iter = 0
log_iter = 10
log = True

In [9]:
for i in range(epochs):
    for real_imgs in tqdm(train_loader):

        iter+=1
        # Training Discriminator
        optimizerD.zero_grad()

        with torch.no_grad():
            z = torch.rand(bs,latent_len,1,1).to(device)
            fake_imgs = G(z)

        output = D(fake_imgs).flatten()
        labels = torch.tensor([0]*output.shape[0]).float().to(device)
        lossF = criterion(output, labels)

        # real_imgs = (torch.rand((bs,n_channels,img_size,img_size)) - 0.5)/0.5
        real_imgs = real_imgs.to(device)
        output = D(real_imgs).flatten()
        labels = torch.tensor([1]*output.shape[0]).float().to(device)
        lossR = criterion(output, labels)

        lossD = lossR + lossF
        lossD.backward()
        optimizerD.step()
        lossD_Real+=lossR.item()
        lossD_Fake+=lossR.item()

        if((iter+1)%D_epochs == 0):
            ## Training Generator
            G.train()
            optimizerG.zero_grad()
            z = torch.rand(bs,latent_len,1,1).to(device)
            fake_imgs = G(z)
            output = D(fake_imgs).flatten()

            label = torch.tensor([1]*output.shape[0]).float().to(device)
            lossG = criterion(output, label)
            lossG.backward()
            optimizerG.step()
            lossG = lossG.item()
            # print(lossG.item())


        if((iter+1)%log_iter == 0 and log==True):
            # print("Logging")
            G.eval()
            with torch.no_grad():
                fixed_fake_imgs = G(fixed_noise).detach()

            wandb.log({
                'lossG': lossG,
                'lossD_Real': lossD_Real,
                'lossD_Fake': lossD_Fake,
                'lossD': lossD_Real + lossD_Fake,
                'Fake Images': [wandb.Image(i) for i in fixed_fake_imgs[:-4]],
                'Real Images' : [wandb.Image(i) for i in real_imgs.detach()[:-4]]
                })
            
            lossD_Real = 0
            lossD_Fake = 0

100%|██████████| 8/8 [00:00<00:00, 94.65it/s]
100%|██████████| 8/8 [00:00<00:00, 31.52it/s]
100%|██████████| 8/8 [00:00<00:00, 20.76it/s]
100%|██████████| 8/8 [00:00<00:00, 29.66it/s]
100%|██████████| 8/8 [00:00<00:00, 29.33it/s]
100%|██████████| 8/8 [00:00<00:00, 98.28it/s]
100%|██████████| 8/8 [00:00<00:00, 31.72it/s]
100%|██████████| 8/8 [00:00<00:00, 30.51it/s]
100%|██████████| 8/8 [00:00<00:00, 29.29it/s]
100%|██████████| 8/8 [00:00<00:00, 30.26it/s]
100%|██████████| 8/8 [00:00<00:00, 93.00it/s]
100%|██████████| 8/8 [00:00<00:00, 20.38it/s]
100%|██████████| 8/8 [00:00<00:00, 30.25it/s]
100%|██████████| 8/8 [00:00<00:00, 29.78it/s]
100%|██████████| 8/8 [00:00<00:00, 27.72it/s]
100%|██████████| 8/8 [00:00<00:00, 82.69it/s]
100%|██████████| 8/8 [00:00<00:00, 29.43it/s]
100%|██████████| 8/8 [00:00<00:00, 29.37it/s]
100%|██████████| 8/8 [00:00<00:00, 30.97it/s]
100%|██████████| 8/8 [00:00<00:00, 20.21it/s]
100%|██████████| 8/8 [00:00<00:00, 94.60it/s]
100%|██████████| 8/8 [00:00<00:00,

In [None]:
channelsDown = [n_channels, 64, 128, 128, 128, 256, latent_len]
channelsUp = [latent_len, 256, 128, 128, 128, 64 ,n_channels]


class UNet(nn.Module):

    def __init__(self, channels, kernelSizeUp=3, kernelSizeDown=4):
        super().__init__()


        self.dis = nn.ModuleList([DownConv2d(channels[i], channels[i+1], kernelSize=4) for i in range(len(channels) - 2)])
        self.out = nn.Conv2d(in_channels=channels[-2], out_channels=channels[-1], kernel_size=kernelSizeDown)

        
        self.gen = nn.ModuleList([UpSampleConv(channels[i], channels[i+1], kernelSize=3) for i in range(len(channels) - 2)])
        self.upSample = nn.Upsample(scale_factor=2, mode='nearest') 
        self.output = nn.Conv2d(channels[-2], channels[-1], kernel_size=kernelSizeUp, padding = (kernelSizeUp-1)//2)
        


    def forward(self, image):

        li = []
        for block in self.dis:
            image = block(image)
            li.append(image)
        self.out(image)

            
        for block in self.gen:
            image = block(self.upSample(image))
            print(image.shape)

        output = torch.tanh(self.output(self.upSample(image)))
        return output

        return self.out(image)

In [None]:
# UNet Architecture
class UNet(nn.Module):
    def __init__(self):
        super(UNet, self).__init__()
        self.conv1 = nn.Conv2d(3,64, 3)
        self.conv2 = nn.Conv2d(64, 64, 3)
        self.bn1 = nn.BatchNorm2d(64)
        self.bn2 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(2, 2)

        self.conv3 = nn.Conv2d(64, 128, 3)
        self.conv4 = nn.Conv2d(128, 128, 3)
        self.bn3 = nn.BatchNorm2d(128)
        self.bn4 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(2, 2)

        self.conv5 = nn.Conv2d(128, 256, 3)
        self.conv6 = nn.Conv2d(256, 256, 3)
        self.bn5 = nn.BatchNorm2d(256)
        self.bn6 = nn.BatchNorm2d(256)
        self.pool3 = nn.MaxPool2d(2,2)

        self.conv7 = nn.Conv2d(256, 512, 3)
        self.conv8 = nn.Conv2d(512, 512, 3)
        self.bn7 = nn.BatchNorm2d(512)
        self.bn8 = nn.BatchNorm2d(512)
        self.upconv1 = nn.ConvTranspose2d(512,256,2, stride = 2)

        self.conv9 = nn.Conv2d(512, 256, 3)
        self.conv10 = nn.Conv2d(256, 256, 3)
        self.bn9 = nn.BatchNorm2d(256)
        self.bn10 = nn.BatchNorm2d(256)
        self.upconv2 = nn.ConvTranspose2d(256,128,2, stride = 2)

        self.conv11 = nn.Conv2d(256, 128, 3)
        self.conv12 = nn.Conv2d(128, 128, 3)
        self.bn11 = nn.BatchNorm2d(128)
        self.bn12 = nn.BatchNorm2d(128)
        self.upconv3 = nn.ConvTranspose2d(128,64,2, stride = 2)

        self.conv13 = nn.Conv2d(128, 64, 3)
        self.conv14 = nn.Conv2d(64, 64, 3)
        self.conv15 = nn.Conv2d(64, 1, 1)
        self.bn13 = nn.BatchNorm2d(64)
        self.bn14 = nn.BatchNorm2d(64)


    def forward(self, x):
        x = F.relu(self.bn1(self.conv1(x)))
        x1 = F.relu(self.bn2(self.conv2(x)))
        x = self.pool1(x1)

        x = F.relu(self.bn3(self.conv3(x)))
        x2 = F.relu(self.bn4(self.conv4(x)))
        x = self.pool2(x2)

        x = F.relu(self.bn5(self.conv5(x)))
        x3 = F.relu(self.bn6(self.conv6(x)))
        x = self.pool3(x3)

        x = F.relu(self.bn7(self.conv7(x)))
        x = F.relu(self.bn8(self.conv8(x)))
        x = self.upconv1(x)

        lpad = (x3.shape[2] - x.shape[2])//2
        tpad = (x3.shape[3] - x.shape[3])//2
        x = torch.cat((x3[:, :,lpad:-lpad, tpad:-tpad],x), 1)
        x = self.bn9(F.relu(self.conv9(x)))
        x = self.bn10(F.relu(self.conv10(x)))
        x = self.upconv2(x)

        lpad = (x2.shape[2] - x.shape[2])//2
        tpad = (x2.shape[3] - x.shape[3])//2
        x = torch.cat((x2[:, :,lpad:-lpad, tpad:-tpad],x), 1)
        x = self.bn11(F.relu(self.conv11(x)))
        x = self.bn12(F.relu(self.conv12(x)))
        x = self.upconv3(x)


        lpad = (x1.shape[2] - x.shape[2])//2
        tpad = (x1.shape[3] - x.shape[3])//2
        x = torch.cat((x1[:, :,lpad:-lpad, tpad:-tpad],x), 1)
        x = self.bn13(F.relu(self.conv13(x)))
        x = self.bn14(F.relu(self.conv14(x)))
        x = self.conv15(x)
        x = x[:,:,2:-2,2:-2]

        return x.squeeze(1)

In [None]:
## MNIST

trans = transforms.Compose([transforms.ToTensor(),
        # transforms.Normalize((0.5, ), (0.5, )),
    ])
train_dataset = torchvision.datasets.MNIST(root="datasets", train=True, download=True, transform=trans)
test_dataset = torchvision.datasets.MNIST(root="datasets", train=False,download=True, transform=trans)
plt.imshow((test_dataset.data.float()[0]/255).numpy())