In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import ToTensor
from PIL import Image
import os
import re

In [2]:
# Define the generator model
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.relu2 = nn.ReLU(inplace=True)
        self.conv3 = nn.Conv2d(64, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        return x


# Define the discriminator model
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1)
        self.leakyrelu1 = nn.LeakyReLU(0.2, inplace=True)
        self.conv2 = nn.Conv2d(64, 64, kernel_size=3, padding=1)
        self.leakyrelu2 = nn.LeakyReLU(0.2, inplace=True)
        self.conv3 = nn.Conv2d(64, 1, kernel_size=3, padding=1)

    def forward(self, x):
        x = self.conv1(x)
        x = self.leakyrelu1(x)
        x = self.conv2(x)
        x = self.leakyrelu2(x)
        x = self.conv3(x)
        return x


# Define dataset class
class ImageDataset(Dataset):
    def __init__(self, compressed_dir, high_quality_dir, transform=None):
        self.compressed_dir = compressed_dir
        self.high_quality_dir = high_quality_dir
        self.transform = transform
        self.compressed_files = os.listdir(self.compressed_dir)

    def __len__(self):
        return len(self.compressed_files)

    def __getitem__(self, index):
        compressed_file = self.compressed_files[index]
        compressed_path = os.path.join(self.compressed_dir, compressed_file)
        split_file = re.split(r'\.|_', compressed_file)
        high_quality_file = split_file[0] + '.png'
        high_quality_path = os.path.join(self.high_quality_dir, high_quality_file)

        compressed_image = Image.open(compressed_path).convert('L')
        high_quality_image = Image.open(high_quality_path).convert('L')

        if self.transform:
            compressed_image = self.transform(compressed_image)
            high_quality_image = self.transform(high_quality_image)

        return compressed_image, high_quality_image


# Define the training loop for GAN
def train_gan(generator, discriminator, dataloader, criterion, optimizer_g, optimizer_d, device):
    generator.train()
    discriminator.train()
    running_gen_loss = 0.0
    running_dis_loss = 0.0
    for compressed_images, high_quality_images in dataloader:
        compressed_images = compressed_images.to(device)
        high_quality_images = high_quality_images.to(device)

        # Update the discriminator
        optimizer_d.zero_grad()
        fake_images = generator(compressed_images)
        real_outputs = discriminator(high_quality_images)
        fake_outputs = discriminator(fake_images.detach())

                             
        dis_loss = criterion(real_outputs, torch.ones_like(real_outputs)) + \
                   criterion(fake_outputs, torch.zeros_like(fake_outputs))
        dis_loss.backward()
        optimizer_d.step()
        running_dis_loss += dis_loss.item()

        # Update the generator
        optimizer_g.zero_grad()
        fake_outputs = discriminator(fake_images)
        gen_loss = criterion(fake_outputs, torch.ones_like(fake_outputs))
        gen_loss.backward()
        optimizer_g.step()
        running_gen_loss += gen_loss.item()

    return running_gen_loss, running_dis_loss


# Set up hyperparameters and training configurations
compressed_dir = 'compressed_images'
high_quality_dir = 'images'
batch_size = 8
learning_rate_g = 0.001 # was 0.0002
learning_rate_d = 0.001 # was 0.0002
num_epochs = 50

In [None]:
Generator.summary()

In [3]:
# Create dataset and dataloader
transform = ToTensor()
dataset = ImageDataset(compressed_dir, high_quality_dir, transform)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=0)

# Create generator and discriminator instances
generator = Generator()
discriminator = Discriminator()

In [4]:
# Define loss function and optimizers
criterion = nn.BCEWithLogitsLoss()
optimizer_g = optim.Adam(generator.parameters(), lr=learning_rate_g, betas=(0.5, 0.999))
optimizer_d = optim.Adam(discriminator.parameters(), lr=learning_rate_d, betas=(0.5, 0.999))

# Move models to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
generator.to(device)
discriminator.to(device)

Discriminator(
  (conv1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (leakyrelu1): LeakyReLU(negative_slope=0.2, inplace=True)
  (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (leakyrelu2): LeakyReLU(negative_slope=0.2, inplace=True)
  (conv3): Conv2d(64, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [None]:
# Training loop
for epoch in range(num_epochs):
    gen_loss, dis_loss = train_gan(generator, discriminator, dataloader, criterion, optimizer_g, optimizer_d, device)
    print(f'Epoch [{epoch+1}/{num_epochs}], Generator Loss: {gen_loss:.4f}, Discriminator Loss: {dis_loss:.4f}')
    # Save generator and discriminator models
    torch.save(generator.state_dict(), 'generbw3k.pth')
    torch.save(discriminator.state_dict(), 'discribw3k.pth')

In [None]:
# Save trained models
torch.save(generator.state_dict(), 'generatbw3k.pth')
torch.save(discriminator.state_dict(), 'discriminatbw3k.pth')

In [None]:
from torchvision.transforms import ToPILImage

In [None]:
# Create an instance of the Generator class
generator = Generator()

# Load the saved state dictionary
generator.load_state_dict(torch.load('generatorbw3k.pth'))
#generator.load_state_dict(torch.load('generator_low.pth'))
#generator.load_state_dict(torch.load('generator_low.pth'), strict=False)

# Set the generator to evaluation mode
generator.eval()

In [None]:
# Load a compressed image that you want to generate a high-quality image for
#compressed_image = Image.open('test.jpg').convert('L')
compressed_image = Image.open('test.jpg')

# Apply the same image transformation used during training
transform = ToTensor()
compressed_image = transform(compressed_image)

# Add a batch dimension to the input image
compressed_image = compressed_image.unsqueeze(0)

# Move the input image to the same device used for training
compressed_image = compressed_image.to(device)

In [None]:
# Generate a high-quality image from the compressed input image using the trained generator
with torch.no_grad():
    high_quality_image = generator(compressed_image)

# Move the generated high-quality image to CPU and convert it to a PIL image
high_quality_image = high_quality_image.cpu().squeeze(0)
high_quality_image = ToPILImage()(high_quality_image)

# Save the generated high-quality image
# high_quality_image.save('GAN2-BW-test.png')
high_quality_image.save('GAN2-low-test.png')