# Evaluate example for M3CAM

In [1]:
from torch.utils.data import Dataset,DataLoader
import torchvision.transforms as transforms
import os
from tqdm.auto import tqdm
import numpy as np
import torch
from sklearn.cluster import KMeans
from torch import nn
import random

import swin_util as swu
from criterion import *
import model.arch_util as arch_util

In [2]:
def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
setup_seed(1)
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## Dataset

In [3]:
def nextPermutation(nums):
    if len(nums)<=1:
        return 
    for i in range(len(nums)-2,-1,-1):
        if nums[i]<nums[i+1]:
            for k in range(len(nums)-1,i,-1):
                if nums[k]>nums[i]:
                    nums[i],nums[k]=nums[k],nums[i]
                    nums[i+1:]=sorted(nums[i+1:])
                    break
            break
        else:
            if i==0:
                nums.sort()
def get_choice(index,number):
    now_array = []
    def recursion(index,num,start,now):
        if num == 0:
            now_array.append(now.copy())
            return 0
        else:
            for i in range(start,len(index)):
                now.append(index[i])
                recursion(index,num-1,i+1,now)
                now.pop()
    recursion(index,number,0,[])
    return now_array
def compute_area(flows):
    ## 构建所有的矩形
    all_matrix = []
    for i in range(len(flows)):
        x,y = flows[i]
        dx,dy = x+1,y+1
        if dx<=2 and dy <=2:
            all_matrix.append((x,y,dx,dy))
        elif dx>2 and dy > 2:
            all_matrix.append((x,y,2,2))
        elif dx>2 and dy <= 2:
            all_matrix.append((x,y,2,dy))
            all_matrix.append((0,y,dx-2,dy))
        else:
            all_matrix.append((x,y,dx,2))
            all_matrix.append((x,0,dx,dy-2))
    
    ## 然后计算所有面积
    ## 每个矩形 左上 x y + 右下 x y
    ps = []
    for info in all_matrix:
        ps.append(info[0])
        ps.append(info[2])
#         print(info)
    ps.sort()
    ans = 0
    for i in range(1,len(ps)):
        a, b = ps[i - 1], ps[i]
        width = b - a
        if width == 0:
            continue
        lines = [(info[1], info[3]) for info in all_matrix if info[0] <= a and b <= info[2]]
        lines.sort()
        height, l, r = 0, -1, -1
        for cur in lines:
            if cur[0] > r:
                height += r - l
                l, r = cur
            elif cur[1] > r:
                r = cur[1]
        height += r - l
        ans += height * width
    return ans
def select_best(flows,number):
    flows = flows % 2
#     print(flows.shape)
    flows[0] = 0
    index = [i for i in range(flows.shape[0]) if i !=0]
    # 递归得到所有可能的组合
    all_choice = get_choice(index,number)
    minIndex = 0
    minArea  = 0
    for i in range(len(all_choice)):
        choice_now = all_choice[i] + [0]
        area = compute_area(flows[choice_now])
#         print(area)
        if minArea > area:
            minArea = area
            minIndex = i
    return [0]+all_choice[minIndex]

In [4]:
transform_raw = transforms.Compose([
    transforms.ToTensor(),
])
class RealDataset(Dataset):
    def __init__(self,path):
        self.image_path = path
        self.image_list = os.listdir(self.image_path)
        
        self.data_path_list = []
        self.capture_path   = []
        
        for img_path in self.image_list:
            img = os.path.join(self.image_path,img_path)
            self.capture_path.append(img)
            part_name = os.listdir(img)
            for part in part_name:
                part_path = os.path.join(img,part)
                if os.path.isdir(part_path):
                    self.data_path_list.append(part_path)
        self.nframes = 4
        self.choice_list = self.getChoice()
        
    def getChoice(self):
        choice_list = []
        for i in tqdm(range(len(self.data_path_list))):
            path = self.data_path_list[i]
            dir_path = os.path.dirname(path)

            data = {}
            image_flows = []
            for i in range(20):
                flows = np.load(os.path.join(path,"offset"+str(i)+".npy"))
                flows = torch.FloatTensor(flows)
                image_flows.append(flows)
            # 接下来筛选
            image_flows = torch.stack(image_flows,dim=0)
            flows_mean = torch.mean(image_flows[:,:,14:-14,14:-14],dim=[2,3])
            select_choice = select_best(flows_mean,self.nframes - 1)
            choice_list.append(select_choice)
        return choice_list
        
    def __getitem__(self,index):
        path = self.data_path_list[index]
        dir_path = os.path.dirname(path)
        
        data = {}
        image_burst = []
        image_flows = []
        for i in range(20):
            burst = np.load(os.path.join(path,str(i)+'.npy'))
            burst = transform_raw(burst)
            image_burst.append(burst)
            flows = np.load(os.path.join(path,"offset"+str(i)+".npy"))
            flows = torch.FloatTensor(flows)
            image_flows.append(flows)
        # 接下来筛选
        image_flows = torch.stack(image_flows,dim=0)
        image_burst = torch.stack(image_burst,dim=0)
        
        select_choice = self.choice_list[index]
        
        data["burst"] = image_burst[select_choice]  #(4,4,56,56) (N,C,H,W)
        data["flows"] = image_flows[select_choice]
        image_frame = np.load(os.path.join(path,'original'+str(0)+'.npy'))
        image_frame = transform_raw(image_frame)
        data['frame_gt'] = image_frame
        meta_info = np.load(os.path.join(dir_path,'meta_info'+str(0)+'.npy'),allow_pickle=True).item()
        data['meta_info'] = meta_info
        return data
    def __len__(self):
        return len(self.data_path_list)

In [5]:
imgpath = "./dataset/"
testDataset = RealDataset(imgpath)

  0%|          | 0/52 [00:00<?, ?it/s]

## Model

