In [None]:
# -*- coding: utf-8 -*-
"""Face Restoration with Transformer-based Model
This script implements a face restoration model using a transformer-based architecture. 
It includes data loading, augmentation, degradation, model definition, training, validation, and inference.
It is designed to work with the FFHQ dataset and includes functionality for generating landmark heatmaps, applying degradation, and computing perceptual loss.
"""


# Import libraries
import os
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torch.utils.tensorboard import SummaryWriter
from torch.cuda.amp import autocast, GradScaler
from tqdm import tqdm
import mediapipe as mp
import logging
import cv2
import torchvision.models as models
import random
import math

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(message)s')
logger = logging.getLogger()

# Verify GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
logger.info(f"Device: {torch.cuda.get_device_name(0) if torch.cuda.is_available() else 'CPU'}")

# Utility Functions
class FileClient:
    def __init__(self, backend='disk'):
        self.backend = backend

    def get(self, filepath):
        if self.backend == 'disk':
            with open(filepath, 'rb') as f:
                return f.read()
        raise NotImplementedError(f"Backend {self.backend} not supported")

def imfrombytes(content, float32=False):
    img_np = cv2.imdecode(np.frombuffer(content, np.uint8), cv2.IMREAD_COLOR)
    img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2RGB)
    if float32:
        img_np = img_np.astype(np.float32) / 255.0
    return img_np

def img2tensor(img, bgr2rgb=False, float32=True):
    if img.shape[2] == 3 and bgr2rgb:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img = torch.from_numpy(img.transpose(2, 0, 1))
    if float32:
        img = img.float() / 255.0
    return img

def augment(img, hflip=True, rotation=True):
    hflip = hflip and random.random() < 0.5
    vflip = rotation and random.random() < 0.5
    rot90 = rotation and random.random() < 0.5
    if hflip:
        img = img[:, ::-1, :]
    if vflip:
        img = img[::-1, :, :]
    if rot90:
        img = img.transpose(1, 0, 2)
    return img

def circular_lowpass_kernel(omega_c, kernel_size):
    kernel = np.zeros((kernel_size, kernel_size), dtype=np.float32)
    center = kernel_size // 2
    for i in range(kernel_size):
        for j in range(kernel_size):
            r = np.sqrt((i - center) ** 2 + (j - center) ** 2)
            if r <= kernel_size // 2:
                kernel[i, j] = np.sinc(r * omega_c / np.pi)
    kernel /= kernel.sum() or 1.0
    return kernel

def random_mixed_kernels(kernel_list, kernel_prob, kernel_size, sigma):
    kernel_type = np.random.choice(kernel_list, p=kernel_prob)
    if kernel_type == 'iso':
        sigma_val = np.random.uniform(sigma[0], sigma[1])
        kernel = cv2.getGaussianKernel(kernel_size, sigma_val)
        kernel = kernel * kernel.T
    elif kernel_type == 'aniso':
        sigma1 = np.random.uniform(sigma[0], sigma[1])
        sigma2 = np.random.uniform(sigma[0], sigma[1])
        theta = np.random.uniform(-math.pi, math.pi)
        kernel = cv2.getGaussianKernel(kernel_size, sigma1)
        kernel = kernel * kernel.T
        M = cv2.getRotationMatrix2D((kernel_size // 2, kernel_size // 2), theta * 180 / np.pi, 1)
        kernel = cv2.warpAffine(kernel, M, (kernel_size, kernel_size))
    kernel /= kernel.sum() or 1.0
    return kernel

# Dataset Class
class FFHQsubDataset(Dataset):
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        self.file_client = FileClient()
        self.gt_folder = opt['dataroot_gt']
        with open(opt['meta_info']) as f:
            self.paths = [os.path.join(self.gt_folder, line.strip()) for line in f]
        
        self.blur_kernel_size = opt['blur_kernel_size']
        self.kernel_list = opt['kernel_list']
        self.kernel_prob = opt['kernel_prob']
        self.blur_sigma = opt['blur_sigma']
        self.sinc_prob = opt['sinc_prob']
        self.kernel_range = [2 * v + 1 for v in range(3, 11)]
        self.pulse_tensor = torch.zeros(21, 21).float()
        self.pulse_tensor[10, 10] = 1

    def __getitem__(self, index):
        gt_path = self.paths[index]
        retry = 3
        while retry > 0:
            try:
                img_bytes = self.file_client.get(gt_path)
                break
            except (IOError, OSError):
                index = random.randint(0, len(self.paths) - 1)
                gt_path = self.paths[index]
                retry -= 1
        if retry == 0:
            raise RuntimeError(f"Failed to load image {gt_path}")

        img_gt = imfrombytes(img_bytes, float32=True)
        img_gt = augment(img_gt, self.opt['use_hflip'], self.opt['use_rot'])

        kernel_size = random.choice(self.kernel_range)
        if np.random.uniform() < self.opt['sinc_prob']:
            omega_c = np.random.uniform(np.pi / 5, np.pi) if kernel_size >= 13 else np.random.uniform(np.pi / 3, np.pi)
            kernel = circular_lowpass_kernel(omega_c, kernel_size)
        else:
            kernel = random_mixed_kernels(self.kernel_list, self.kernel_prob, kernel_size, self.blur_sigma)
        pad_size = (21 - kernel_size) // 2
        kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))

        if np.random.uniform() < self.opt['final_sinc_prob']:
            kernel_size = random.choice(self.kernel_range)
            sinc_kernel = circular_lowpass_kernel(np.random.uniform(np.pi / 3, np.pi), kernel_size)
            sinc_kernel = np.pad(sinc_kernel, ((pad_size, pad_size), (pad_size, pad_size)))
        else:
            sinc_kernel = self.pulse_tensor

        img_gt = img2tensor(img_gt, bgr2rgb=True, float32=True)
        kernel = torch.FloatTensor(kernel)
        sinc_kernel = torch.FloatTensor(sinc_kernel)

        return {'gt': img_gt, 'kernel': kernel, 'sinc_kernel': sinc_kernel, 'gt_path': gt_path}

    def __len__(self):
        return len(self.paths)

