In [19]:
import os
import cv2
import torch
import torch.nn.functional as F
import torchvision.transforms as T

from PIL import Image
from torchvision.models.optical_flow import raft_large

# Path load
outputs_path = 'sample_output'
gt_path = 'sample_gt'
outputs = sorted([os.path.join(outputs_path, output) for output in os.listdir(outputs_path)])
gts = sorted([os.path.join(gt_path, gt) for gt in os.listdir(gt_path)])

# Device setting
device = "cuda:0" if torch.cuda.is_available() else "cpu"

# Save option
save_folder = 'save'
save_image = False

In [2]:
# Data preprocess
def preprocess(batch):
    transforms = T.Compose(
        [   
            T.ToTensor(),
            T.ConvertImageDtype(torch.float32),
            T.Normalize(mean=0.5, std=0.5),  # map [0, 1] into [-1, 1]
            T.Resize(size=(520, 960)), # Resize to a size supported by RAFT
        ]
    )
    batch = transforms(batch)
    return batch


# Warp function
def warp_feature(x, flow):
    """
    x: (B, C, H, W)
    flow: (B, 2, H, W) - (u, v) shape optical flow
        Convert the flow into a normalized grid with values in the range [-1, 1].
    """
    B, C, H, W = x.size()
    # generate index grid
    # normalized coords: -1 ~ 1 range
    y, x_ = torch.meshgrid(
        torch.linspace(-1, 1, H, device=x.device),
        torch.linspace(-1, 1, W, device=x.device)
    )
    # (H, W) -> (B, H, W)
    grid = torch.stack((x_, y), dim=-1).unsqueeze(0).repeat(B, 1, 1, 1)
    # Separate the flow into its horizontal (u) and vertical (v) components.
    flow_u = flow[:, 0, :, :]  # (B, H, W)
    flow_v = flow[:, 1, :, :]  # (B, H, W)

    # Assume that flow_u and flow_v are normalized to the range [–1, 1].
    grid[:, :, :, 0] += flow_u
    grid[:, :, :, 1] += flow_v

    return F.grid_sample(x, grid, mode='bilinear', padding_mode='zeros', align_corners=True)


def flow_normalize(x):
    B, _, _, _ = x.shape
    max_ = torch.tensor([[[520]], [[960]]], device='cuda:0').repeat(B, 1, 1, 1)
    x_norm = 2 * (-x + max_) / (2 * max_ + 1e-8) - 1
    return x_norm

In [24]:
# Model load
model = raft_large(pretrained=True, progress=False).to(device)
model = model.eval()

# Compute optical flow and warp images
total_loss = []
for i in range(len(gts)):
    # RAFT only supports color images so repeat the gray image across three channels.
    outputs_batch = preprocess(Image.open(outputs[i])).repeat(1, 3, 1, 1).to(device)
    gts_batch = preprocess(Image.open(gts[i])).repeat(1, 3, 1, 1).to(device)
    mask = torch.ones_like(outputs_batch)
    with torch.no_grad():
        list_of_flows = model(outputs_batch, gts_batch)
    flow = list_of_flows[-1]
    flow_normalized = flow_normalize(flow)
    x = torch.mean(warp_feature(outputs_batch, flow_normalized), dim=1)
    mask = torch.mean(warp_feature(mask, flow_normalized), dim=1)
    gts_batch = torch.mean(gts_batch, dim=1) * (mask > 0.5)
    total_loss.append(F.l1_loss(gts_batch, x))

    if save_image:
        x_ = ((x.squeeze(0) * 0.5 + 0.5).clamp(0, 1) * 255).detach().to(torch.uint8).cpu().numpy()
        x_ = cv2.resize(x_, (256, 256), interpolation=cv2.INTER_LINEAR)
        x_ = Image.fromarray(x_, mode='L')
        os.makedirs(save_folder, exist_ok=True)
        x_.save(f'{save_folder}/tmp_{i}.png')
print('Total L1 loss :', torch.mean(torch.tensor(total_loss)))

Total L1 loss : tensor(0.1602)
