In [5]:
import os
import glob
import cv2
import torch
import numpy as np

from drct.archs.DRCT_arch import DRCT  # Make sure the path to DRCT_arch is correct

# Define arguments manually
class Args:
    model_path = "/home/rahul_b/Run1/DRCT/experiments/train_DRCT-L_SRx4_finetune_from_ImageNet_pretrain_archived_20250419_230123/models/net_g_latest.pth"
    input = "/home/rahul_b/Run1/data/val/LR"
    output = "results/DRCT-L_X4"
    scale = 4
    tile = None
    tile_overlap = 32

args = Args()

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = DRCT(
    upscale=4, in_chans=3, img_size=64, window_size=16, compress_ratio=3,
    squeeze_factor=30, conv_scale=0.01, overlap_ratio=0.5, img_range=1.,
    depths=[6]*12, embed_dim=180, num_heads=[6]*12, gc=32, mlp_ratio=2,
    upsampler='pixelshuffle', resi_connection='1conv'
)
model.load_state_dict(torch.load(args.model_path)['params'], strict=True)
model.eval()
model = model.to(device)

window_size = 16
os.makedirs(args.output, exist_ok=True)

def test(img_lq, model, args, window_size):
    if args.tile is None:
        output = model(img_lq)
    else:
        b, c, h, w = img_lq.size()
        tile = min(args.tile, h, w)
        assert tile % window_size == 0
        stride = tile - args.tile_overlap
        sf = args.scale

        h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
        w_idx_list = list(range(0, w-tile, stride)) + [w-tile]

        E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
        W = torch.zeros_like(E)

        for h_idx in h_idx_list:
            for w_idx in w_idx_list:
                in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
                out_patch = model(in_patch)
                out_patch_mask = torch.ones_like(out_patch)

                E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
                W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
        output = E.div_(W)

    return output


for idx, path in enumerate(sorted(glob.glob(os.path.join(args.input, '*')))):
    imgname = os.path.splitext(os.path.basename(path))[0]
    print(f"Processing {idx}: {imgname}")

    img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
    img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
    img = img.unsqueeze(0).to(device)

    with torch.no_grad():
        _, _, h_old, w_old = img.size()
        h_pad = (h_old // window_size + 1) * window_size - h_old
        w_pad = (w_old // window_size + 1) * window_size - w_old

        img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, :h_old + h_pad, :]
        img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, :w_old + w_pad]

        output = test(img, model, args, window_size)
        output = output[..., :h_old * args.scale, :w_old * args.scale]

        output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
        output = (output * 255.0).round().astype(np.uint8)

        out_path = os.path.join(args.output, f'{imgname}_DRCT-L_X4.png')
        cv2.imwrite(out_path, output)
        print(f"Saved: {out_path}")


Processing 0: img_19
Saved: results/DRCT-L_X4/img_19_DRCT-L_X4.png
Processing 1: img_23
Saved: results/DRCT-L_X4/img_23_DRCT-L_X4.png
Processing 2: img_27
Saved: results/DRCT-L_X4/img_27_DRCT-L_X4.png
Processing 3: img_9
Saved: results/DRCT-L_X4/img_9_DRCT-L_X4.png