# Degradation Function
def degrade_image(img, kernel, sinc_kernel):
    B, C, H, W = img.size()
    output = torch.zeros(B, C, 128, 128, device=img.device)
    for b in range(B):
        k = kernel[b:b+1].unsqueeze(0).repeat(3, 1, 1, 1).to(img.device)
        sk = sinc_kernel[b:b+1].unsqueeze(0).repeat(3, 1, 1, 1).to(img.device)
        img_b = img[b:b+1]
        img_b = F.conv2d(img_b, k, padding=10, groups=3)
        img_b = F.interpolate(img_b, scale_factor=0.25, mode='bicubic')
        img_b = F.conv2d(img_b, sk, padding=10, groups=3)
        img_b = F.interpolate(img_b, size=(128, 128), mode='bicubic').clamp(0, 1)
        output[b:b+1] = img_b
    return output

# Model Definition
class TransformerBlock(nn.Module):
    def __init__(self, dim, num_heads=4):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(dim, num_heads)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = nn.Sequential(
            nn.Linear(dim, dim * 2),
            nn.GELU(),
            nn.Linear(dim * 2, dim)
        )

    def forward(self, x):
        B, C, H, W = x.shape
        x = x.permute(0, 2, 3, 1).reshape(B, H * W, C)
        x = x + self.attn(self.norm1(x), self.norm1(x), self.norm1(x))[0]
        x = x + self.ffn(self.norm2(x))
        return x.reshape(B, H, W, C).permute(0, 3, 1, 2)

