In [None]:
#  load weight
from networks.vnet import VNet
import torch

net = VNet(
    n_channels=1, n_classes=2, normalization="batchnorm", 
)


path_weight = "../model/bin_no_1/AdamW/15k/e0.5/t0.75/no_cutmix/s1_to_s2/ker7_off15/best_model.pth"
check_point = torch.load(path_weight, weights_only=True)
net.load_state_dict(check_point)

net.cuda()

In [2]:
import numpy as np
import cv2


def open_morphology(img, kernel_size=5, iterations=2):
    kernel = np.ones((kernel_size, kernel_size), np.uint8)
    img = cv2.morphologyEx(img, cv2.MORPH_OPEN, kernel, iterations=iterations)
    return img
    
    
def find_roi(outs, erode=False, out_offset=10, kernel_size=5, iterations=2):
    # 找x方向的最大有值的位置
    for i in range(outs.shape[1]):
        if not erode:
            if outs[:,i,...].sum() > 0:
                x_min = i
                break
        else:
            tmp_img = outs[:,i,...].squeeze()
            # do erode
            tmp_img = tmp_img.astype(np.uint8)
            tmp_img = open_morphology(tmp_img, kernel_size=kernel_size, iterations=iterations)
            if tmp_img.sum() > 0:
                x_min = i
                break
            

    for i in range(outs.shape[1]):
        if not erode:
            if outs[:,outs.shape[1]-1-i,...].sum() > 0:
                x_max = outs.shape[1]-1-i
                break
        else:
            tmp_img = outs[:,outs.shape[1]-1-i,...].squeeze()
            # do erode
            tmp_img = tmp_img.astype(np.uint8)
            tmp_img = open_morphology(tmp_img, kernel_size=kernel_size, iterations=iterations)
            if tmp_img.sum() > 0:
                x_max = outs.shape[1]-1-i
                break
        
    # 找y方向的最大有值的位置
    for i in range(outs.shape[2]):
        if not erode:
            if outs[:,:,i,...].sum() > 0:
                y_min = i
                break
        else:
            tmp_img = outs[:,:,i,...].squeeze()
            # do erode
            tmp_img = tmp_img.astype(np.uint8)
            tmp_img = open_morphology(tmp_img, kernel_size=kernel_size, iterations=iterations)
            if tmp_img.sum() > 0:
                y_min = i
                break
        
    for i in range(outs.shape[2]):
        if not erode:
            if outs[:,:,outs.shape[2]-1-i,...].sum() > 0:
                y_max = outs.shape[2]-1-i
                break
        else:
            tmp_img = outs[:,:,outs.shape[2]-1-i,...].squeeze()
            # do erode
            tmp_img = tmp_img.astype(np.uint8)
            tmp_img = open_morphology(tmp_img, kernel_size=kernel_size, iterations=iterations)
            if tmp_img.sum() > 0:
                y_max = outs.shape[2]-1-i
                break
        
    # 找z方向的最大有值的位置
    for i in range(outs.shape[3]):
        if not erode:
            if outs[:,:,:,i,...].sum() > 0:
                z_min = i
                break
        else:
            tmp_img = outs[:,:,:,i,...].squeeze()
            # do erode
            tmp_img = tmp_img.astype(np.uint8)
            tmp_img = open_morphology(tmp_img, kernel_size=kernel_size, iterations=iterations)
            if tmp_img.sum() > 0:
                z_min = i
                break
        
    for i in range(outs.shape[3]):
        if not erode:
            if outs[:,:,:,outs.shape[3]-1-i,...].sum() > 0:
                z_max = outs.shape[3]-1-i
                break
        else:
            tmp_img = outs[:,:,:,outs.shape[3]-1-i,...].squeeze()
            # do erode
            tmp_img = tmp_img.astype(np.uint8)
            tmp_img = open_morphology(tmp_img, kernel_size=kernel_size, iterations=iterations)
            if tmp_img.sum() > 0:
                z_max = outs.shape[3]-1-i
                break
    
    xy = 20
    x_min = max(0, x_min-out_offset - xy)
    x_max = min(outs.shape[1], x_max + out_offset + xy)
    y_min = max(0, y_min-out_offset - xy)
    y_max = min(outs.shape[2], y_max+out_offset + xy)
    z_min = max(0, z_min-out_offset -xy)
    z_max = min(outs.shape[3], z_max+out_offset+xy)
    return (x_min, x_max, y_min, y_max, z_min, z_max)

In [4]:
import torch

