In [None]:
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm

from src.train.dataset import load_patch_pairs, SRGANDataset
from src.train.model import Generator, Discriminator

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Device:", device)


In [None]:
splits = load_patch_pairs()
train_lr, train_hr = splits["train"]

train_dataset = SRGANDataset(train_lr[:512], train_hr[:512])  # ambil kecil saja
train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=0)

print("Mini train batches:", len(train_loader))


In [None]:
G = Generator(upscale_factor=2).to(device)
D = Discriminator().to(device)

print("Generator & Discriminator siap")


In [None]:
import torch.nn as nn
from torch.optim import Adam

criterion_L1 = nn.L1Loss()
criterion_BCE = nn.BCELoss()

opt_G = Adam(G.parameters(), lr=1e-4)
opt_D = Adam(D.parameters(), lr=1e-4)

loss_G_list, loss_D_list = [], []

for lr, hr in tqdm(train_loader):
    lr, hr = lr.to(device), hr.to(device)
    bs = lr.size(0)

    real = torch.ones(bs, 1).to(device)
    fake = torch.zeros(bs, 1).to(device)

    # Train D
    sr = G(lr).detach()
    loss_D = criterion_BCE(D(hr), real) + criterion_BCE(D(sr), fake)
    opt_D.zero_grad()
    loss_D.backward()
    opt_D.step()

    # Train G
    sr = G(lr)
    loss_G = criterion_L1(sr, hr) + 1e-3 * criterion_BCE(D(sr), real)
    opt_G.zero_grad()
    loss_G.backward()
    opt_G.step()

    loss_G_list.append(loss_G.item())
    loss_D_list.append(loss_D.item())


In [None]:
plt.figure()
plt.plot(loss_G_list, label="Loss Generator")
plt.plot(loss_D_list, label="Loss Discriminator")
plt.title("Grafik Loss Training (Demo)")
plt.xlabel("Batch")
plt.ylabel("Loss")
plt.legend()
plt.grid()