class LandmarkAttention(nn.Module):
    def __init__(self, channels):
        super().__init__()
        self.conv = nn.Conv2d(channels, channels, 3, padding=1)
        self.attention = nn.Sequential(
            nn.Conv2d(1, 16, 3, padding=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(16, 1, 3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x, landmark_heatmap):
        if landmark_heatmap is None:
            return self.conv(x)
        if landmark_heatmap.size(2) != x.size(2):
            landmark_heatmap = F.interpolate(landmark_heatmap, size=(x.size(2), x.size(3)), mode='bilinear')
        attention_weights = self.attention(landmark_heatmap)
        return self.conv(x * (1 + attention_weights))

class PixelShuffleUpsample(nn.Module):
    def __init__(self, in_ch, scale=2):
        super().__init__()
        self.conv = nn.Conv2d(in_ch, in_ch * (scale ** 2), 3, padding=1)
        self.shuffle = nn.PixelShuffle(scale)

    def forward(self, x):
        return self.shuffle(self.conv(x))

class FaceRestormer(nn.Module):
    def __init__(self, in_ch=3, out_ch=3, dim=64):
        super().__init__()
        self.initial = nn.Conv2d(in_ch, dim, 3, padding=1)
        self.encoder1 = nn.Sequential(TransformerBlock(dim), TransformerBlock(dim))
        self.down1 = nn.Conv2d(dim, dim, 4, stride=2, padding=1)
        self.encoder2 = TransformerBlock(dim)
        self.down2 = nn.Conv2d(dim, dim, 4, stride=2, padding=1)
        self.bottleneck = TransformerBlock(dim)
        self.up1 = PixelShuffleUpsample(dim, 2)
        self.landmark_attention = LandmarkAttention(dim)
        self.decoder1 = TransformerBlock(dim)
        self.up2 = PixelShuffleUpsample(dim, 2)
        self.decoder2 = TransformerBlock(dim)
        self.final_upsample = nn.Sequential(
            PixelShuffleUpsample(dim, 2),
            nn.Conv2d(dim, dim, 3, padding=1),
            nn.LeakyReLU(0.2),
            PixelShuffleUpsample(dim, 2),
            nn.Conv2d(dim, out_ch, 3, padding=1),
            nn.Sigmoid()
        )

    def forward(self, x, landmark_heatmap=None):
        feat = self.initial(x)
        e1 = self.encoder1(feat)
        feat = self.down1(e1)
        e2 = self.encoder2(feat)
        feat = self.down2(e2)
        feat = self.bottleneck(feat)
        feat = self.up1(feat)
        feat = self.landmark_attention(feat, landmark_heatmap) + e2
        feat = self.decoder1(feat)
        feat = self.up2(feat) + e1
        feat = self.decoder2(feat)
        return self.final_upsample(feat)

# Landmark Heatmap Generation
def generate_landmark_heatmap(image_batch):
    mp_face = mp.solutions.face_mesh.FaceMesh(static_image_mode=True, max_num_faces=1)
    batch_size = image_batch.size(0)
    device = image_batch.device
    heatmaps = []
    for b in range(batch_size):
        img = (image_batch[b].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
        results = mp_face.process(img)
        heatmap = np.zeros((128, 128), dtype=np.float32)
        if results.multi_face_landmarks:
            for face_landmarks in results.multi_face_landmarks:
                for lm in face_landmarks.landmark:
                    x, y = int(lm.x * 128), int(lm.y * 128)
                    if 0 <= x < 128 and 0 <= y < 128:
                        for i in range(max(0, y-2), min(128, y+3)):
                            for j in range(max(0, x-2), min(128, x+3)):
                                dist = np.sqrt((i-y)**2 + (j-x)**2)
                                if dist < 3:
                                    heatmap[i, j] = max(heatmap[i, j], np.exp(-dist))
        else:
            center_x, center_y = 64, 64
            for i in range(128):
                for j in range(128):
                    dist = np.sqrt((i - center_y) ** 2 + (j - center_x) ** 2)
                    heatmap[i, j] = np.exp(-dist / 10.0)
        heatmaps.append(torch.from_numpy(heatmap))
    heatmap_batch = torch.stack(heatmaps).unsqueeze(1).to(device)
    mp_face.close()
    return heatmap_batch

# Perceptual Loss
class PerceptualLoss(nn.Module):
    def __init__(self):
        super().__init__()
        vgg = models.vgg16(weights='IMAGENET1K_V1').features[:16].eval().to(device)
        for param in vgg.parameters():
            param.requires_grad = False
        self.vgg = vgg
        self.criterion = nn.L1Loss()

    def forward(self, pred, gt):
        return self.criterion(self.vgg(pred), self.vgg(gt))

# Dataset Configuration
dataset_base = "/content/drive/MyDrive/2024-S2-AI6126-Project2-Release/datasets"
opt = {
    'dataroot_gt': '/content/FFHQ/train/GT',
    'meta_info': '/content/FFHQ/train/meta_info_FFHQfull_GT.txt',
    'use_hflip': True,
    'use_rot': True,
    'blur_kernel_size': 21,
    'kernel_list': ['iso', 'aniso'],
    'kernel_prob': [0.45, 0.55],
    'blur_sigma': [0.2, 3],
    'sinc_prob': 0.1,
    'final_sinc_prob': 0.8
}

# Setup Directories
os.makedirs('/content/FFHQ/train/GT', exist_ok=True)
os.makedirs('/content/FFHQ/val/GT', exist_ok=True)
os.makedirs('/content/FFHQ/val/LQ', exist_ok=True)
os.makedirs('/content/FFHQ/test/LQ', exist_ok=True)
os.makedirs('/content/results/val', exist_ok=True)
os.makedirs('/content/results/test', exist_ok=True)
os.makedirs('/content/models', exist_ok=True)

# Copy Dataset
logger.info("Copying dataset to Colab...")
os.system(f"cp -r {dataset_base}/train/GT/* /content/FFHQ/train/GT/ || echo 'Error copying train GT'")
os.system(f"cp -r {dataset_base}/val/GT/* /content/FFHQ/val/GT/ || echo 'Error copying val GT'")
os.system(f"cp -r {dataset_base}/val/LQ/* /content/FFHQ/val/LQ/ || echo 'Error copying val LQ'")
if os.path.exists(f"{dataset_base}/test/LQ"):
    os.system(f"cp -r {dataset_base}/test/LQ/* /content/FFHQ/test/LQ/ || echo 'Error copying test LQ'")
else:
    logger.warning(f"Test directory does not exist: {dataset_base}/test/LQ")

# Verify Dataset
train_images = len([f for f in os.listdir('/content/FFHQ/train/GT') if f.endswith('.png')])
val_gt_images = len([f for f in os.listdir('/content/FFHQ/val/GT') if f.endswith('.png')])
val_lq_images = len([f for f in os.listdir('/content/FFHQ/val/LQ') if f.endswith('.png')])
test_images = len([f for f in os.listdir('/content/FFHQ/test/LQ') if f.endswith('.png')])
logger.info(f"Training images: {train_images}, Val GT: {val_gt_images}, Val LQ: {val_lq_images}, Test: {test_images}")
if train_images == 0 or val_gt_images == 0 or val_lq_images == 0:
    raise ValueError("Dataset directories are empty.")

# Generate Meta Info
with open(opt['meta_info'], 'w') as f:
    for fname in sorted(os.listdir('/content/FFHQ/train/GT')):
        if fname.endswith('.png'):
            f.write(f"{fname}\n")
with open('/content/FFHQ/val/meta_info_FFHQval_GT.txt', 'w') as f:
    for fname in sorted(os.listdir('/content/FFHQ/val/GT')):
        if fname.endswith('.png'):
            f.write(f"{fname}\n")

# DataLoaders
train_dataset = FFHQsubDataset(opt)
val_dataset = FFHQsubDataset({
    **opt,
    'dataroot_gt': '/content/FFHQ/val/GT',
    'meta_info': '/content/FFHQ/val/meta_info_FFHQval_GT.txt',
    'use_hflip': False,
    'use_rot': False
})
train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=4, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=2, pin_memory=True)
logger.info(f"Train samples: {len(train_dataset)}, Validation samples: {len(val_dataset)}")

# Training Setup
model = FaceRestormer().to(device)
params = sum(p.numel() for p in model.parameters())
logger.info(f"Model parameters: {params}")
assert params < 250000, f"Model exceeds parameter limit: {params.sum()}}"

def init_weights(m):
    if isinstance(m, (nn.Conv2dnn.Linear)):
        nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
        if m.bias is not None:
            nn.init.constant_(m.bias_, 0)

model.apply(init_weights)
optimizer = optim.Adam(model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)
criterion_l1 = nn.L1Loss()
criterion_perceptual = PerceptualLoss()
scaler = GradScaler()
writer = SummaryWriter("runs/face_restormer")

# Training Loop
num_epochs = 15
best_psnr = 0
best_epoch = 0

for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0
    num_batches = 0

    for data in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        gt = data['gt'].to(device)
        kernel = data['kernel'].to(device)
        sinc_kernel = data['sinc_kernel'].to(device)
        lq = degrade_image(gt, kernel, sinc_kernel)
        landmark_heatmap = generate_landmark_heatmap(lq)

        with autocast():
            pred = model(lq, landmark_heatmap)
            loss_l1 = criterion_l1(pred, gt)
            loss_perceptual = criterion_perceptual(pred, gt)
            loss_landmark = criterion_l1(pred * F.interpolate(landmark_heatmap, size=(512, 512), mode='bilinear'), gt * F.interpolate(landmark_heatmap, size=(512, 512), mode='bilinear'))
            loss = 0.5 * loss_l1 + 0.3 * loss_perceptual + 0.2 * loss_landmark
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

        epoch_loss += loss.item()
        num_batches += 1

    scheduler.step()

    avg_train_loss = epoch_loss / num_batches
    logger.info(f"Epoch {epoch+1}/{num_epochs}, Train Loss: {avg_train_loss:.4f}, LR: {scheduler.get_last_lr()[0]:.2e}")
    writer.add_scalar('Loss/train', avg_train_loss, epoch)

    # Validation
    model.eval()
    psnr_list = []
    with torch.no_grad():
        for data in val_loader:
            gt = data['gt'].to(device)
            kernel = data['kernel'].to(device)
            sinc_kernel = data['sinc_kernel'].to(device)
            lq = degrade_image(gt, kernel, sinc_kernel)
            landmark_heatmap = generate_landmark_heatmap(lq)
            pred = model(lq, landmark_heatmap)
            for i in range(pred.size(0)):
                pred_np = (pred[i].permute(1, 2, 0.cpu().numpy() * 255).astype(np.uint8)
                gt_np = (gt[i].permute(1, 2, 0).cpu().numpy() * 255).astype(np.uint8)
                mse = np.mean((pred_np - gt_np) ** 2))
                psnr = 20 * np.log10(255 / np.sqrt(mse)) if mse > 0 else 100
                psnr_list.append(psnr)
                Image.fromarray(pred_np).save(f"/content/results/val/{len(psnr_list):04d}.png")

    mean_psnr = np.mean(psnr_list)
    logger.info(f"Epoch {epoch+1}/{num_epochs}, Validation PSNR: {mean_psnr:.2f} dB")
    if mean_psnr > best_psnr:
        best_psnr = mean_psnr
        best_epoch = epoch + 1
        torch.save(model.state_dict(), "/content/models/face_restormer_best.pth")

    writer.add_scalar('PSNR/val', mean_psnr, epoch)

torch.save(model.state_dict(), "/content/models/face_restormer_final.pth")
logger.info(f"Training completed. Best PSNR: {best_psnr:.2f} dB at epoch {best_epoch}")

# Inference
model.load_state_dict(torch.load("/content/models/face_restormer_best.pth"))
model.eval()

# Validation Inference
psnr_list = []
with torch.no_grad():
    for data in tqdm(val_loader, desc="Validation Inference"):
        gt = data['gt'].to(device)
        kernel = data['kernel'].to(device)
        sinc_kernel = data['sinc_kernel'].to(device)
        lq = degrade_image(gt, kernel, sinc_kernel)
        landmark_heatmap = generate_landmark_heatmap(lq)
        pred = model(lq, landmark_heatmap)
        for i in range(pred.shape[0]):
            pred_np = (pred[i].squeeze().permute(1, 2, 0).cpu().cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
            gt_np = (gt[i].permute(1, 2, 0).cpu().cpu().numpy() * 255).astype(np.uint8)
            mse = np.mean((pred_np - gt_np) ** 2))
            psnr = 20 * np.log10(255 / np.sqrt(mse)) if mse > 0 else 100
            psnr_list.append(psnr)
            Image.fromarray(pred_np).save(f"/content/results/val/{len(psnr_list):04d}.png")

mean_psnr = np.mean(psnr_list)
with open("/content/results/val/scores.txt", 'w') as f:
    f.write(f"PSNR: {mean_psnr:.4f}")
logger.info(f"Validation PSNR: {mean_psnr:.2f} dB")

# Test Inference
test_dir = "/content/FFHQ/test/LQ"
if os.path.exists(test_dir) and os.listdir(test_dir):
    test_files = sorted([f for f in os.listdir(test_dir) if f.endswith('.png')])
    with torch.no_grad():
        for fname in tqdm(test_files, desc="Test Images"):
            img = Image.open(os.path.join(test_dir, fname)).convert('RGB').resize((128, 128))
            lq = img2tensor(np.array(img), bgr2rgb=True, float32=True).unsqueeze(0).to(device)
            landmark_heatmap = generate_landmark_heatmap(lq)
            pred = model(lq, landmark_heatmap)
            pred_np = (pred.squeeze().squeeze().permute(1, 2, 0).cpu().numpy() * 255).clip(0, 255).astype(np.uint8)
            Image.fromarray(pred_np).write(f"/content/results/test/{fname}")
    logger.info("Test predictions saved")
else:
    logger.warning(f"No test set available at {test_dir}")

# Submission Prep
os.makedirs('/content/submit', exist_ok=True)
os.makedirs('/content/adx/src', exist_ok=True)
os.makedirs('/content/adx/logs', exist_ok=True)
os.system(f"cp /content/results/test/*.png /content/submit/ || echo 'No test results to copy'")
os.system('zip -r /content/submit.zip /content/submit/')
os.system(f"cp /content/results/val/*.png /content/adx/ || echo 'No validation results'")
os.system(f"cp /content/results/val/scores.txt /content/adx/codalab_score.txt || echo 'No scores.txt'")
os.system(f"cp /content/models/face_restormer_best.pth /content/adx/")
os.system(f"cp -r /content/runs/face_restormer/* /content/adx/logs/ || echo 'No logs'")
