In [None]:
# @title SRGAN (x4, Y-channel, VGG19 φ5,4, 16 blocos)
# @markdown Este notebook:
# @markdown 1) Baixa um mini-conjunto de imagens (5 imagens "HR")
# @markdown 2) Gera "LR" com bicúbica ×4 (fiel ao procedimento do paper)
# @markdown 3) Define Gerador (SRResNet-16) e Discriminador
# @markdown 4) Pré-treina o Gerador com MSE (poucos passos) e treina SRGAN (poucos passos)
# @markdown 5) Avalia PSNR/SSIM no canal Y com recorte de borda de 4 px (como no paper)

import os, math, random, urllib.request, zipfile, io
from pathlib import Path

import numpy as np
from PIL import Image
import cv2

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision.models import vgg19, VGG19_Weights
from torchvision import transforms
from skimage.metrics import peak_signal_noise_ratio as psnr_metric
from skimage.metrics import structural_similarity as ssim_metric

In [None]:
# =========================
# 0) Configurações
# =========================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

UPSCALE = 4                 # fator x4
PATCH_HR = 96               # como no paper (patches 96x96 HR na etapa de treino)
BATCH_SIZE = 4              # pequeno para rodar em CPU também
WORKERS = 2

G_PRETRAIN_STEPS = 200      # PARA DEMO: passos reduzidos (paper usou 1e6)
GAN_TRAIN_STEPS = 200       # PARA DEMO: passos reduzidos (paper usou 2e5 total)
G_LR_1 = 1e-4               # lr inicial
G_LR_2 = 1e-5               # (para treino longo; não usamos aqui)
D_LR = 1e-4

# Reprodutibilidade básica
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)

root = Path("./srgan_demo")
(root / "data/hr").mkdir(parents=True, exist_ok=True)
(root / "out").mkdir(parents=True, exist_ok=True)

In [None]:
# ==============================
# 1) Mini dataset (5 imagens HR)
# ==============================
!pip -q install scikit-image

from skimage import data
import numpy as np

hr_dir = root / "data/hr"
(hr_dir).mkdir(parents=True, exist_ok=True)

def to_rgb(img):
    # garante RGB mesmo se vier em grayscale
    if img.ndim == 2:
        img = np.stack([img, img, img], axis=-1)
    if img.shape[2] == 4:
        img = img[:, :, :3]
    return img

def save_img(arr, path):
    arr = np.clip(arr, 0, 255).astype(np.uint8)
    Image.fromarray(arr).save(path)

