In [39]:
import argparse
import os
import torch, torchvision
import numpy as np
from torch.utils.data import DataLoader
import torch.nn.functional as F
from train_new import UNet, get_SIDD_validation, calculate_psnr, calculate_ssim, BayerPatternShifter
from torchvision import transforms
from PIL import Image
from scipy.io import loadmat
from tqdm import tqdm
import matplotlib.pyplot as plt

In [40]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def get_SIDD_validation(dataset_dir):
    val_data_dict = loadmat(
        os.path.join(dataset_dir, "ValidationNoisyBlocksRaw.mat"))
    val_data_noisy = val_data_dict['ValidationNoisyBlocksRaw']
    val_data_dict = loadmat(
        os.path.join(dataset_dir, 'ValidationGtBlocksSrgb.mat'))
    val_data_gt = val_data_dict['ValidationGtBlocksSrgb']
    # print(val_data_gt.shape)
    num_img, num_block, _, _, _ = val_data_gt.shape
    return num_img, num_block, val_data_noisy, val_data_gt


def test(model_path, data_dir, out_dir):
    model = UNet(in_channels=4, out_channels=3, wf=48).to(device)
    model.load_state_dict(torch.load(model_path))
    model.eval()

    num_img, num_block, val_noisy, val_gt = get_SIDD_validation(data_dir)
    os.makedirs(out_dir, exist_ok=True)

    psnr_list, ssim_list = [], []

    for idx in tqdm(range(num_img), desc="Testing", unit="image"):
        for idy in tqdm(range(num_block), desc="Processing blocks", unit="block"):
            gt = val_gt[idx, idy] / 255.0  # [H, W, 3]
            # print(gt.shape)
            noisy = val_noisy[idx, idy][:, :, np.newaxis]  # [H, W, 1]

            transformer = transforms.Compose([transforms.ToTensor()])
            noisy_tensor = transformer(noisy).unsqueeze(0).to(device)
            noisy_tensor = BayerPatternShifter.bayer_1ch_to_4ch(noisy_tensor)

            with torch.no_grad():
                pred_rgb = model(noisy_tensor)

            pred_rgb = pred_rgb.permute(0, 2, 3, 1).cpu().clamp(0, 1).numpy().squeeze(0)
            pred255 = np.clip(pred_rgb * 255.0 + 0.5, 0, 255).astype(np.uint8)

            psnr = calculate_psnr(gt.astype(np.float32), pred_rgb.astype(np.float32), 1.0)
            ssim = calculate_ssim(gt * 255.0, pred_rgb * 255.0)
            psnr_list.append(psnr)
            ssim_list.append(ssim)

            save_path = os.path.join(out_dir, f"val_{idx:03d}_{idy:03d}.png")
            Image.fromarray(pred255).save(save_path)

    print(f"Average PSNR: {np.mean(psnr_list):.4f}, SSIM: {np.mean(ssim_list):.4f}")    


In [None]:
# parser = argparse.ArgumentParser()
# parser.add_argument('--checkpoint', type=str, required=True, help='Path to trained model (.pth)')
# parser.add_argument('--test_dir', type=str, required=True, help='Path to validation data folder')
# parser.add_argument('--out_dir', type=str, default='./test_results', help='Directory to save predictions')
# args = parser.parse_args()

# test(args.checkpoint, args.test_dir, args.out_dir)
test_im = Image.open('./data/validation/kodim01.png')
shifter = BayerPatternShifter()
test_im = test_im.convert('RGB')
test_im = torchvision.transforms.ToTensor()(test_im).to(device)
# print(test_im.shape)
bayer_f_4ch = shifter.remosaic(test_im, 'RGGB', out_channel=4)
print(bayer_f_4ch.shape)  # [1, 4, H, W]
# merge 2 G channels of bayer_f
bayer_f_3ch = bayer_f_4ch.clone()
bayer_f_3ch[1, :, :] = (bayer_f_4ch[1, :, :] + bayer_f_4ch[2, :, :])
bayer_f_3ch[2, :, :] = bayer_f_4ch[3, :, :]
bayer_f_3ch = bayer_f_3ch[[0, 1, 2], :, :]  # [1, 3, H, W]

bayer_f_1ch = shifter.remosaic(test_im, 'RGGB', out_channel=1)

def bilinear_demosaic_rggb(bayer):
    # bayer: [B, 3, H, W]
    print(bayer.shape)

    # 获取 R, G, B 通道
    R = bayer[:, 0, :, :]  # [B, H, W]
    G = bayer[:, 1, :, :]  # [B, H, W]
    B_ = bayer[:, 2, :, :]  # [B, H, W]

    # 插值空位置
    R = F.interpolate(R.unsqueeze(1), scale_factor=1, mode='bilinear', align_corners=False)
    G = F.interpolate(G.unsqueeze(1), scale_factor=1, mode='bilinear', align_corners=False)
    B_ = F.interpolate(B_.unsqueeze(1), scale_factor=1, mode='bilinear', align_corners=False)

    # 拼接成 RGB
    rgb = torch.cat([R, G, B_], dim=1)  # [B, 3, H, W]
    return rgb



simple_demosaic = bilinear_demosaic_rggb(bayer_f_3ch.unsqueeze(0))
print(simple_demosaic.shape)

plt.subplot(1, 3, 1)
plt.imshow(test_im.permute(1, 2, 0).cpu().numpy())
plt.title('Original')
plt.subplot(1, 3, 2)
plt.imshow(bayer_f_3ch.squeeze(0).permute(1, 2, 0).cpu().numpy())
plt.title('Bayer Pattern')
plt.subplot(1, 3, 3)
plt.imshow(simple_demosaic.squeeze(0).permute(1, 2, 0).cpu().numpy())
plt.title('Simple Demosaic')

torch.Size([4, 512, 768])


UnboundLocalError: local variable 'B' referenced before assignment