In [4]:
import os
import torch
import torch.optim as optim
import torch.nn as nn
from torch.utils.data import DataLoader
from models import Generator, Discriminator
from utils.data_utils import QRDataset
from utils.model_utils import save_model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
def train(generator, discriminator, dataloader, num_epochs=10, start_epoch = 0):
    g_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
    d_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

    content_loss = nn.MSELoss()
    adversarial_loss = nn.BCELoss()

    for epoch in range(start_epoch, num_epochs):
        for lr_imgs, hr_imgs in dataloader:
            lr_imgs, hr_imgs = lr_imgs.to(device), hr_imgs.to(device)

            # Train Discriminator
            discriminator.zero_grad()
            real_labels = torch.ones(lr_imgs.size(0), 1).to(device)
            fake_labels = torch.zeros(lr_imgs.size(0), 1).to(device)

            real_output = discriminator(hr_imgs)
            real_loss = adversarial_loss(real_output, real_labels)

            generated_imgs = generator(lr_imgs)
            fake_output = discriminator(generated_imgs.detach())
            fake_loss = adversarial_loss(fake_output, fake_labels)

            d_loss = real_loss + fake_loss
            d_loss.backward()
            d_optimizer.step()

            # Train Generator
            generator.zero_grad()
            fake_output = discriminator(generated_imgs)
            g_loss = content_loss(generated_imgs, hr_imgs) + 1e-3 * adversarial_loss(fake_output, real_labels)
            g_loss.backward()
            g_optimizer.step()

        print(f"Epoch [{epoch}/{num_epochs}]  D Loss: {d_loss.item():.4f}, G Loss: {g_loss.item():.4f}")
        
        # Save model every few epochs
        if epoch % 5 == 0 or epoch == num_epochs - 1:
            save_model(generator, discriminator, epoch)

In [6]:
dataset = QRDataset("data/")
dataloader = DataLoader(dataset, batch_size=4, shuffle=True)

generator = Generator().to(device)
discriminator = Discriminator().to(device)

# Load checkpoint if available
start_epoch = 0
checkpoint_gen = 'models/generator_epoch_25.pth'
checkpoint_disc = 'models/discriminator_epoch_25.pth'

if os.path.exists(checkpoint_gen) and os.path.exists(checkpoint_disc):
    generator.load_state_dict(torch.load(checkpoint_gen))
    discriminator.load_state_dict(torch.load(checkpoint_disc))
    start_epoch = 26  # Continue from the next epoch after the checkpoint

# Start or continue training
train(generator, discriminator, dataloader, num_epochs=41, start_epoch=start_epoch)

Epoch [26/41]  D Loss: 100.0000, G Loss: 0.0074
Epoch [27/41]  D Loss: 100.0000, G Loss: 0.0038
Epoch [28/41]  D Loss: 100.0000, G Loss: 0.0042
Epoch [29/41]  D Loss: 100.0000, G Loss: 0.0029
Epoch [30/41]  D Loss: 100.0000, G Loss: 0.0033
Models saved at epoch 30:
Epoch [31/41]  D Loss: 100.0000, G Loss: 0.0045
Epoch [32/41]  D Loss: 100.0000, G Loss: 0.0036
Epoch [33/41]  D Loss: 100.0000, G Loss: 0.0042
Epoch [34/41]  D Loss: 100.0000, G Loss: 0.0047
Epoch [35/41]  D Loss: 100.0000, G Loss: 0.0030
Models saved at epoch 35:
Epoch [36/41]  D Loss: 100.0000, G Loss: 0.0026
Epoch [37/41]  D Loss: 100.0000, G Loss: 0.0026
Epoch [38/41]  D Loss: 100.0000, G Loss: 0.0034
Epoch [39/41]  D Loss: 100.0000, G Loss: 0.0027
Epoch [40/41]  D Loss: 100.0000, G Loss: 0.0026
Models saved at epoch 40:
