# 可视化光流

In [None]:
from torchvision.utils import flow_to_image
from einops import rearrange
import torchvision.transforms.functional as tf
def flow2img(flow,name):
    """
    flow: h w 2
    """
    flow = rearrange(flow,"h w c -> c h w")
    flow_im = flow_to_image(flow)
    image = tf.to_pil_image(flow_im)
    image.save(name)

# 可视化模型特征

In [None]:
"""
Reference:
+ https://blog.csdn.net/weixin_45826022/article/details/118830531
"""

import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as f
from torchvision import transforms
import torchvision.transforms.functional as tf
import numpy as np 
from PIL import Image
from collections import OrderedDict
import cv2
from einops import rearrange
from archs.MIMOUNetV1 import MIMOUNet

def read_img(img_path):
    """
    return:
        torch.Tensor: b c h w
    """
    img = cv2.imread(img_path)
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB).astype("float32")
    img = torch.from_numpy(img.transpose(2,0,1)).float()/255.0
    return img.unsqueeze(0)

def tensor2img(tensor,name):
    """
    tensor:
        c h w
    """
    image = tf.to_pil_image(tensor)
    image.save(name)

def show_feature_gray(feature, name):
    """
    feature:
        c h w
    """
    feature = rearrange(feature[0], "(c1 c2) h w -> (c1 h) (c2 w)", c1=16)
    feature = feature.unsqueeze(0)
    print(feature.shape)
    tensor2img(feature,name)

def load_feature(model_path, img_path):
    state_dict = torch.load(model_path)
    model = MIMOUNet([4,8,12],[32,64,128],True)
    model.load_state_dict(state_dict['params'])
    model.eval()
    input_image = read_img(img_path)
    feature = model(input_image)
    return feature

if __name__ == "__main__":
    # feature = torch.randn(64,256,256)
    # show_feature_gray(feature, "test.png")
    weight_path = "weights/MIMOV1_600000.pth"
    img_path = "testmini/blur/1.png"
    feature_list = load_feature(weight_path, img_path)
    for feature in feature_list:
        for key,val in feature.items():
            feature_path = f"feature_show/1/{key}.png"
            feauter_img = val
            show_feature_gray(feauter_img, feature_path)
