In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dset
import matplotlib.pyplot as plt
import numpy as np
import os
import tarfile
import shutil
import requests
from torchvision.datasets import ImageFolder

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

def download_xray_dataset(url, save_path):
    response = requests.get(url, stream=True)
    response.raise_for_status()
    with open(save_path, 'wb') as fd:
        for chunk in response.iter_content(chunk_size=128):
            fd.write(chunk)

# Uncomment the following to download the dataset
'''download_xray_dataset('https://nihcc.box.com/shared/static/vfk49d74nhbxq3nqjg0900w5nvkorp5c.gz', 'data_part1.tar.gz')

# Assuming you've downloaded all parts and they are named as 'data_part1.tar.gz', 'data_part2.tar.gz', etc.
# Extract them in the respective order.
with tarfile.open('data_part1.tar.gz', 'r:gz') as tar:
    tar.extractall('./data')
'''

if not os.path.exists('./data/images/real_images'):
    os.makedirs('./data/images/real_images')

    for img_file in os.listdir('./data/images'):
        if img_file.endswith('.png'):
            shutil.move(os.path.join('./data/images', img_file), './data/images/real_images')

transform = transforms.Compose([
    transforms.Resize((64, 64)),  # You might need to change this depending on your GAN architecture
    transforms.Grayscale(num_output_channels=3),  # Convert to 3 channels
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])

dataset = ImageFolder(root='./data/images', transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=64*6, shuffle=True)

class ResidualBlockUp(nn.Module):
    """ Upsampling Residual Block """
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlockUp, self).__init__()

        self.main = nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, 3, stride=stride, padding=1, output_padding=stride-1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.ConvTranspose2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels)
        )

        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.ConvTranspose2d(in_channels, out_channels, kernel_size=1, stride=stride, output_padding=stride-1),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        return nn.ReLU(inplace=True)(self.main(x) + self.shortcut(x))

class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.model = nn.Sequential(
            nn.ConvTranspose2d(100, 512, 4, 1, 0),  # 4x4
            nn.BatchNorm2d(512),
            nn.ReLU(True),

            ResidualBlockUp(512, 256, stride=2),  # 8x8

            ResidualBlockUp(256, 128, stride=2),  # 16x16

            ResidualBlockUp(128, 64, stride=2),   # 32x32

            nn.ConvTranspose2d(64, 3, 4, 2, 1),   # 64x64
            nn.Tanh()
        )

    def forward(self, x):
        return self.model(x)


class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1, bias=False),  # 32x32
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(64, 128, 4, 2, 1, bias=False),  # 16x16
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(128, 256, 4, 2, 1, bias=False),  # 8x8
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(256, 512, 4, 2, 1, bias=False),  # 4x4
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Flatten(),
            nn.Linear(512*4*4, 1),

            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).squeeze()





generator = Generator().to(device)
discriminator = Discriminator().to(device)


# Loss and Optimizers
criterion = nn.BCELoss()
optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0002)
optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002)


# Number of epochs
num_epochs = 25

# Lists to keep track of progress
img_list = []
G_losses = []
D_losses = []

# Training Loop
for epoch in range(num_epochs):

    for i, data in enumerate(dataloader, 0):

        # (1) Update Discriminator: maximize log(D(x)) + log(1 - D(G(z)))

        ## Train with real images
        discriminator.zero_grad()

        real_images = data[0].to(device)
        b_size = real_images.size(0)
        labels = torch.full((b_size,), 1, device=device, dtype=torch.float)


        output = discriminator(real_images).view(-1)
        d_loss_real = criterion(output, labels)
        d_loss_real.backward()

        ## Train with fake images
        noise = torch.randn(b_size, 100, 1, 1).to(device)
        fake_images = generator(noise).detach()
        labels.fill_(0)

        output = discriminator(fake_images.detach()).view(-1)
        d_loss_fake = criterion(output, labels)
        d_loss_fake.backward()

        d_loss = d_loss_real + d_loss_fake
        optimizer_d.step()

        # (2) Update Generator: maximize log(D(G(z)))
        generator.zero_grad()
        labels.fill_(1)

        output = discriminator(fake_images).view(-1)
        g_loss = criterion(output, labels)
        g_loss.backward()

        optimizer_g.step()

        # Print stats
        if i % 5 == 0:
            print(f"[{epoch}/{num_epochs}] [{i}/{len(dataloader)}] D_loss: {d_loss.item()} | G_loss: {g_loss.item()}")

        # Save losses for plotting later
        G_losses.append(g_loss.item())
        D_losses.append(d_loss.item())

    # Save generator's output after each epoch
    with torch.no_grad():
        fake_images = generator(noise).detach().cpu()
    img_list.append(torchvision.utils.make_grid(fake_images, padding=2, normalize=True))


print("Finished Training")

# Visualize the GAN's progression (last epoch result)
plt.figure(figsize=(10,10))
plt.axis("off")
plt.imshow(np.transpose(img_list[-1], (1,2,0)))
plt.show()


[0/25] [0/14] D_loss: 1.4899190664291382 | G_loss: 1.8973863124847412
[0/25] [5/14] D_loss: 0.00783935934305191 | G_loss: 9.400535583496094
[0/25] [10/14] D_loss: 0.002823860850185156 | G_loss: 11.908961296081543
[1/25] [0/14] D_loss: 0.009661608375608921 | G_loss: 12.882421493530273
