In [1]:
!python datsr/test.py -opt "options/test/test_restoration_mse.yml" --launcher "none"

OrderedDict([('name', 'test_restoration_mse'), ('suffix', None), ('scale', 4), ('model_type', 'RefRestorationModel'), ('set_CUDA_VISIBLE_DEVICES', None), ('crop_border', 4), ('gpu_ids', None), ('datasets', OrderedDict([('test_1', OrderedDict([('name', 'WR-SR'), ('type', 'RefCUFEDDataset'), ('dataroot_in', 'datasets\\WR-SR\\input'), ('dataroot_ref', 'datasets\\WR-SR\\ref'), ('io_backend', OrderedDict([('type', 'disk')])), ('bicubic_model', 'PIL'), ('ann_file', 'datasets\\WR-SR_pairs.txt')]))])), ('val_func', 'BasicSRValidation'), ('save_img', False), ('network_g', OrderedDict([('type', 'SwinUnetv3RestorationNet'), ('ngf', 128), ('n_blocks', 8), ('groups', 8), ('embed_dim', 128), ('depths', [4, 4]), ('num_heads', [4, 4]), ('window_size', 8), ('use_checkpoint', True)])), ('network_map', OrderedDict([('type', 'FlowSimCorrespondenceGenerationArch'), ('patch_size', 3), ('stride', 1), ('vgg_layer_list', ['relu1_1', 'relu2_1', 'relu3_1']), ('vgg_type', 'vgg19')])), ('network_extractor', Ordere

  'On January 1, 2023, MMCV will release v2.0.0, in which it will remove '
2024-11-17 13:05:33,090.090 - INFO:   name: test_restoration_mse
  suffix: None
  scale: 4
  model_type: RefRestorationModel
  set_CUDA_VISIBLE_DEVICES: None
  crop_border: 4
  gpu_ids: None
  datasets:[
    test_1:[
      name: WR-SR
      type: RefCUFEDDataset
      dataroot_in: datasets\WR-SR\input
      dataroot_ref: datasets\WR-SR\ref
      io_backend:[
        type: disk
      ]
      bicubic_model: PIL
      ann_file: datasets\WR-SR_pairs.txt
      phase: test
      scale: 4
    ]
  ]
  val_func: BasicSRValidation
  save_img: False
  network_g:[
    type: SwinUnetv3RestorationNet
    ngf: 128
    n_blocks: 8
    groups: 8
    embed_dim: 128
    depths: [4, 4]
    num_heads: [4, 4]
    window_size: 8
    use_checkpoint: True
  ]
  network_map:[
    type: FlowSimCorrespondenceGenerationArch
    patch_size: 3
    stride: 1
    vgg_layer_list: ['relu1_1', 'relu2_1', 'relu3_1']
    vgg_type: vgg19
  ]
  networ

In [None]:
import torch
import torch.nn as nn
from torchvision import models
from datsr.models.archs.datsr_arch import DATSR
from datsr.test.

# Grayscale transformation function
def grayscale_transform(image):
    image = image.squeeze(0).cpu().numpy().transpose(1, 2, 0)
    gray_image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    return torch.tensor(gray_image).unsqueeze(0).cuda()

# Whitening and Coloring Transform (WCT) function
def wct(content_features, style_features):
    c_mean = content_features.mean(dim=[2, 3], keepdim=True)
    s_mean = style_features.mean(dim=[2, 3], keepdim=True)

    content_centered = content_features - c_mean
    style_centered = style_features - s_mean

    c_cov = content_centered @ content_centered.permute(0, 1, 3, 2)
    s_cov = style_centered @ style_centered.permute(0, 1, 3, 2)

    c_eigval, c_eigvec = torch.symeig(c_cov, eigenvectors=True)
    s_eigval, s_eigvec = torch.symeig(s_cov, eigenvectors=True)

    whitening = c_eigvec @ torch.diag_embed(torch.sqrt(1 / (c_eigval + 1e-5))) @ c_eigvec.permute(0, 1, 3, 2)
    coloring = s_eigvec @ torch.diag_embed(torch.sqrt(s_eigval + 1e-5)) @ s_eigvec.permute(0, 1, 3, 2)

    transformed_features = coloring @ (whitening @ content_centered) + s_mean
    return transformed_features

# Phase Replacement (PR) function
def phase_replacement(content_feature, stylized_feature):
    content_fft = torch.fft.fft2(content_feature, dim=(-2, -1))
    stylized_fft = torch.fft.fft2(stylized_feature, dim=(-2, -1))

    amplitude_content = torch.abs(content_fft)
    phase_content = torch.angle(content_fft)

    amplitude_stylized = torch.abs(stylized_fft)

    result = amplitude_stylized * torch.exp(1j * phase_content)
    return torch.fft.ifft2(result, dim=(-2, -1)).real

# Domain Matching Module
class DomainMatchingSR(nn.Module):
    def __init__(self):
        super(DomainMatchingSR, self).__init__()
        self.encoder = models.vgg19(pretrained=True).features[:21]  # Use VGG19 for feature extraction

    def forward(self, lr_image, ref_image):
        # Step 1: Grayscale Transformation
        lr_gray = grayscale_transform(lr_image)
        ref_gray = grayscale_transform(ref_image)

        # Step 2: Feature Extraction (with VGG19)
        lr_features = self.encoder(lr_gray)
        ref_features = self.encoder(ref_gray)

        # Step 3: Whitening and Coloring Transform
        transformed_features = wct(lr_features, ref_features)

        # Step 4: Phase Replacement
        output = phase_replacement(lr_features, transformed_features)

        return output

# Integrated DATSR with Domain Matching
class IntegratedDATSR(nn.Module):
    def __init__(self):
        super(IntegratedDATSR, self).__init__()
        self.datsr = DATSR()  # Load the existing DATSR model
        self.domain_matching = DomainMatchingSR()

    def forward(self, lr_image, ref_image):
        # Apply Domain Matching Module
        domain_matched_image = self.domain_matching(lr_image, ref_image)

        # Pass through DATSR
        output_sr = self.datsr(domain_matched_image)

        return output_sr