## Generate superpixel-based pseudolabels


### Overview

This is the third step for data preparation

Input: normalized images

Output: pseulabel label candidates for all the images

In [1]:
%reset
%load_ext autoreload
%autoreload 2
# import matplotlib.pyplot as plt
import copy
import skimage

from skimage.segmentation import slic
from skimage.segmentation import mark_boundaries
from skimage.util import img_as_float
from skimage.measure import label 
import scipy.ndimage.morphology as snm
from skimage import io
import argparse
import numpy as np
import glob

import SimpleITK as sitk
import os

to01 = lambda x: (x - x.min()) / (x.max() - x.min())



**Summary**

a. Generate a mask of the patient to avoid pseudolabels of empty regions in the background

b. Generate superpixels as pseudolabels

**Configurations of pseudlabels**

```python
# default setting of minimum superpixel sizes
segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)
# you can also try other configs
segs = seg_func(img[ii, ...], min_size = 100, sigma = 0.8)
```


In [None]:
DATASET_CONFIG = {'SABS':{
                    'img_bname': f'./SABS/sabs_CT_normalized/image_*.nii.gz',
                    'out_dir': './SABS/sabs_CT_normalized',
                    'fg_thresh': 1e-4
                    },
                  'CHAOST2':{
                      'img_bname': f'./CHAOST2/chaos_MR_T2_normalized/image_*.nii.gz',
                      'out_dir': './CHAOST2/chaos_MR_T2_normalized',
                      'fg_thresh': 1e-4 + 50
                    }
                 }
            

DOMAIN = 'CHAOST2'  # 'SABS' or 'CHAOST2'
img_bname = DATASET_CONFIG[DOMAIN]['img_bname']
print(f'img_bname: {img_bname}')
imgs = glob.glob(img_bname)
out_dir = DATASET_CONFIG[DOMAIN]['out_dir']


img_bname: ./SABS/sabs_CT_normalized/image_*.nii.gz


In [3]:
imgs