torch.cuda.empty_cache()

In [None]:
outs.shape

In [None]:
(x_min, x_max, y_min, y_max, z_min, z_max) = find_roi(outs)
(x_min_e, x_max_e, y_min_e, y_max_e, z_min_e, z_max_e) = find_roi(outs, erode=True, kernel_size=12)

print(x_min, x_max, y_min, y_max, z_min, z_max)
print(x_min_e, x_max_e, y_min_e, y_max_e, z_min_e, z_max_e)

In [9]:
roi_img = img[0,0,x_min:x_max+1,y_min:y_max+1,z_min:z_max+1].cpu().detach().numpy()
roi_mask = outs[0,x_min:x_max+1,y_min:y_max+1,z_min:z_max+1]
roi_mask_e = outs[0,x_min_e:x_max_e+1,y_min_e:y_max_e+1,z_min_e:z_max_e+1]

In [None]:
import matplotlib.pyplot as plt
from monai.visualize import blend_images, matshow3d

matshow3d(roi_mask_e, fig=None,
    title="input image",
    figsize=(50, 50),
    every_n=3,
    frame_dim=-1,
    show=True,
    cmap="gray",
)

In [None]:
import os
import h5py
import numpy as np
from tqdm import tqdm
from monai.visualize import blend_images, matshow3d
lab_path = "/alls_data_no/lab/"

for img in tqdm(os.listdir(lab_path)):
    img_path = os.path.join(lab_path, img)
    

    with h5py.File(img_path, 'a') as h5file:
        ROI_label = h5file['ROI_label'][:]
        # print(img_path)
        matshow3d(ROI_label, fig=None,
                title="input image",
                figsize=(10, 10),
                every_n=3,
                frame_dim=-1,
                show=True,
                cmap="gray",
            )
        
        

In [None]:
import h5py

# 打开H5文件
with h5py.File('MICCAI2024/alls_data_no/unlab/imgs_0001_norm.h5', 'r') as file:
    # 查看文件中所有的主要keys
    keys = list(file.keys())
    print(keys)


In [None]:
# load data
# for img
unlab_path = "data/MICCAI2024/alls_data_no/val/"
import os
import h5py
from scipy.ndimage import binary_erosion
import nibabel as nib
from monai.inferers import SlidingWindowInferer
import h5py
import numpy as np
from scipy.ndimage import binary_erosion
from tqdm import tqdm
from monai.visualize import blend_images, matshow3d

sliding_window_inferer = SlidingWindowInferer(roi_size=(112, 112, 80), sw_batch_size=1, overlap=0.5, device=torch.device("cuda"))
now_device = torch.device("cuda")

for img in tqdm(os.listdir(unlab_path)):
    img_path = unlab_path + img
    with torch.no_grad():
        with h5py.File(img_path, 'a') as h5file:
            img_yuan = h5file['image'][:]
            img = torch.from_numpy(img_yuan)
            img = img.unsqueeze(0).unsqueeze(0).float().to(now_device)
            
            # infer
            outs = sliding_window_inferer(img, net).cpu().detach()
            outs_prob = outs.softmax(dim=1).max(dim=1)[0].numpy()
            outs = outs.numpy()
            outs = outs.argmax(axis=1) * (outs_prob > 0.85)
            outs = outs.astype(np.uint8)
            
            # find roi
            (x_min, x_max, y_min, y_max, z_min, z_max) = find_roi(outs, erode=True, kernel_size=12, iterations=1, out_offset=20)
            
            roi_img = img_yuan[x_min:x_max+1,y_min:y_max+1,z_min:z_max+1]
            roi_mask = outs[0,x_min:x_max+1,y_min:y_max+1,z_min:z_max+1]
            
            mask = [x_min, x_max, y_min, y_max, z_min, z_max]
            
            # delete old data
            # del h5file['ROI_image']
            # del h5file['ROI_label']
            # del h5file['ROI_posi']
            # del h5file['ROI_prob']
            if 'ROI_image' in h5file:
                del h5file['ROI_image']
            if 'ROI_label' in h5file:
                del h5file['ROI_label']
            if 'ROI_posi' in h5file:
                del h5file['ROI_posi']
            if 'ROI_prob' in h5file:
                del h5file['ROI_prob']
            
            h5file.create_dataset('ROI_image', data=roi_img)
            h5file.create_dataset('ROI_label', data=roi_mask)
            h5file.create_dataset('ROI_posi', data=mask)
            h5file.create_dataset('ROI_prob', data=outs_prob)
            
            print(img_path)
            matshow3d(roi_mask, fig=None,
                    title="input image",
                    figsize=(10, 10),
                    every_n=3,
                    frame_dim=-1,
                    show=True,
                    cmap="gray",
                )

