In [15]:
latent_size = 100
hidden_size = 256
image_size = 300
num_epochs = 100
batch_size = 32


In [13]:
import torch
import torch.nn as nn
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
import os


def train_gan(latent_size, hidden_size, image_size, num_epochs, batch_size, sample_dir):
    # Device configuration
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    # Create directory for saving generated samples
    os.makedirs(sample_dir, exist_ok=True)

    # Image preprocessing and augmentation
    transform = transforms.Compose([
        transforms.Resize(image_size),
        transforms.CenterCrop(image_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])

    # Load the dataset
    dataset = ImageFolder(root=root, transform=transform)
    data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Discriminator network
    D = nn.Sequential(
        nn.Linear(image_size * image_size * 3, hidden_size),
        nn.LeakyReLU(0.2),
        nn.Linear(hidden_size, hidden_size),
        nn.LeakyReLU(0.2),
        nn.Linear(hidden_size, 1),
        nn.Sigmoid()
    ).to(device)

    # Generator network
    G = nn.Sequential(
        nn.Linear(latent_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, hidden_size),
        nn.ReLU(),
        nn.Linear(hidden_size, image_size * image_size * 3),
        nn.Tanh()
    ).to(device)

    # Loss function and optimizers
    criterion = nn.BCELoss()
    d_optimizer = torch.optim.Adam(D.parameters(), lr=0.0002)
    g_optimizer = torch.optim.Adam(G.parameters(), lr=0.0002)

    # Training the GAN
    total_step = len(data_loader)
    for epoch in range(num_epochs):
        for i, (images, _) in enumerate(tqdm(data_loader)):
            batch_size = images.size(0)
            images = images.reshape(batch_size, -1).to(device)

            # Create the labels for real and fake images
            real_labels = torch.ones(batch_size, 1).to(device)
            fake_labels = torch.zeros(batch_size, 1).to(device)

            # Training the discriminator
            # Real images
            outputs = D(images)
            d_loss_real = criterion(outputs, real_labels)
            real_score = outputs

            # Fake images
            z = torch.randn(batch_size, latent_size).to(device)
            fake_images = G(z)
            outputs = D(fake_images.detach())
            d_loss_fake = criterion(outputs, fake_labels)
            fake_score = outputs

            # Combined loss
            d_loss = d_loss_real + d_loss_fake

            # Update discriminator
            d_optimizer.zero_grad()
            d_loss.backward()
            d_optimizer.step()

            # Training the generator
            z = torch.randn(batch_size, latent_size).to(device)
            fake_images = G(z)
            outputs = D(fake_images)
            g_loss = criterion(outputs, real_labels)

            # Update generator
            g_optimizer.zero_grad()
            g_loss.backward()
            g_optimizer.step()

            # Print training progress
            if (i + 1) % 200 == 0:
                print('Epoch [{}/{}], Step [{}/{}], d_loss: {:.4f}, g_loss: {:.4f}, D(x): {:.2f}, D(G(z)): {:.2f}'
                      .format(epoch + 1, num_epochs, i + 1, total_step, d_loss.item(), g_loss.item(),
                              real_score.mean().item(), fake_score.mean().item()))

        # Save generated images
        if (epoch + 1) % 10 == 0:
            fake_images = fake_images.reshape(fake_images.size(0), 3, image_size, image_size)
            save_image((fake_images + 1) / 2, os.path.join(sample_dir, 'generated_images-{}.png'.format(epoch + 1)))

    # Save the trained models
    torch.save(G.state_dict(), os.path.join(sample_dir, 'generator.pth'))
    torch.save(D.state_dict(), os.path.join(sample_dir, 'discriminator.pth'))



In [21]:
sample_dir = 'conventional_unocated_1200dpi'
root ='E:\code\THESIS\printer_source\DATA_micro_print\DATA_micro_print\conventional_unocated_1200dpi'
train_gan(latent_size, hidden_size, image_size, num_epochs, batch_size, sample_dir)
torch.cuda.empty_cache()

100%|██████████| 89/89 [00:40<00:00,  2.19it/s]
100%|██████████| 89/89 [00:08<00:00, 10.09it/s]
100%|██████████| 89/89 [00:08<00:00, 10.15it/s]
100%|██████████| 89/89 [00:08<00:00, 10.31it/s]
100%|██████████| 89/89 [00:08<00:00, 10.00it/s]
100%|██████████| 89/89 [00:08<00:00, 10.02it/s]
100%|██████████| 89/89 [00:08<00:00, 10.26it/s]
100%|██████████| 89/89 [00:08<00:00, 10.21it/s]
100%|██████████| 89/89 [00:08<00:00,  9.91it/s]
100%|██████████| 89/89 [00:08<00:00, 10.22it/s]
100%|██████████| 89/89 [00:08<00:00, 10.28it/s]
100%|██████████| 89/89 [00:08<00:00, 10.17it/s]
100%|██████████| 89/89 [00:08<00:00, 10.32it/s]
100%|██████████| 89/89 [00:09<00:00,  9.65it/s]
100%|██████████| 89/89 [00:11<00:00,  7.50it/s]
100%|██████████| 89/89 [00:11<00:00,  7.65it/s]
100%|██████████| 89/89 [00:11<00:00,  7.63it/s]
100%|██████████| 89/89 [00:11<00:00,  7.61it/s]
100%|██████████| 89/89 [00:10<00:00,  8.74it/s]
100%|██████████| 89/89 [00:08<00:00,  9.97it/s]
100%|██████████| 89/89 [00:09<00:00,  9.

In [28]:
sample_dir = 'laser_coated_600dpi'
root ='E:\code\THESIS\printer_source\DATA_micro_print\DATA_micro_print\laser_coated_600dpi'
train_gan(latent_size, hidden_size, image_size, num_epochs, batch_size, sample_dir)
torch.cuda.empty_cache()

100%|██████████| 57/57 [00:07<00:00,  7.80it/s]
100%|██████████| 57/57 [00:06<00:00,  9.06it/s]
100%|██████████| 57/57 [00:06<00:00,  9.10it/s]
100%|██████████| 57/57 [00:06<00:00,  9.10it/s]
100%|██████████| 57/57 [00:06<00:00,  9.11it/s]
100%|██████████| 57/57 [00:06<00:00,  9.15it/s]
100%|██████████| 57/57 [00:06<00:00,  9.25it/s]
100%|██████████| 57/57 [00:06<00:00,  9.18it/s]
100%|██████████| 57/57 [00:06<00:00,  9.19it/s]
100%|██████████| 57/57 [00:06<00:00,  9.05it/s]
100%|██████████| 57/57 [00:06<00:00,  9.10it/s]
100%|██████████| 57/57 [00:06<00:00,  9.27it/s]
100%|██████████| 57/57 [00:06<00:00,  9.10it/s]
100%|██████████| 57/57 [00:06<00:00,  9.20it/s]
100%|██████████| 57/57 [00:06<00:00,  9.24it/s]
100%|██████████| 57/57 [00:06<00:00,  9.16it/s]
100%|██████████| 57/57 [00:06<00:00,  9.16it/s]
100%|██████████| 57/57 [00:06<00:00,  9.06it/s]
100%|██████████| 57/57 [00:06<00:00,  9.11it/s]
100%|██████████| 57/57 [00:06<00:00,  9.08it/s]
100%|██████████| 57/57 [00:06<00:00,  9.

In [27]:
sample_dir = 'laser_uncoated_600dpi'
root ='E:\code\THESIS\printer_source\DATA_micro_print\DATA_micro_print\laser_uncoated_600dpi'
train_gan(latent_size, hidden_size, image_size, num_epochs, batch_size, sample_dir)
torch.cuda.empty_cache()

100%|██████████| 76/76 [00:36<00:00,  2.06it/s]
100%|██████████| 76/76 [00:15<00:00,  4.77it/s]
100%|██████████| 76/76 [00:15<00:00,  4.80it/s]
100%|██████████| 76/76 [00:16<00:00,  4.60it/s]
100%|██████████| 76/76 [00:15<00:00,  4.97it/s]
100%|██████████| 76/76 [00:15<00:00,  4.95it/s]
100%|██████████| 76/76 [00:16<00:00,  4.74it/s]
100%|██████████| 76/76 [00:15<00:00,  4.83it/s]
100%|██████████| 76/76 [00:15<00:00,  4.85it/s]
100%|██████████| 76/76 [00:15<00:00,  4.82it/s]
100%|██████████| 76/76 [00:14<00:00,  5.13it/s]
100%|██████████| 76/76 [00:15<00:00,  4.97it/s]
100%|██████████| 76/76 [00:15<00:00,  4.88it/s]
100%|██████████| 76/76 [00:15<00:00,  4.84it/s]
100%|██████████| 76/76 [00:15<00:00,  4.92it/s]
100%|██████████| 76/76 [00:15<00:00,  4.87it/s]
100%|██████████| 76/76 [00:15<00:00,  4.80it/s]
100%|██████████| 76/76 [00:15<00:00,  4.98it/s]
100%|██████████| 76/76 [00:15<00:00,  4.97it/s]
100%|██████████| 76/76 [00:16<00:00,  4.71it/s]
100%|██████████| 76/76 [00:10<00:00,  7.

In [24]:
sample_dir = 'Waterless_coated_1200dpi'
root ='E:\code\THESIS\printer_source\DATA_micro_print\DATA_micro_print\Waterless_coated_1200dpi'
train_gan(latent_size, hidden_size, image_size, num_epochs, batch_size, sample_dir)
torch.cuda.empty_cache()

100%|██████████| 78/78 [00:39<00:00,  2.00it/s]
100%|██████████| 78/78 [00:07<00:00, 10.71it/s]
100%|██████████| 78/78 [00:07<00:00, 10.57it/s]
100%|██████████| 78/78 [00:07<00:00, 10.71it/s]
100%|██████████| 78/78 [00:07<00:00, 10.72it/s]
100%|██████████| 78/78 [00:07<00:00, 10.63it/s]
100%|██████████| 78/78 [00:07<00:00, 10.56it/s]
100%|██████████| 78/78 [00:07<00:00, 10.67it/s]
100%|██████████| 78/78 [00:07<00:00, 10.70it/s]
100%|██████████| 78/78 [00:07<00:00, 10.54it/s]
100%|██████████| 78/78 [00:07<00:00, 10.64it/s]
100%|██████████| 78/78 [00:07<00:00, 10.61it/s]
100%|██████████| 78/78 [00:07<00:00, 10.71it/s]
100%|██████████| 78/78 [00:07<00:00, 10.70it/s]
100%|██████████| 78/78 [00:07<00:00, 10.65it/s]
100%|██████████| 78/78 [00:07<00:00, 10.55it/s]
100%|██████████| 78/78 [00:07<00:00, 10.74it/s]
100%|██████████| 78/78 [00:07<00:00, 10.63it/s]
100%|██████████| 78/78 [00:07<00:00, 10.69it/s]
100%|██████████| 78/78 [00:07<00:00, 10.60it/s]
100%|██████████| 78/78 [00:07<00:00, 10.

In [25]:
sample_dir = 'Waterless_uncoated_1200dpi'
root ='E:\code\THESIS\printer_source\DATA_micro_print\DATA_micro_print\Waterless_uncoated_1200dpi'
train_gan(latent_size, hidden_size, image_size, num_epochs, batch_size, sample_dir)
torch.cuda.empty_cache()

100%|██████████| 82/82 [00:39<00:00,  2.07it/s]
100%|██████████| 82/82 [00:07<00:00, 10.47it/s]
100%|██████████| 82/82 [00:07<00:00, 10.71it/s]
100%|██████████| 82/82 [00:07<00:00, 10.80it/s]
100%|██████████| 82/82 [00:07<00:00, 10.78it/s]
100%|██████████| 82/82 [00:07<00:00, 10.81it/s]
100%|██████████| 82/82 [00:07<00:00, 10.79it/s]
100%|██████████| 82/82 [00:07<00:00, 10.75it/s]
100%|██████████| 82/82 [00:07<00:00, 10.84it/s]
100%|██████████| 82/82 [00:07<00:00, 10.72it/s]
100%|██████████| 82/82 [00:07<00:00, 10.77it/s]
100%|██████████| 82/82 [00:07<00:00, 10.76it/s]
100%|██████████| 82/82 [00:07<00:00, 10.74it/s]
100%|██████████| 82/82 [00:07<00:00, 10.68it/s]
100%|██████████| 82/82 [00:07<00:00, 10.78it/s]
100%|██████████| 82/82 [00:07<00:00, 10.80it/s]
100%|██████████| 82/82 [00:07<00:00, 10.82it/s]
100%|██████████| 82/82 [00:07<00:00, 10.82it/s]
100%|██████████| 82/82 [00:07<00:00, 10.84it/s]
100%|██████████| 82/82 [00:07<00:00, 10.82it/s]
100%|██████████| 82/82 [00:07<00:00, 10.