In [1]:
import argparse
from collections import OrderedDict
from glob import glob
import albumentations as albu
from albumentations.pytorch import ToTensorV2
from torch.utils.data import DataLoader
import torch
from torch import nn as nn
import torchvision
from PIL import Image, ImageDraw, ImageFont
import numpy as np
from rich.progress import track
from src.utils.draw import pasteImages, give_colors_to_mask
from src.data.components.grass import Grass
from src.utils.draw import pasteImages, give_colors_to_mask
from src.data.components.grass import Grass
from src.models.components.farseg import Farseg
from src.models.components.fcn import FCN
from src.models.components.linknet import Linknet
from src.models.components.pspnet import PSPNet
from src.models.components.unet_plus_plus import UnetPlusPlus
from src.models.components.pan import PAN
from src.models.components.unet import Unet
from src.models.components.deeplabv3plus import DeepLabV3Plus
from src.models.components.manet import MAnet
from src.models.components.fpn import FPN
from src.models.components.deeplabv3 import DeepLabV3



In [2]:
device = "cuda:1"

In [3]:
models = OrderedDict(
            {
                "linknet-timm-resnest101e": Linknet(encoder_name="timm-resnest101e").to(
                    device
                ),
                "fpn-timm-regnetx_320": FPN(encoder_name="timm-regnetx_320").to(
                    device
                ),
                "manet-se_resnext101_32x4d": MAnet(
                    encoder_name="se_resnext101_32x4d"
                ).to(device),
                "deeplabv3plus-timm-efficientnet-l2": DeepLabV3Plus(
                    encoder_name="timm-efficientnet-l2",encoder_weights="noisy-student-475"
                ).to(device),
                "farseg_resnet50": Farseg(backbone="resnet50").to(device),
                "unet-timm-efficientnet-l2": Unet(
                    encoder_name="timm-efficientnet-l2",encoder_weights="noisy-student-475"
                ).to(device),
                "pan-se_resnext101_32x4d": PAN(encoder_name="se_resnext101_32x4d").to(
                    device
                ),
                "unet_plus_plus-se_resnext101_32x4d": UnetPlusPlus(
                    encoder_name="se_resnext101_32x4d"
                ).to(device),
                "fcn-resnet50": FCN(weights="resnet50",num_classes=6).to(device),
                "pspnet-timm-efficientnet-l2": PSPNet(
                    encoder_name="timm-efficientnet-l2",encoder_weights="noisy-student-475"
                ).to(device),
                "deeplabv3-resnet152": DeepLabV3(encoder_name="resnet152").to(
                    device
                ),
            }
        )

In [4]:
def load_state_dict(filename: str):
    ckpt = torch.load(filename, map_location=device)
    state_dict = {}
    for k, v in ckpt["state_dict"].items():
        state_dict[k[4:]] = v
    return state_dict

In [5]:
for name, model in models.items():
    filename = glob(f"../logs/grasseg/{name}/*/checkpoints/*epoch*.ckpt")[0]
    model.load_state_dict(load_state_dict(filename))
    model.eval()

In [6]:
import tifffile
import numpy as np

from skimage import exposure
import cv2

# # 读取图像
# image = tifffile.imread('rgb_point1.tif')

# # 对各通道分别进行直方图均衡化
# equalized_image = np.zeros_like(image)
# for i in range(3):
#     equalized_image[..., i] = exposure.equalize_hist(image[..., i]) * 255

# image1 = equalized_image.astype(np.uint8)

# # 目标尺寸
# target_height, target_width = 256, 256

# # 计算需要补零的上下左右数量
# pad_top = (target_height - image.shape[0]) // 2
# pad_bottom = target_height - image.shape[0] - pad_top
# pad_left = (target_width - image.shape[1]) // 2
# pad_right = target_width - image.shape[1] - pad_left

# # 使用 numpy 的 pad 函数进行四周补零
# image1 = np.pad(image1, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0)

In [7]:
# print(image1.shape,np.min(image1),np.max(image1))

In [8]:
# image1.dtype

In [9]:
# def read_image(filename:str):
#     image = tf.imread(filename)
#     image = (image - np.min(image)) / (np.max(image) - np.min(image))
#     image = (image * 255).astype(np.uint8)
#     return Image.fromarray(image)

In [10]:
# image1 = tf.imread("rgb_point1.tif")
# print(image1.shape,np.min(image1),np.max(image1))

In [11]:
# from matplotlib import pyplot as plt
# plt.imshow(image1)