In [12]:
# for img in tqdm(os.listdir(unlab_path)):
#     # show roi masks
#     img_path = unlab_path + img
#     with h5py.File(img_path, 'r') as h5file:
#         roi_img = h5file['ROI_image'][:]
#         roi_lab = h5file['ROI_label'][:]
#         mask = h5file['ROI_posi'][:]
        
#         matshow3d(roi_lab, fig=None,
#             title="input image",
#             figsize=(50, 50),
#             every_n=3,
#             frame_dim=-1,
#             show=True,
#             cmap="gray",
#         )

In [None]:
#h5
path = "MICCAI2024/alls_data_no/lab/imgs_0013_norm.h5"
import h5py

with h5py.File(path, "r") as file:
    h5image = file['image'][:]
    h5file = file['ROI_posi'][:]
    print(h5image.shape)
    print(h5file)
    

In [None]:
# load data
# for img
unlab_path = "data/MICCAI2024/alls_data_no/unlab/"
import os
import h5py
from scipy.ndimage import binary_erosion
import nibabel as nib
from monai.inferers import SlidingWindowInferer
import h5py
import numpy as np
from scipy.ndimage import binary_erosion
from tqdm import tqdm
from monai.visualize import blend_images, matshow3d

sliding_window_inferer = SlidingWindowInferer(roi_size=(112, 112, 80), sw_batch_size=1, overlap=0.5, device=torch.device("cuda"))
now_device = torch.device("cuda")

for img in tqdm(os.listdir(unlab_path)[200:]):
    img_path = unlab_path + img
    with torch.no_grad():
        with h5py.File(img_path, 'a') as h5file:
            img_yuan = h5file['image'][:]
            img = torch.from_numpy(img_yuan)
            img = img.unsqueeze(0).unsqueeze(0).float().to(now_device)
            
            # infer
            outs = sliding_window_inferer(img, net).cpu().detach()
            outs_prob = outs.softmax(dim=1).max(dim=1)[0].numpy()
            outs = outs.numpy()
            outs = outs.argmax(axis=1) * (outs_prob > 0.85)
            outs = outs.astype(np.uint8)
            
            # find roi
            try:
                x_min, x_max, y_min, y_max, z_min, z_max = find_roi(outs, erode=True, kernel_size=12, iterations=1, out_offset=20)
            except:
                x_min, x_max, y_min, y_max, z_min, z_max = find_roi(outs, erode=True, kernel_size=5, iterations=1, out_offset=20)

            
            roi_img = img_yuan[x_min:x_max+1,y_min:y_max+1,z_min:z_max+1]
            roi_mask = outs[0,x_min:x_max+1,y_min:y_max+1,z_min:z_max+1]
            
            mask = [x_min, x_max, y_min, y_max, z_min, z_max]
            
            # delete old data
            if 'ROI_image' in h5file:
                del h5file['ROI_image']
            if 'ROI_label' in h5file:
                del h5file['ROI_label']
            if 'ROI_posi' in h5file:
                del h5file['ROI_posi']
            if 'ROI_prob' in h5file:
                del h5file['ROI_prob']
            
            h5file.create_dataset('ROI_image', data=roi_img)
            h5file.create_dataset('ROI_label', data=roi_mask)
            h5file.create_dataset('ROI_posi', data=mask)
            h5file.create_dataset('ROI_prob', data=outs_prob)
            
            print(img_path)
            matshow3d(roi_mask, fig=None,
                    title="input image",
                    figsize=(10, 10),
                    every_n=3,
                    frame_dim=-1,
                    show=True,
                    cmap="gray",
                )

In [None]:
# load data
# for img
unlab_path = "MICCAI2024/alls_data_no/unlab/"
import os
import h5py
from scipy.ndimage import binary_erosion
import nibabel as nib
from monai.inferers import SlidingWindowInferer
import h5py
import numpy as np
from scipy.ndimage import binary_erosion
from tqdm import tqdm
from monai.visualize import blend_images, matshow3d

count = 0

for img in (os.listdir(unlab_path)):
    img_path = unlab_path + img
    count += 1
    with h5py.File(img_path, 'r') as h5file:
        print(count)    

        if  h5file['ROI_image'][:] is  None:
            pass
        