# pega 5 imagens estáveis do scikit-image
imgs = [
    data.astronaut(),  # 512x512
    data.coffee(),     # 400x600
    data.chelsea(),    # 300x451
    data.rocket(),     # 427x640
    data.camera(),     # 512x512 (grayscale)
]
# converte e ajusta para múltiplo de 4 (requisito do upscale x4)
for i, im in enumerate(imgs, 1):
    im = to_rgb(im)
    H, W = im.shape[:2]
    H = (H // UPSCALE) * UPSCALE
    W = (W // UPSCALE) * UPSCALE
    if (H, W) != im.shape[:2]:
        im = np.array(Image.fromarray(im).resize((W, H), Image.BICUBIC))
    save_img(im, hr_dir / f"img_{i:02d}.png")

print("✅ Dataset HR pronto em", hr_dir)


✅ Dataset HR pronto em srgan_demo/data/hr


In [None]:
# ======================================================
# 2) Utilitários de imagem (YCbCr, canal Y, bicúbica ×4)
# ======================================================
def rgb_to_ycbcr(img: np.ndarray):
    # img float32 [0,1]
    # fórmula ITU-R BT.601 (aprox), retorna em [0,1]
    r, g, b = img[...,0], img[...,1], img[...,2]
    y  = 0.299*r + 0.587*g + 0.114*b
    cb = -0.168736*r - 0.331264*g + 0.5*b + 0.5
    cr = 0.5*r - 0.418688*g - 0.081312*b + 0.5
    return np.stack([y, cb, cr], axis=-1)

def ycbcr_to_rgb(img: np.ndarray):
    # img float32 [0,1]
    y, cb, cr = img[...,0], img[...,1]-0.5, img[...,2]-0.5
    r = y + 1.402*cr
    g = y - 0.344136*cb - 0.714136*cr
    b = y + 1.772*cb
    out = np.stack([r, g, b], axis=-1)
    return np.clip(out, 0.0, 1.0)

def pil_to_np(img_pil):
    return np.array(img_pil).astype(np.float32) / 255.0

def np_to_pil(img_np):
    img = np.clip(img_np*255.0, 0, 255).astype(np.uint8)
    return Image.fromarray(img)

def make_lr_bicubic(hr_img_pil, scale=4):
    w, h = hr_img_pil.size
    lr = hr_img_pil.resize((w//scale, h//scale), Image.BICUBIC)
    return lr

# crop de avaliação: remove 4px de borda no HR (paper)
def shave(img: np.ndarray, border=4):
    return img[border:-border, border:-border, :]

In [None]:
# ====================================================================================================
# 3) Dataset (gera patches 96x96 HR e LR bicúbica x4)
#    Notas do paper: patches HR 96x96, LR via bicúbica r=4,
#    normalização LR [0,1], HR [-1,1] (para MSE):contentReference[oaicite:1]{index=1}.
# ====================================================================================================
class SRPatches(Dataset):
    def __init__(self, hr_dir, patch=96, scale=4, n_patches_per_img=20, split="train"):
        self.hr_paths = sorted([str((Path(hr_dir)/p)) for p in os.listdir(hr_dir) if p.lower().endswith(".png")])
        self.patch = patch
        self.scale = scale
        self.n = n_patches_per_img
        self.split = split
        # carregamos em memória para facilitar
        self.hr_imgs = [Image.open(p).convert("RGB") for p in self.hr_paths]
        # para "val", não queremos patching aleatório — usaremos full frame
    def __len__(self):
        if self.split == "train":
            return len(self.hr_imgs) * self.n
        return len(self.hr_imgs)
    def __getitem__(self, idx):
        if self.split == "train":
            img = random.choice(self.hr_imgs)
            W, H = img.size
            # escolhe um patch HR de tamanho patch, múltiplo de scale
            x = random.randint(0, W - self.patch)
            y = random.randint(0, H - self.patch)
            hr = img.crop((x, y, x+self.patch, y+self.patch))
        else:
            # validação: usa a imagem inteira (ajusta para múltiplo de scale)
            img = self.hr_imgs[idx]
            W, H = img.size
            W = (W // self.scale) * self.scale
            H = (H // self.scale) * self.scale
            hr = img.resize((W, H), Image.BICUBIC)

        # gera LR bicúbica
        lr = make_lr_bicubic(hr, scale=self.scale)

        # tensores: a rede trabalha em RGB; avaliação será em Y
        to_t = transforms.ToTensor()  # [0,1]
        lr_t = to_t(lr)                          # [0,1]
        hr_t = to_t(hr)*2.0 - 1.0                # [-1,1] (paper para MSE):contentReference[oaicite:2]{index=2}

        return lr_t, hr_t

train_ds = SRPatches(hr_dir, PATCH_HR, UPSCALE, n_patches_per_img=20, split="train")
val_ds   = SRPatches(hr_dir, PATCH_HR, UPSCALE, split="val")

train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=WORKERS, drop_last=True)
val_dl   = DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0)

In [None]:
# ====================================================================================================
# 4) Modelos (Gerador SRResNet-16 + Discriminador)
#    Fiel ao paper: blocos residuais 3x3/64 + BN + PReLU; PixelShuffle para upsample ×4;
#    Discriminador com convs 3x3, canais 64→512, strides, LeakyReLU(0.2), FC + sigmoid:contentReference[oaicite:3]{index=3}.
# ====================================================================================================
class ResidualBlock(nn.Module):
    def __init__(self, channels=64):
        super().__init__()
        self.conv1 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.bn1   = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU()
        self.conv2 = nn.Conv2d(channels, channels, 3, 1, 1)
        self.bn2   = nn.BatchNorm2d(channels)
    def forward(self, x):
        out = self.prelu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        return x + out

class UpsampleBlock(nn.Module):
    def __init__(self, in_ch=64, scale=2):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, in_ch * (scale ** 2), 3, 1, 1)
        self.ps   = nn.PixelShuffle(scale)
        self.prelu= nn.PReLU()
    def forward(self, x):
        return self.prelu(self.ps(self.conv(x)))

class GeneratorSRResNet16(nn.Module):
    def __init__(self, n_blocks=16, in_ch=3, out_ch=3):
        super().__init__()
        self.conv1 = nn.Conv2d(in_ch, 64, 9, 1, 4)  # kernel 9x9 comum em SRResNet inicial
        self.prelu = nn.PReLU()
        self.resblocks = nn.Sequential(*[ResidualBlock(64) for _ in range(n_blocks)])
        self.conv2 = nn.Conv2d(64, 64, 3, 1, 1)
        self.bn2   = nn.BatchNorm2d(64)
        # dois upsample ×2 → ×4 total
        self.up1 = UpsampleBlock(64, 2)
        self.up2 = UpsampleBlock(64, 2)
        self.conv3 = nn.Conv2d(64, out_ch, 9, 1, 4)
    def forward(self, x):
        x1 = self.prelu(self.conv1(x))
        xr = self.resblocks(x1)
        x2 = self.bn2(self.conv2(xr))
        x = x1 + x2
        x = self.up1(x)
        x = self.up2(x)
        x = torch.tanh(self.conv3(x))  # [-1,1]
        return x

class Discriminator(nn.Module):
    def __init__(self, in_ch=3):
        super().__init__()
        def block(i, o, k=3, s=1, p=1, bn=True):
            layers = [nn.Conv2d(i,o,k,s,p)]
            if bn: layers += [nn.BatchNorm2d(o)]
            layers += [nn.LeakyReLU(0.2, inplace=True)]
            return nn.Sequential(*layers)
        self.features = nn.Sequential(
            block(in_ch, 64, 3, 1, 1, bn=False),
            block(64, 64, 3, 2, 1, bn=True),
            block(64, 128, 3, 1, 1, bn=True),
            block(128,128,3, 2, 1, bn=True),
            block(128,256,3, 1, 1, bn=True),
            block(256,256,3, 2, 1, bn=True),
            block(256,512,3, 1, 1, bn=True),
            block(512,512,3, 2, 1, bn=True),
        )
        # descobre automaticamente o flatten_dim com um forward “seco”
        with torch.no_grad():
            dummy = torch.zeros(1, 3, PATCH_HR, PATCH_HR)
            f = self.features(dummy)
            flatten_dim = f.numel()
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(flatten_dim, 1024),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(1024, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        f = self.features(x)
        return self.classifier(f)
G = GeneratorSRResNet16().to(DEVICE)
D = Discriminator().to(DEVICE)

In [None]:
# =========================
# 5) Perdas
#    - Content loss VGG19 φ5,4 (relu5_4)
#    - Adversarial loss: -log D(G(x))  (usamos BCE)
#    - Pré-treino G com MSE em RGB (paper pré-treinou com MSE):contentReference[oaicite:4]{index=4}
# =========================
# VGG19 para feature extraction
vgg = vgg19(weights=VGG19_Weights.IMAGENET1K_V1).features.to(DEVICE).eval()
for p in vgg.parameters(): p.requires_grad = False

# índice da relu5_4 em torchvision (features)
# VGG19 layers: conv1_1=0, relu1_1=1, ..., relu5_4=34 (conv5_4=34? relu5_4 ~ 35-36 dependendo da versão)
# No torchvision, o bloco final tem indices até 36; pegaremos features até relu5_4 inclusive.
VGG_FEATURE_LAYER = 35  # ajusta para capturar relu5_4 (pode variar por versão)
def vgg54_feats(x):
    # espera x em [0,1], normalizado como ImageNet
    # converter de [-1,1] p/ [0,1]
    x = (x + 1.0)/2.0
    norm = transforms.Normalize(
        mean=[0.485, 0.456, 0.406],
        std =[0.229, 0.224, 0.225]
    )
    x = torch.stack([norm(xi) for xi in x], dim=0)
    return vgg[:VGG_FEATURE_LAYER](x)

bce = nn.BCELoss()
mse = nn.MSELoss()

opt_G_pre = torch.optim.Adam(G.parameters(), lr=G_LR_1, betas=(0.9, 0.999))
opt_G     = torch.optim.Adam(G.parameters(), lr=G_LR_1, betas=(0.9, 0.999))
opt_D     = torch.optim.Adam(D.parameters(), lr=D_LR,   betas=(0.9, 0.999))

Downloading: "https://download.pytorch.org/models/vgg19-dcbb9e9d.pth" to /root/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth


100%|██████████| 548M/548M [00:07<00:00, 73.2MB/s]


In [None]:
# =========================
# 6) Pré-treino do Gerador (MSE)
# =========================
G.train()
for step, (lr_t, hr_t) in enumerate(train_dl, 1):
    lr_t, hr_t = lr_t.to(DEVICE), hr_t.to(DEVICE)
    sr_t = G(lr_t)
    loss = mse(sr_t, hr_t)  # MSE no espaço RGB (com HR em [-1,1])
    opt_G_pre.zero_grad()
    loss.backward()
    opt_G_pre.step()
    if step % 50 == 0:
        print(f"[G-pre] step {step}/{G_PRETRAIN_STEPS}  MSE: {loss.item():.4f}")
    if step >= G_PRETRAIN_STEPS:
        break

In [None]:
# =========================
# 7) Treino adversarial (SRGAN) — (content VGG54 + 1e-3 * adversarial)
# =========================
G.train(); D.train()

def real_fake_labels(bs):
    return torch.ones(bs,1,device=DEVICE), torch.zeros(bs,1,device=DEVICE)

ADV_WEIGHT = 1e-3  # como no paper para escalar o termo adversarial:contentReference[oaicite:5]{index=5}

step = 0
while step < GAN_TRAIN_STEPS:
    for lr_t, hr_t in train_dl:
        step += 1
        lr_t, hr_t = lr_t.to(DEVICE), hr_t.to(DEVICE)
        bs = lr_t.size(0)

        # === Atualiza D ===
        with torch.no_grad():
            sr_t = G(lr_t)
        D_real = D((hr_t + 1)/2)   # D espera [0,1]; convertendo HR [-1,1] → [0,1]
        D_fake = D((sr_t + 1)/2)   # idem para SR
        y_real, y_fake = real_fake_labels(bs)
        loss_D = bce(D_real, y_real) + bce(D_fake, y_fake)
        opt_D.zero_grad()
        loss_D.backward()
        opt_D.step()

        # === Atualiza G ===
        sr_t = G(lr_t)
        D_fake = D((sr_t + 1)/2)
        # content loss: VGG19 φ5,4 (euclidiana)
        f_sr = vgg54_feats(sr_t)
        f_hr = vgg54_feats(hr_t)
        content_loss = mse(f_sr, f_hr)
        adv_loss = bce(D_fake, y_real)  # -log(D(G)) ~ BCE com target=1
        loss_G = content_loss + ADV_WEIGHT * adv_loss

        opt_G.zero_grad()
        loss_G.backward()
        opt_G.step()

        if step % 50 == 0:
            print(f"[GAN] step {step}/{GAN_TRAIN_STEPS}  "
                  f"content(VGG54): {content_loss.item():.4f}  "
                  f"adv: {adv_loss.item():.4f}  D: {loss_D.item():.4f}")
        if step >= GAN_TRAIN_STEPS:
            break

[GAN] step 50/200  content(VGG54): 12.5854  adv: 24.7100  D: 0.0000
[GAN] step 100/200  content(VGG54): 36.0304  adv: 21.7004  D: 0.0003
[GAN] step 150/200  content(VGG54): 15.1095  adv: 18.2490  D: 0.0001
[GAN] step 200/200  content(VGG54): 13.1520  adv: 30.5570  D: 0.0000


In [None]:
# =========================
# 8) Avaliação: PSNR/SSIM no canal Y com recorte de 4 px (fiel ao paper)
# =========================
G.eval()
def evaluate_model(G, loader, save_examples=3):
    G.eval()
    all_psnr, all_ssim = [], []
    saved = 0
    with torch.no_grad():
        for i, (lr_t, hr_t) in enumerate(loader):
            lr_t = lr_t.to(DEVICE)
            sr_t = G(lr_t).clamp(-1,1)
            # converte tensores para imagens numpy [0,1] RGB
            sr = (sr_t[0].permute(1,2,0).cpu().numpy() + 1)/2
            hr = (hr_t[0].permute(1,2,0).cpu().numpy() + 1)/2

            # converte para YCbCr, pega Y
            sr_y = rgb_to_ycbcr(sr)[...,0:1]
            hr_y = rgb_to_ycbcr(hr)[...,0:1]

            # crop borda de 4 px (avaliar região válida)
            sr_y_c = shave(sr_y, 4)
            hr_y_c = shave(hr_y, 4)

            psnr_val = psnr_metric(hr_y_c, sr_y_c, data_range=1.0)
            ssim_val = ssim_metric(hr_y_c.squeeze(), sr_y_c.squeeze(), data_range=1.0)
            all_psnr.append(psnr_val)
            all_ssim.append(ssim_val)

            # salva alguns exemplos
            if saved < save_examples:
                out_p = root / "out" / f"sr_{i:02d}.png"
                np_to_pil(sr).save(out_p)
                saved += 1
    return float(np.mean(all_psnr)), float(np.mean(all_ssim))

mean_psnr, mean_ssim = evaluate_model(G, val_dl, save_examples=3)
print(f"RESULTADOS (demo) — Y-channel, crop 4 px, x{UPSCALE}:  PSNR={mean_psnr:.2f} dB  SSIM={mean_ssim:.4f}")

print(f"\nSaídas de exemplo salvas em: {root/'out'}")

# =========================
# 9) Observações de fidelidade:
# - Arquitetura do Gerador: 16 blocos residuais, BN, PReLU, PixelShuffle ×4 (SRResNet) ✓
# - Discriminador: convs 3x3, canais 64→512, strides, LeakyReLU(0.2), MLP final + sigmoid ✓
# - Perdas: Content (VGG19 φ5,4) + 1e-3 * adversarial (BCE com target=1) ✓
# - Pré-treino do G com MSE; depois adversarial alternando G/D (k=1) ✓
# - Dados: LR gerado por bicúbica ×4; avaliação no canal Y com crop 4 px ✓
# - Escala reduzida para ser rápido; para resultados fortes, aumente passos/épocas e use banco grande (p.ex. DIV2K/ImageNet).


RESULTADOS (demo) — Y-channel, crop 4 px, x4:  PSNR=8.54 dB  SSIM=0.0635

Saídas de exemplo salvas em: srgan_demo/out


In [None]:
import shutil
from google.colab import files

# Caminho da pasta
folder_path = "/content/srgan_demo/out"
zip_path = "/content/srgan_out.zip"

# Compacta toda a pasta em um arquivo .zip
shutil.make_archive("/content/srgan_out", 'zip', folder_path)

# Faz o download do .zip
files.download(zip_path)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>

In [None]:
# =========================
# 8) Avaliação (fiel ao paper)
# =========================
import csv
from PIL import ImageDraw, ImageFont

G.eval()

def make_bicubic_from_lr(lr_rgb, scale=4):
    """Upscale LR (RGB [0,1]) para o tamanho HR com bicúbica, para baseline de comparação."""
    h, w = lr_rgb.shape[:2]
    up = np.array(Image.fromarray((lr_rgb*255).astype(np.uint8)).resize((w*scale, h*scale), Image.BICUBIC)).astype(np.float32)/255.0
    return up

def concat_horiz(imgs, labels=("Bicubic", "SRGAN", "HR")):
    """Concatena imagens horizontalmente e adiciona legendas sobre fundo preto."""
    pil_imgs = [np_to_pil(im) for im in imgs]
    widths, heights = zip(*(im.size for im in pil_imgs))
    H = max(heights)
    W = sum(widths)
    out = Image.new('RGB', (W, H + 30), color=(0, 0, 0))  # faixa preta para legenda
    x = 0
    draw = ImageDraw.Draw(out)

    try:
        font = ImageFont.truetype("DejaVuSans-Bold.ttf", 16)
    except:
        font = ImageFont.load_default()

    for idx, im in enumerate(pil_imgs):
        out.paste(im, (x, 30))
        w, h = im.size
        label = labels[idx] if idx < len(labels) else f"img{idx}"

        # Usa textbbox se disponível, senão textsize (compatibilidade Pillow)
        if hasattr(draw, "textbbox"):
            bbox = draw.textbbox((0, 0), label, font=font)
            text_w = bbox[2] - bbox[0]
            text_h = bbox[3] - bbox[1]
        else:
            text_w, text_h = draw.textsize(label, font=font)

        draw.text(
            (x + (w - text_w) // 2, (30 - text_h) // 2),
            label,
            font=font,
            fill=(255, 255, 255)
        )
        x += w
    return out

def central_crop_box(H, W, min_size=128):
    """Define uma caixa central (x0,y0,larg,alt) para a 'mesma fatia' em HR.
       Se a imagem for pequena, usa min(H,W)//3."""
    size = max(min_size, min(H, W)//3)
    size = min(size, H, W)  # garante que cabe
    x0 = (W - size) // 2
    y0 = (H - size) // 2
    return (x0, y0, size, size)

def crop_by_box(img, box):
    """Recorta img (RGB [0,1]) pela caixa (x0,y0,w,h)."""
    x0, y0, w, h = box
    return img[y0:y0+h, x0:x0+w, :]

def evaluate_model(G, loader, dataset_for_names, save_examples=3):
    G.eval()
    all_psnr, all_ssim = [], []
    saved = 0

    info_rows = []
    out_dir = root / "out"
    out_dir.mkdir(parents=True, exist_ok=True)
    csv_path = out_dir / "examples_info.csv"

    with torch.no_grad(), open(csv_path, "w", newline="", encoding="utf-8") as fcsv:
        writer = csv.writer(fcsv)
        writer.writerow([
            "idx","original_hr_path","sr_path","hr_path","bicubic_path",
            "compare_full_path","compare_patch_path","psnr_y","ssim_y","crop_box_(x0,y0,w,h)"
        ])

        for i, (lr_t, hr_t) in enumerate(loader):
            lr_t = lr_t.to(DEVICE)
            sr_t = G(lr_t).clamp(-1,1)

            # Tensores → numpy RGB [0,1]
            sr = (sr_t[0].permute(1,2,0).cpu().numpy() + 1)/2
            hr = (hr_t[0].permute(1,2,0).cpu().numpy() + 1)/2
            # LR chega menor; upsample bicúbico para baseline (mesmo tamanho do HR/SR)
            lr_rgb = lr_t[0].permute(1,2,0).cpu().numpy()  # [0,1]
            bic = make_bicubic_from_lr(lr_rgb, scale=UPSCALE)

            # === Métricas PSNR/SSIM (Y + shave 4) ===
            sr_y = rgb_to_ycbcr(sr)[...,0:1]
            hr_y = rgb_to_ycbcr(hr)[...,0:1]
            sr_y_c = shave(sr_y, 4)
            hr_y_c = shave(hr_y, 4)
            psnr_val = psnr_metric(hr_y_c, sr_y_c, data_range=1.0)
            ssim_val = ssim_metric(hr_y_c.squeeze(), sr_y_c.squeeze(), data_range=1.0)
            all_psnr.append(psnr_val); all_ssim.append(ssim_val)

            # === Salvamento dos arquivos ===
            # caminhos
            sr_p  = out_dir / f"sr_{i:02d}.png"
            hr_p  = out_dir / f"hr_{i:02d}.png"
            bic_p = out_dir / f"bicubic_{i:02d}.png"
            # orig path (nome fonte da imagem HR no dataset)
            try:
                original_path = dataset_for_names.hr_paths[i]
            except:
                original_path = str(hr_p)

            # salva completos
            np_to_pil(sr).save(sr_p)
            np_to_pil(hr).save(hr_p)
            np_to_pil(bic).save(bic_p)

            # comparativo completo (Bicubic | SR | HR)
            cmp_full = concat_horiz([bic, sr, hr])
            cmp_full_p = out_dir / f"compare_full_{i:02d}.png"
            cmp_full.save(cmp_full_p)

            # === "Mesma fatia" — recorte central idêntico nas 3 imagens ===
            H, W = hr.shape[:2]
            box = central_crop_box(H, W, min_size=128)
            bic_patch = crop_by_box(bic, box)
            sr_patch  = crop_by_box(sr,  box)
            hr_patch  = crop_by_box(hr,  box)

            # salva patches individuais (opcional) e comparativo lado a lado
            patch_cmp = concat_horiz([bic_patch, sr_patch, hr_patch])
            patch_cmp_p = out_dir / f"compare_patch_{i:02d}.png"
            patch_cmp.save(patch_cmp_p)

            # registra no CSV
            writer.writerow([
                i, original_path, str(sr_p), str(hr_p), str(bic_p),
                str(cmp_full_p), str(patch_cmp_p),
                f"{psnr_val:.4f}", f"{ssim_val:.6f}",
                f"({box[0]},{box[1]},{box[2]},{box[3]})"
            ])

            # print resumo para quem está vendo no console
            if saved < save_examples:
                print(f"[Exemplo {i}]")
                print(f"  Original HR: {original_path}")
                print(f"  SR salvo em: {sr_p.name} | HR: {hr_p.name} | Bicubic: {bic_p.name}")
                print(f"  Comparação (full): {cmp_full_p.name}")
                print(f"  Comparação (fatia): {patch_cmp_p.name}  crop_box={box}")
                print(f"  Métricas (Y, crop 4px): PSNR={psnr_val:.2f}dB  SSIM={ssim_val:.4f}\n")
                saved += 1

    return float(np.mean(all_psnr)), float(np.mean(all_ssim))

mean_psnr, mean_ssim = evaluate_model(G, val_dl, val_ds, save_examples=3)
print(f"RESULTADOS (demo) — Y-channel, crop 4 px, x{UPSCALE}:  PSNR={mean_psnr:.2f} dB  SSIM={mean_ssim:.4f}")
print(f"\nArquivos salvos em: {root/'out'}")
print(f"Planilha com caminhos e métricas: {(root/'out'/'examples_info.csv')}")


[Exemplo 0]
  Original HR: srgan_demo/data/hr/img_01.png
  SR salvo em: sr_00.png | HR: hr_00.png | Bicubic: bicubic_00.png
  Comparação (full): compare_full_00.png
  Comparação (fatia): compare_patch_00.png  crop_box=(171, 171, 170, 170)
  Métricas (Y, crop 4px): PSNR=6.77dB  SSIM=0.0496

[Exemplo 1]
  Original HR: srgan_demo/data/hr/img_02.png
  SR salvo em: sr_01.png | HR: hr_01.png | Bicubic: bicubic_01.png
  Comparação (full): compare_full_01.png
  Comparação (fatia): compare_patch_01.png  crop_box=(233, 133, 133, 133)
  Métricas (Y, crop 4px): PSNR=8.17dB  SSIM=0.0407

[Exemplo 2]
  Original HR: srgan_demo/data/hr/img_03.png
  SR salvo em: sr_02.png | HR: hr_02.png | Bicubic: bicubic_02.png
  Comparação (full): compare_full_02.png
  Comparação (fatia): compare_patch_02.png  crop_box=(160, 86, 128, 128)
  Métricas (Y, crop 4px): PSNR=11.71dB  SSIM=0.1032

RESULTADOS (demo) — Y-channel, crop 4 px, x4:  PSNR=8.54 dB  SSIM=0.0635

Arquivos salvos em: srgan_demo/out
Planilha com camin

In [None]:
import shutil
from google.colab import files

# Caminho da pasta
folder_path = "/content/srgan_demo/out"
zip_path = "/content/srgan_out.zip"

# Compacta toda a pasta em um arquivo .zip
shutil.make_archive("/content/srgan_out_full_files", 'zip', folder_path)

# Faz o download do .zip para seu computador
files.download(zip_path)


<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>