In [1]:
import os
import cv2
import torch
import torch.nn.functional as F
import torchvision.transforms as T
import matplotlib.pyplot as plt
import numpy as np

from PIL import Image
from torchvision.models.optical_flow import raft_large, Raft_Large_Weights

# Path load
root = 'results/sem_cyclegan_6b/test_latest/images'
outputs = sorted([os.path.join(root, output) for output in os.listdir(root) if ('fake_B' in output)])
noises = sorted([os.path.join(root, output) for output in os.listdir(root) if ('real_A' in output)])
gts = sorted([os.path.join(root, output) for output in os.listdir(root) if ('real_B' in output)])

# Device setting
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Save option
save_folder = 'results/aligned'
save_image = False

In [2]:
# Data preprocess
def preprocess(batch):
    transforms = T.Compose(
        [   
            T.ToTensor(),
            T.ConvertImageDtype(torch.float32),
            T.Normalize(mean=0.5, std=0.5),  # map [0, 1] into [-1, 1]
            T.Resize(size=(520, 960), antialias=True), # Resize to a size supported by RAFT
        ]
    )
    return transforms(batch)


# Data resize to original size
def resize(batch):
    transforms = T.Compose(
        [   
            T.Resize(size=(704, 704), antialias=True), # Resize to original size
        ]
    )
    return transforms(batch)


# Warp function
def warp_feature(x, flow):
    """
    x: (B, C, H, W)
    flow: (B, 2, H, W) - (u, v) shape optical flow
        Convert the flow into a normalized grid with values in the range [-1, 1].
    """
    B, C, H, W = x.size()
    # generate index grid
    # normalized coords: -1 ~ 1 range
    y, x_ = torch.meshgrid(
        torch.linspace(-1, 1, H, device=x.device),
        torch.linspace(-1, 1, W, device=x.device),
        indexing='ij',
    )
    # (H, W) -> (B, H, W)
    grid = torch.stack((x_, y), dim=-1).unsqueeze(0).repeat(B, 1, 1, 1)
    # Separate the flow into its horizontal (u) and vertical (v) components.
    flow_u = flow[:, 0, :, :]  # (B, H, W)
    flow_v = flow[:, 1, :, :]  # (B, H, W)

    # Assume that flow_u and flow_v are normalized to the range [–1, 1].
    grid[:, :, :, 0] += flow_u
    grid[:, :, :, 1] += flow_v

    return F.grid_sample(x, grid, mode='bilinear', padding_mode='zeros', align_corners=True)


def flow_normalize(x):
    B, _, _, _ = x.shape
    max_ = torch.tensor([[[520]], [[960]]], device='cuda:0').repeat(B, 1, 1, 1)
    x_norm = 2 * (-x + max_) / (2 * max_ + 1e-8) - 1
    return x_norm

In [3]:
# L1 loss with an align method
# Model load
model = raft_large(weights=Raft_Large_Weights.C_T_SKHT_V2, progress=False).to(device)
model = model.eval()

# Compute optical flow and warp images
total_loss = []
for i in range(len(gts)):
    # RAFT only supports color images so repeat the gray image across three channels.
    outputs_batch = preprocess(Image.open(outputs[i])).repeat(1, 3, 1, 1).to(device)
    gts_batch = preprocess(Image.open(gts[i])).repeat(1, 3, 1, 1).to(device)
    mask = torch.ones_like(outputs_batch)
    with torch.no_grad():
        list_of_flows = model(outputs_batch, gts_batch)
    flow = list_of_flows[-1]
    flow_normalized = flow_normalize(flow)
    x = torch.mean(warp_feature(outputs_batch, flow_normalized), dim=1)
    mask = torch.mean(warp_feature(mask, flow_normalized), dim=1)
    gts_batch_ = torch.mean(gts_batch, dim=1) * (mask > 0.5)
    total_loss.append(F.l1_loss(resize((0.5 + gts_batch_ * 0.5).clamp(0,1)), resize((0.5 + 0.5 * x).clamp(0, 1))))

    if save_image:
        x_ = ((x.squeeze(0) * 0.5 + 0.5).clamp(0, 1) * 255).detach().to(torch.uint8).cpu().numpy()
        outputs_batch = ((torch.mean(outputs_batch.squeeze(0), dim=0) * 0.5 + 0.5).clamp(0, 1) * 255).detach().to(torch.uint8).cpu().numpy()
        gts_batch_ = ((torch.mean(gts_batch.squeeze(0), dim=0) * 0.5 + 0.5).clamp(0, 1) * 255).detach().to(torch.uint8).cpu().numpy()
        x_ = cv2.resize(x_, (704, 704), interpolation=cv2.INTER_LINEAR)
        noise = np.array(Image.open(noises[i]))
        gts_batch_ = cv2.resize(gts_batch_, (704, 704), interpolation=cv2.INTER_LINEAR)
        outputs_batch = cv2.resize(outputs_batch, (704, 704), interpolation=cv2.INTER_LINEAR)
        images = [noise, outputs_batch, x_, gts_batch_]
        labels = ['Noise Input', 'Output', 'Aligned Output', 'Ground Truth']
        fig, axes = plt.subplots(2, 2, figsize=(8, 8))
        for ax, img, lbl in zip(axes.flatten(), images, labels):
            ax.imshow(img, cmap='gray')
            ax.set_title(lbl, fontsize=12)
            ax.axis('off')
        os.makedirs(save_folder, exist_ok=True)
        plt.tight_layout(pad=1.0)
        plt.savefig(f'{save_folder}/aligned_{i:04d}.pdf', dpi=300, bbox_inches='tight')
        plt.close(fig)
print('Total L1 loss with an align method : \n', torch.mean(torch.tensor(total_loss)))

Total L1 loss with an align method : 
 tensor(0.0701)


In [8]:
# L1 loss without an align method
to_tensor = T.ToTensor()

total_loss = []
for i in range(len(gts)):
    outputs_batch = to_tensor(Image.open(outputs[i]))
    gts_batch = to_tensor(Image.open(gts[i]))
    total_loss.append(F.l1_loss(gts_batch, outputs_batch))

print('Total L1 loss without an align method : \n', torch.mean(torch.tensor(total_loss)))

Total L1 loss without an align method : 
 tensor(0.1395)