In [12]:
def read_image(filename:str):

    # 读取图像
    image = tifffile.imread(filename)

    # 对各通道分别进行直方图均衡化
    equalized_image = np.zeros_like(image)
    for i in range(3):
        equalized_image[..., i] = exposure.equalize_hist(image[..., i]) * 255

    image = equalized_image.astype(np.uint8)

    # 目标尺寸
    target_height, target_width = 256, 256

    # 计算需要补零的上下左右数量
    pad_top = (target_height - image.shape[0]) // 2
    pad_bottom = target_height - image.shape[0] - pad_top
    pad_left = (target_width - image.shape[1]) // 2
    pad_right = target_width - image.shape[1] - pad_left
    # 使用 numpy 的 pad 函数进行四周补零
    image = np.pad(image, ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), mode='constant', constant_values=0)
    return image

In [14]:
pad_top = (256 - 12) // 2
pad_bottom = 256 - 12 - pad_top
pad_left = (256 - 12) // 2
pad_right = 256 - 12 - pad_left

In [15]:
@torch.no_grad()
def inference(model:nn.Module,img: torch.Tensor) -> np.ndarray:
    logits = model(img)
    preds = torch.argmax(logits, dim=1).squeeze().cpu().numpy()
    return preds

In [17]:
def compute(mask:np.ndarray,pix:int,ratio:float):
    mask_pix = np.sum(mask==pix)
    num = mask_pix / 144
    return ratio * num


def get_mask(model:nn.Module,image: np.ndarray,model_name:str,image_name) -> None:
    img = (image - np.min(image)) / (np.max(image) - np.min(image))
    img = torch.from_numpy(img).unsqueeze(0).float().permute(0, 3, 1, 2).to("cuda:1")
    mask = inference(model,img)
    color_mask = give_colors_to_mask(image,mask)
    restored_image = image[pad_top:pad_top+12, pad_left:pad_left+12, :]
    color_mask = color_mask[pad_top:pad_top+12, pad_left:pad_left+12, :]
    paste_image = pasteImages([restored_image,color_mask])
    show_image = Image.fromarray(paste_image)
    show_image.save(f"./results/{image_name}/{model_name}.png",dpi=(300,300))

    mask_clip = mask[pad_top:pad_top+12, pad_left:pad_left+12]

    res = [compute(mask_clip,5,0.9),
    compute(mask_clip,4,0.6),
    compute(mask_clip,3,0.3),
    compute(mask_clip,2,0.15),
    compute(mask_clip,1,0.05)]
    print(f"{model_name}:植被覆盖度:{sum(res)}")

    return show_image

In [18]:
from glob import glob
import os
images = glob("./images/*.tif")
for filename in images:
    image = read_image(filename)
    image_name = filename.split(os.path.sep)[-1].split(".")[0]
    print(f"{filename} is processing...")
    os.makedirs(f"./results/{image_name}",exist_ok=True)
    for model_name,model in models.items():
        get_mask(model,image,model_name,image_name)

./images/rgb_point4.tif is processing...
linknet-timm-resnest101e:植被覆盖度:0.04027777777777778
fpn-timm-regnetx_320:植被覆盖度:0.08541666666666667
manet-se_resnext101_32x4d:植被覆盖度:0.08750000000000001
deeplabv3plus-timm-efficientnet-l2:植被覆盖度:0.04097222222222222
farseg_resnet50:植被覆盖度:0.0
unet-timm-efficientnet-l2:植被覆盖度:0.5211805555555555
pan-se_resnext101_32x4d:植被覆盖度:0.0
unet_plus_plus-se_resnext101_32x4d:植被覆盖度:0.011458333333333334
fcn-resnet50:植被覆盖度:0.23888888888888887
pspnet-timm-efficientnet-l2:植被覆盖度:0.0
deeplabv3-resnet152:植被覆盖度:0.5572916666666665
./images/rgb_point3.tif is processing...
linknet-timm-resnest101e:植被覆盖度:0.07361111111111111
fpn-timm-regnetx_320:植被覆盖度:0.196875
manet-se_resnext101_32x4d:植被覆盖度:0.2829861111111111
deeplabv3plus-timm-efficientnet-l2:植被覆盖度:0.5517361111111111
farseg_resnet50:植被覆盖度:0.08541666666666667
unet-timm-efficientnet-l2:植被覆盖度:0.9
pan-se_resnext101_32x4d:植被覆盖度:0.0
unet_plus_plus-se_resnext101_32x4d:植被覆盖度:0.04583333333333334
fcn-resnet50:植被覆盖度:0.08541666666666667
ps

In [None]:
# 1 deeplabv3plus-timm-efficientnet-l2
# 