['./SABS/sabs_CT_normalized/image_1.nii.gz',
 './SABS/sabs_CT_normalized/image_8.nii.gz',
 './SABS/sabs_CT_normalized/image_7.nii.gz',
 './SABS/sabs_CT_normalized/image_18.nii.gz',
 './SABS/sabs_CT_normalized/image_12.nii.gz',
 './SABS/sabs_CT_normalized/image_13.nii.gz',
 './SABS/sabs_CT_normalized/image_25.nii.gz',
 './SABS/sabs_CT_normalized/image_19.nii.gz',
 './SABS/sabs_CT_normalized/image_17.nii.gz',
 './SABS/sabs_CT_normalized/image_22.nii.gz',
 './SABS/sabs_CT_normalized/image_11.nii.gz',
 './SABS/sabs_CT_normalized/image_20.nii.gz',
 './SABS/sabs_CT_normalized/image_28.nii.gz',
 './SABS/sabs_CT_normalized/image_2.nii.gz',
 './SABS/sabs_CT_normalized/image_15.nii.gz',
 './SABS/sabs_CT_normalized/image_3.nii.gz',
 './SABS/sabs_CT_normalized/image_21.nii.gz',
 './SABS/sabs_CT_normalized/image_24.nii.gz',
 './SABS/sabs_CT_normalized/image_14.nii.gz',
 './SABS/sabs_CT_normalized/image_16.nii.gz',
 './SABS/sabs_CT_normalized/image_29.nii.gz',
 './SABS/sabs_CT_normalized/image_9.nii

In [4]:
imgs = sorted(imgs, key = lambda x: int(x.split('_')[-1].split('.nii.gz')[0]) )

In [5]:
imgs

['./SABS/sabs_CT_normalized/image_0.nii.gz',
 './SABS/sabs_CT_normalized/image_1.nii.gz',
 './SABS/sabs_CT_normalized/image_2.nii.gz',
 './SABS/sabs_CT_normalized/image_3.nii.gz',
 './SABS/sabs_CT_normalized/image_4.nii.gz',
 './SABS/sabs_CT_normalized/image_5.nii.gz',
 './SABS/sabs_CT_normalized/image_6.nii.gz',
 './SABS/sabs_CT_normalized/image_7.nii.gz',
 './SABS/sabs_CT_normalized/image_8.nii.gz',
 './SABS/sabs_CT_normalized/image_9.nii.gz',
 './SABS/sabs_CT_normalized/image_10.nii.gz',
 './SABS/sabs_CT_normalized/image_11.nii.gz',
 './SABS/sabs_CT_normalized/image_12.nii.gz',
 './SABS/sabs_CT_normalized/image_13.nii.gz',
 './SABS/sabs_CT_normalized/image_14.nii.gz',
 './SABS/sabs_CT_normalized/image_15.nii.gz',
 './SABS/sabs_CT_normalized/image_16.nii.gz',
 './SABS/sabs_CT_normalized/image_17.nii.gz',
 './SABS/sabs_CT_normalized/image_18.nii.gz',
 './SABS/sabs_CT_normalized/image_19.nii.gz',
 './SABS/sabs_CT_normalized/image_20.nii.gz',
 './SABS/sabs_CT_normalized/image_21.nii.gz'

In [6]:
MODE = 'MIDDLE' # minimum size of pesudolabels. 'MIDDLE' is the default setting
from skimage.segmentation import slic

# wrapper for process 3d image in 2d
def superpix_vol(img, method = 'fezlen', **kwargs):
    """
    loop through the entire volume
    assuming image with axis z, x, y
    """
    if method =='fezlen':
        seg_func = skimage.segmentation.felzenszwalb
    else:
        raise NotImplementedError
        
    out_vol = np.zeros(img.shape)
    for ii in range(img.shape[0]):
        if MODE == 'MIDDLE':
#             print("type of img: ", type(img[ii, ...]))
#             print("img shape: ", img[ii, ...].shape)
#             print("img min: ", img[ii, ...].min())
#             print("img max: ", img[ii, ...].max())
            # print("img mean: ", img[ii, ...].mean())
#             segs = seg_func(img[ii, ...], min_size = 400, sigma = 1)
            # segs = slic(
                # img[ii, ...],
                # n_segments=25,         # Số lượng superpixels mong muốn
                # compactness=0.1,       # Điều chỉnh giữa màu sắc và không gian
                # max_num_iter=7,        # Số lần lặp tối đa trong thuật toán
                # sigma=0.3,              # Làm mượt ảnh trước khi phân đoạn
                # spacing=None,           # Dùng cho ảnh 3D, xác định kích thước pixel (thường bỏ qua)
                # multichannel=True,      # Ảnh màu = True, ảnh grayscale = False
                # convert2lab=None,       # Chuyển ảnh RGB sang Lab để phân đoạn tốt hơn
                # enforce_connectivity=True, # Bắt buộc các vùng phải liên thông
                # min_size_factor=0.1,    # Nếu enforce_connectivity=True, ngưỡng nhỏ nhất cho vùng nhỏ
                # max_size_factor=3,      # Ngưỡng lớn nhất cho vùng nhỏ
                # slic_zero=False,        # Dùng thuật toán cải tiến SLIC-zero hay không
                # start_label=0,
                # mask=None,
                # channel_axis=None
                # channel_axis=None
    # )     
            segs = slic(img[ii, ...], n_segments=25, compactness=0.1, max_num_iter=7, sigma=0.3, enforce_connectivity=True, min_size_factor=0.1, max_size_factor=3, start_label=1, channel_axis=None)
        
#             print("segs shape: ", segs.shape)
            print("segs min: ", segs.min())
            print("segs max: ", segs.max())
            print("unique segs: ", len(np.unique(segs)))
        else:
            raise NotImplementedError
        out_vol[ii, ...] = segs
        
    return out_vol

# thresholding the intensity values to get a binary mask of the patient
def fg_mask2d(img_2d, thresh): # change this by your need
    mask_map = np.float32(img_2d > thresh)
    
    def getLargestCC(segmentation): # largest connected components
        labels = label(segmentation)
        assert( labels.max() != 0 ) # assume at least 1 CC
        largestCC = labels == np.argmax(np.bincount(labels.flat)[1:])+1
        return largestCC
    if mask_map.max() < 0.999:
        return mask_map
    else:
        post_mask = getLargestCC(mask_map)
        fill_mask = snm.binary_fill_holes(post_mask)
    return fill_mask

# remove superpixels within the empty regions
def superpix_masking(raw_seg2d, mask2d):
    raw_seg2d = np.int32(raw_seg2d)
    lbvs = np.unique(raw_seg2d)
    max_lb = lbvs.max()
    raw_seg2d[raw_seg2d == 0] = max_lb + 1
    lbvs = list(lbvs)
    lbvs.append( max_lb )
    raw_seg2d = raw_seg2d * mask2d
    lb_new = 1
    out_seg2d = np.zeros(raw_seg2d.shape)
    for lbv in lbvs:
        if lbv == 0:
            continue
        else:
            out_seg2d[raw_seg2d == lbv] = lb_new
            lb_new += 1
    
    return out_seg2d
            
def superpix_wrapper(img, verbose = False, fg_thresh = 1e-4):
    raw_seg = superpix_vol(img)
    fg_mask_vol = np.zeros(raw_seg.shape)
    processed_seg_vol = np.zeros(raw_seg.shape)
    for ii in range(raw_seg.shape[0]):
        if verbose:
            print("doing {} slice".format(ii))
        _fgm = fg_mask2d(img[ii, ...], fg_thresh )
        _out_seg = superpix_masking(raw_seg[ii, ...], _fgm)
        fg_mask_vol[ii] = _fgm
        processed_seg_vol[ii] = _out_seg
    return fg_mask_vol, processed_seg_vol
        
# copy spacing and orientation info between sitk objects
def copy_info(src, dst):
    dst.SetSpacing(src.GetSpacing())
    dst.SetOrigin(src.GetOrigin())
    dst.SetDirection(src.GetDirection())
    # dst.CopyInfomation(src)
    return dst


def strip_(img, lb):
    img = np.int32(img)
    if isinstance(lb, float):
        lb = int(lb)
        return np.float32(img == lb) * float(lb)
    elif isinstance(lb, list):
        out = np.zeros(img.shape)
        for _lb in lb:
            out += np.float32(img == int(_lb)) * float(_lb)
            
        return out
    else:
        raise Exception

In [7]:
# Generate pseudolabels for every image and save them
for img_fid in imgs:
# img_fid = imgs[0]

    idx = os.path.basename(img_fid).split("_")[-1].split(".nii.gz")[0]
    im_obj = sitk.ReadImage(img_fid)

    out_fg, out_seg = superpix_wrapper(sitk.GetArrayFromImage(im_obj), fg_thresh = DATASET_CONFIG[DOMAIN]['fg_thresh'] )
    out_fg_o = sitk.GetImageFromArray(out_fg ) 
    out_seg_o = sitk.GetImageFromArray(out_seg )

    out_fg_o = copy_info(im_obj, out_fg_o)
    out_seg_o = copy_info(im_obj, out_seg_o)
    seg_fid = os.path.join(out_dir, f'superpix-{MODE}_{idx}.nii.gz')
    msk_fid = os.path.join(out_dir, f'fgmask_{idx}.nii.gz')
    sitk.WriteImage(out_fg_o, msk_fid)
    sitk.WriteImage(out_seg_o, seg_fid)
    print(f'image with id {idx} has finished')


segs min:  1
segs max:  29
unique segs:  29
segs min:  1
segs max:  29
unique segs:  29
segs min:  1
segs max:  29
unique segs:  29
segs min:  1
segs max:  33
unique segs:  33
segs min:  1
segs max:  33
unique segs:  33
segs min:  1
segs max:  35
unique segs:  35
segs min:  1
segs max:  35
unique segs:  35
segs min:  1
segs max:  39
unique segs:  39
segs min:  1
segs max:  40
unique segs:  40
segs min:  1
segs max:  38
unique segs:  38
segs min:  1
segs max:  39
unique segs:  39
segs min:  1
segs max:  39
unique segs:  39
segs min:  1
segs max:  35
unique segs:  35
segs min:  1
segs max:  34
unique segs:  34
segs min:  1
segs max:  37
unique segs:  37
segs min:  1
segs max:  37
unique segs:  37
segs min:  1
segs max:  35
unique segs:  35
segs min:  1
segs max:  38
unique segs:  38
segs min:  1
segs max:  35
unique segs:  35
segs min:  1
segs max:  34
unique segs:  34
segs min:  1
segs max:  31
unique segs:  31
segs min:  1
segs max:  30
unique segs:  30
segs min:  1
segs max:  30
uniqu

  fill_mask = snm.binary_fill_holes(post_mask)


image with id 0 has finished
segs min:  1
segs max:  35
unique segs:  35
segs min:  1
segs max:  35
unique segs:  35
segs min:  1
segs max:  30
unique segs:  30
segs min:  1
segs max:  33
unique segs:  33
segs min:  1
segs max:  32
unique segs:  32
segs min:  1
segs max:  30
unique segs:  30
segs min:  1
segs max:  29
unique segs:  29
segs min:  1
segs max:  30
unique segs:  30
segs min:  1
segs max:  31
unique segs:  31
segs min:  1
segs max:  31
unique segs:  31
segs min:  1
segs max:  34
unique segs:  34
segs min:  1
segs max:  33
unique segs:  33
segs min:  1
segs max:  33
unique segs:  33
segs min:  1
segs max:  33
unique segs:  33
segs min:  1
segs max:  33
unique segs:  33
segs min:  1
segs max:  33
unique segs:  33
segs min:  1
segs max:  32
unique segs:  32
segs min:  1
segs max:  32
unique segs:  32
segs min:  1
segs max:  31
unique segs:  31
segs min:  1
segs max:  27
unique segs:  27
segs min:  1
segs max:  27
unique segs:  27
segs min:  1
segs max:  25
unique segs:  25
seg

In [8]:
import torch
from collections import OrderedDict # Cần thiết nếu state_dict là OrderedDict

pretrained_path = "/root/ducnt/fewshot_medical_segmentor/exps/myexperiments_MIDDLE_0/mySSL_train_CHAOST2_Superpix_lbgroup0_scale_MIDDLE_vfold0_CHAOST2_Superpix_sets_0_1shot/31/snapshots/slic_res101_alp_attention_epoch_15000.pth" # Thay đổi đường dẫn

try:
    # Load checkpoint vào CPU
    checkpoint = torch.load(pretrained_path, map_location='cpu')

    # Xác định state_dict thực sự
    if 'state_dict' in checkpoint:
        state_dict = checkpoint['state_dict']
        print(f"--- Keys inside 'state_dict' from {pretrained_path} ---")
    elif isinstance(checkpoint, (dict, OrderedDict)): # Có thể file .pth chỉ chứa state_dict
        state_dict = checkpoint
        print(f"--- Keys directly from {pretrained_path} ---")
    else:
        print(f"Error: Unsupported checkpoint format in {pretrained_path}")
        state_dict = None

    if state_dict:
        # In tất cả các key
        key_list = list(state_dict.keys())
        print(f"Total keys found: {len(key_list)}")
        for i, key in enumerate(key_list):
            # Lấy shape của tensor tương ứng
            tensor_shape = state_dict[key].shape
            print(f"{i+1:4d}: {key} (Shape: {tensor_shape})")

        # Hoặc chỉ in các key bắt đầu bằng một tiền tố cụ thể
        print("\n--- Keys starting with 'encoder.' ---")
        encoder_keys = [k for k in state_dict.keys() if k.startswith("encoder.")]
        for i, key in enumerate(encoder_keys):
            tensor_shape = state_dict[key].shape
            print(f"{i+1:4d}: {key} (Shape: {tensor_shape})")

        print("\n--- Keys starting with 'cls_unit.' ---")
        cls_unit_keys = [k for k in state_dict.keys() if k.startswith("cls_unit.")]
        for i, key in enumerate(cls_unit_keys):
             tensor_shape = state_dict[key].shape
             print(f"{i+1:4d}: {key} (Shape: {tensor_shape})")

except FileNotFoundError:
    print(f"Error: File not found at {pretrained_path}")
except Exception as e:
    print(f"An error occurred: {e}")

Error: File not found at /root/ducnt/fewshot_medical_segmentor/exps/myexperiments_MIDDLE_0/mySSL_train_CHAOST2_Superpix_lbgroup0_scale_MIDDLE_vfold0_CHAOST2_Superpix_sets_0_1shot/31/snapshots/slic_res101_alp_attention_epoch_15000.pth