In [12]:
class M3CAM(nn.Module):
    def __init__(self, num_in_ch = 4, embed_dim = 96, ape = False, drop_rate = 0., num_feat = 64):
        
        super(M3CAM, self).__init__()
        embed_dim = 96
        self.pixel_shuffle = nn.PixelShuffle(2)
        
        self.upconv0 = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1, bias=True)
        self.upconv1 = nn.Conv2d(embed_dim, embed_dim * 4, 3, 1, 1, bias=True)
        self.upconv2 = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1, bias=True)
        self.upconv3 = nn.Conv2d(embed_dim, num_feat, 3, 1, 1, bias=True)
        self.lrelu = nn.LeakyReLU(0.1,inplace=True)
        
        nframes = 4
        depths=[6, 6, 6, 6]
        num_heads=[6, 6, 6, 6]
        mlp_ratio = 2.0
        qkv_bias = True
        qk_scale = None
        drop_rate = 0.
        attn_drop_rate = 0.
        drop_path_rate = 0.1
        norm_layer = nn.LayerNorm
        
        img_size = 56
        img_size_1 = 112
        img_size_2 = 224
        patch_size = 4
        use_checkpoint=False
        window_size = 7
        
        self.patch_norm = True
        self.ape = ape        
        self.norm_layer = norm_layer
        self.num_features = embed_dim
        self.norm = self.norm_layer(self.num_features)
        
        self.num_layers = 1 #len(depths)
        
        
        self.patch_embed = swu.PatchEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        num_patches = self.patch_embed.num_patches # 14 * 14 
        patches_resolution = self.patch_embed.patches_resolution #（14，14）
        
        num_patches_1 = num_patches * 4 # 28 * 28 = 784
        patches_resolution_1 = patches_resolution * 2 # 28 ,28
        
        num_patches_2 = num_patches_1 * 2 # 56 * 56 = 784*4
        patches_resolution_2 = patches_resolution_1 * 2 # 56,56
        
        self.patch_unembed = swu.PatchUnEmbed(
            img_size=img_size, patch_size=patch_size, in_chans=embed_dim, embed_dim=embed_dim,
            norm_layer=norm_layer if self.patch_norm else None)
        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
        self.swin_layers = nn.ModuleList()
        for i_layer in range(depths[0]):
            layer = swu.SwinTransformerBlock(dim=embed_dim,
                                             input_resolution=(patches_resolution[0]//2,patches_resolution[1]//2),
                                             num_heads=num_heads[0],
                                             window_size=window_size,
                                             shift_size=0 if (i_layer % 2 == 0) else window_size // 2, # 一层fix + 一层 shift
                                             mlp_ratio = mlp_ratio,
                                             qkv_bias = qkv_bias,
                                             qk_scale = qk_scale,   
                                             drop = drop_rate,
                                             attn_drop = attn_drop_rate,
                                             drop_path = dpr[i_layer],
                                             norm_layer = norm_layer
                                             )
            self.swin_layers.append(layer)
        self.pre_norm = norm_layer(embed_dim)

        # PatchEmbed 将 x 从维度2 开始拉直 x (B,96,112,112) -> (B.96,112*112)
        dpr_1 = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
        
        self.up_swin_1 = nn.ModuleList()
        for i_layer in range(self.num_layers):
            # swin trans 提取特征
            layer = swu.RSTB(dim=embed_dim,
                     input_resolution=(patches_resolution_1[0],
                                       patches_resolution_1[1]),
                     depth=depths[i_layer],
                     num_heads=num_heads[i_layer],
                     window_size=window_size,
                     mlp_ratio=mlp_ratio,
                     qkv_bias=qkv_bias, qk_scale=qk_scale,
                     drop=drop_rate, attn_drop=attn_drop_rate,
                     drop_path=dpr_1[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results
                     norm_layer=norm_layer,
                     downsample=None,
                     use_checkpoint=use_checkpoint,
                     img_size=img_size_1,
                     patch_size=patch_size)
            self.up_swin_1.append(layer)


        dpr_2 = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] 
        self.up_swin_2 = nn.ModuleList()
        for i_layer in range(self.num_layers):
            # swin trans 提取特征
            layer = swu.RSTB(dim=embed_dim,
                     input_resolution=(patches_resolution_2[0],
                                       patches_resolution_2[1]),
                     depth=depths[i_layer],
                     num_heads=num_heads[i_layer],
                     window_size=window_size,
                     mlp_ratio=mlp_ratio,
                     qkv_bias=qkv_bias, qk_scale=qk_scale,
                     drop=drop_rate, attn_drop=attn_drop_rate,
                     drop_path=dpr_2[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],  # no impact on SR results
                     norm_layer=norm_layer,
                     downsample=None,
                     use_checkpoint=use_checkpoint,
                     img_size=img_size_2,
                     patch_size=patch_size)
            self.up_swin_2.append(layer)
        
        self.pos_drop = nn.Dropout(p=drop_rate)
        
        self.Fusion = nn.Conv2d(embed_dim*nframes, embed_dim, 1, 1, padding=0, bias=True)
        
        self.upsample1 = nn.ConvTranspose2d(embed_dim,embed_dim,3, 2, 1,output_padding=1, bias=True)
        self.upsample2 = nn.ConvTranspose2d(num_feat,num_feat,3, 2, 1,output_padding=1, bias=True)
        
        self.convHR = nn.Conv2d(num_feat, num_feat, 1, 1, padding=0, bias=True)
        self.convLast = nn.Conv2d(num_feat, 1, 1, 1, padding=0, bias=True)
        
    def forward_features(self,x):
        # x_size (56,56)
        x_size = (x.shape[-2],x.shape[-1])
        # x (B,96,56,56) -> (B.56*56,96) 长度为56*56 特征 96 的序列
        x = self.patch_embed(x,use_norm=True)
        x = self.pos_drop(x)
        for idx,layer in enumerate(self.swin_layers):
            x = layer(x,x_size)
        x = self.pre_norm(x) #(BN,56*56,96)
        x = self.patch_unembed(x,x_size) # 变回 (BN,96,56,56)
        return x
    def up_sample_1(self,x):
        #(B,96,112,112)
        x_size = (x.shape[-2], x.shape[-1])
        x = self.patch_embed(x) #(B,112*112,96)
        x = self.pos_drop(x) #(B,112*112,96)
        for idx,layer in enumerate(self.up_swin_1):
            x = layer(x, x_size)
        x = self.norm(x) #(B,112*112,96)
        x = self.patch_unembed(x,x_size)  #(B,96,112,112)
        return x
    def up_sample_2(self,x):
        #(B,96,224,224)
        x_size = (x.shape[-2], x.shape[-1])
        x = self.patch_embed(x) #(B,224,224,96)
        x = self.pos_drop(x) #(B,224,224,96)
        for idx,layer in enumerate(self.up_swin_2):
            x = layer(x, x_size)
        x = self.norm(x) #(B,224,224,96)
        x = self.patch_unembed(x,x_size)  #(B,96,224,224)
        return x
    def forward(self,x,offset):
        B,N,C,H,W = x.size() # x (B,N,4,56,56)
        # 升维
        x   = x.view(B*N,C,H,W)
        img = self.lrelu(self.upconv0(x))  # img (B*N,96,56,56)
        # img1 (B,96,56,56)
        # 特征提取
        img = self.forward_features(img)   # img (B*N,96,56,56)
        img = self.lrelu(self.pixel_shuffle(self.upconv1(img)))  # img (B*N,96,112,112)
        
        _,_,H,W = img.size()               
        # 对齐
        
        offset = offset.view(B*N,-1,H,W) # offsets(B*N,2,H,W)
        
        warp_fea = arch_util.flow_warp(img,offset.permute(0,2,3,1),'bilinear')  # warp_fea (B*N,96,112,112)
        warp_fea = warp_fea.view(B,N,-1,H,W)  # warp_fea (B,N,96,112,112)
        
        warp_fea[:,0,:,:,:] = img.view(B,N,-1,H,W)[:,0,:,:,:]
        
        warp_fea = warp_fea.view(B,-1,H,W)    # warp_fea (B,96*N,112,112)
        warp_fea = self.lrelu(self.Fusion(warp_fea))  # warp_fea (B,96,112,112)
        
        # swin Transformer
        features = self.up_sample_1(warp_fea)  # features (B,96,112,112)
        # 升维,放大
        features = self.lrelu(self.upsample1(self.upconv2(features))) # features (B,96,224,224)
        # swin Transformer
        features = self.up_sample_2(features) # features (B,96,224,224)
        # 放大，降维
        features = self.lrelu(self.upsample2(self.upconv3(features))) # features (B,64,448,448)
        # 降维
        features = self.lrelu(self.convHR(features)) # features (B,64,448,448)
        
        output = self.convLast(features)  # features (B,1,448,448)
        # upsample 
        return output

In [13]:
model = M3CAM().to(device)

## raw 2 rgb processing

In [14]:
def process_raw(image,meta_info):
    image = image.clamp(0.0, 1.0)
    image = image.permute(1,2,0)
    image = image.numpy()
    def f(x):
        return x / 4
    black_level = np.array(meta_info['black_level'][0:2] + meta_info['black_level'][-1:])
    
    
    image = np.rint(f(image*1024)).astype(np.uint8)
    black_level = np.rint(f(np.array(black_level, dtype=np.uint8)))
    max_image = meta_info["max_image"]
    min_image = meta_info["min_image"]

    image = np.clip(np.rint((image - min_image) / (max_image - min_image) * 255.0), 0.0, 255.0).astype(np.uint8)
    black_level = np.clip(np.rint((black_level - min_image) / (max_image - min_image) * 255.0), 0.0, 255.0).astype(np.uint8)
    image -= black_level[0]
    image = image.astype(np.uint8)
    im_dem_np = cv2.cvtColor(image, cv2.COLOR_BayerRG2BGR)
    image = im_dem_np / 255.0
    wb = meta_info['wb']
    ave = meta_info["ave"]
    mid = meta_info["mid"]
    bright = meta_info["bright"]


    gains = np.array([wb[0],wb[1],wb[2]]) * bright
    image[:,:,0] = image[:,:,0]  * gains[0]
    image[:,:,1] = image[:,:,1]  * gains[1]
    image[:,:,2] = image[:,:,2]  * gains[2]
    image[image<=0.0] = 0.0
    image[image>=1.0] = 1.0
    image[image<=1e-8] = 1e-8
    image = image**(1.0/2.2)
    image = 3 * image**2 - 2 * image**3
    def color_scale_display(image, input_, output_):
        shadow, midtones, highlight = input_
        outShadow, outHighlight = output_
        diff = highlight - shadow
        imageDiff = np.maximum(image - shadow, 0.0)
        clImage = np.power(imageDiff / diff, 1. / midtones)
        outImage = clImage * (outHighlight - outShadow) + outShadow
        image = outImage
        image[image <= 0.0] = 0.0
        image[image >= 1.0] = 1.0
        return image

    im_rgb = np.zeros_like(image)
    im_rgb[:, :, 0] = color_scale_display(image[:, :, 0], [16.0/255., 1., 255./255.], [0., 1.])
    im_rgb[:, :, 1] = color_scale_display(image[:, :, 1], [15.0/255., 1., 221./255.], [0., 1.])
    im_rgb[:, :, 2] = color_scale_display(image[:, :, 2], [15.0/255., 1., 215./255.], [0., 1.])

    im_rgb = torch.from_numpy(im_rgb).float().permute(2, 0, 1)
    return im_rgb
class UnNormalize(object):
    def __init__(self,mean,std):
        self.mean=mean
        self.std=std
 
    def __call__(self,tensor):
        """
        Args:
        :param tensor: tensor image of size (B,C,H,W) to be un-normalized
        :return: UnNormalized image
        """
        res = torch.zeros_like(tensor).to(device)
        for i,d  in enumerate(zip(tensor,self.mean,self.std)):
            t, m, s = d
            res[i] = t.mul(s).add(m)
        return res
mean_list = (0.5,)
std_list = (0.5,)
unnorm = UnNormalize(mean_list, std_list)

## Evaluate

In [15]:
class Actor():
    def __init__(self, net, objective, loss_weight=None):
        super().__init__()
        if loss_weight is None:
            loss_weight = {'rgb': 1.0}
        self.loss_weight = loss_weight
        self.net = net
        self.objective = objective

    def __call__(self, data, isTrain):
        pred = self.net(data['burst'],data['flows'])
        pred = pred.clamp(0.0,1.0)
        loss_rgb_raw = self.objective['rgb'](pred, data['frame_gt'])
        loss_rgb = self.loss_weight['rgb'] * loss_rgb_raw
        loss = loss_rgb
        stats = {'Loss/rgb': loss_rgb.item(),
                 'Loss/raw/rgb': loss_rgb_raw.item()}
        if isTrain:
            pass
        else:
            pred_process = pred.detach().clone()
            pred_process = pred_process.clamp(0.0,1.0)
            pred_process = process_raw(pred_process[0].to('cpu'),data['meta_info']).to(device)
            frame_gt = data['frame_gt']
            frame_gt = process_raw(frame_gt[0].to('cpu'),data['meta_info']).to(device)
            if 'psnr' in self.objective.keys():
                psnr =  self.objective['psnr'](pred_process, frame_gt)
                stats['Stat/psnr'] = psnr.item()
            if 'ssim' in self.objective.keys():
                ssim =  self.objective['ssim'](pred_process, frame_gt, valid=None)
                stats['Stat/ssim'] = 1 - ssim.item()
            if 'lpips' in self.objective.keys():
                lpips =  self.objective['lpips'](pred_process, frame_gt, valid=None)
                stats['Stat/lpips'] = lpips.item()
        stats['Loss/total'] = loss.item()
        return loss, stats

In [16]:
criterion = {'rgb': PixelWiseError(metric='l1', boundary_ignore=56).to(device), 'psnr': PSNR(boundary_ignore=56).to(device),'ssim': SSIM(boundary_ignore=56).to(device),'lpips':LPIPS(boundary_ignore=56).to(device)}

Setting up [LPIPS] perceptual loss: trunk [alex], v[0.1], spatial [off]
Loading model from: D:\Users\22496\anaconda3\envs\pytorch\lib\site-packages\lpips\weights\v0.1\alex.pth


In [17]:
model.load_state_dict(torch.load("./models/SR.pth"))

<All keys matched successfully>

In [18]:
loss_weight = {'rgb': 1.0}
actor = Actor(net=model, objective=criterion, loss_weight=loss_weight)

In [20]:
Loss = 0.0
psnr = 0.0
ssim = 0.0
lpips = 0.0
model.eval()
with torch.no_grad():
    for i in tqdm(range(len(testDataset))):
        data = testDataset[i]
        data['burst'] = data['burst'].unsqueeze(0).to(device)
        data['flows'] = data['flows'].unsqueeze(0).to(device)
        data['frame_gt'] = data['frame_gt'].unsqueeze(0).to(device)
        loss, stats = actor(data,False)
        psnr += stats['Stat/psnr']
        ssim += stats['Stat/ssim']
        lpips += stats['Stat/lpips']
        Loss += loss.item()
Loss /= len(testDataset)
psnr /= len(testDataset)
ssim /= len(testDataset)
lpips /= len(testDataset)
print("Loss:{}，PSNR:{},SSIM:{},LPIPS:{}".format(Loss,psnr,ssim,lpips))

  0%|          | 0/52 [00:00<?, ?it/s]

Loss:0.006316702822984483，PSNR:38.102121573228104,SSIM:0.917108797110044,LPIPS:0.069236967802191
