In [1]:
# 標準ライブラリ
import os
import sys
import glob
import time
import copy

# データ操作・数値計算ライブラリ
import numpy as np
import cv2
import scipy.ndimage
from scipy import signal
from scipy.ndimage import median_filter

# PyTorch関連
import torch
from torch.nn import functional as F
from torchsummary import summary

# 天文学・画像処理ライブラリ
import astropy.io.fits
from astropy.stats import sigma_clipped_stats
from photutils.detection import DAOStarFinder
import matplotlib.pyplot as plt

# W&B (Weights & Biases)
import wandb

# NpyAppendArray
from npy_append_array import NpyAppendArray

# tqdm (プログレスバー)
from tqdm import tqdm
import warnings
warnings.resetwarnings()
warnings.simplefilter("ignore")

In [2]:
# プロジェクト固有モジュール
# ユーザーの環境に合わせてパスを設定
sys.path.append('/home/elmegreen/galactic_bubble/photoutils/')
from utils.ssd_model import SSD, Detect
from utils.ssd_predict_show import SSDPredictShow
# from processing import conv

In [3]:
def remove_nan(data1):

    mask1_10 = data1 == data1
    mask1_1010 = np.where(mask1_10, 0, 1)
    label1, name1 = scipy.ndimage.label(mask1_1010)
    data_areas1 = scipy.ndimage.sum(mask1_1010, label1, np.arange(name1 + 1))
    minsize1 = 500
    data_mask1_10 = (data_areas1 < minsize1) & (0 < data_areas1)
    small_mask1_10 = data_mask1_10[label1.ravel()].reshape(label1.shape)
    data1[small_mask1_10] = np.nanmax(data1)

    return data1

In [4]:
def fit_lognormal_component(data):
    """
    データから対数正規成分をフィッティングして統計値を取得
    """
    # NaNを除去し、正の値のみを使用
    clean_data = data[~np.isnan(data) & (data > 0)]
    
    if len(clean_data) < 10:  # データが少なすぎる場合
        warnings.warn("データが少なすぎるため、通常の統計値を使用します")
        return np.nanmean(data), np.nanstd(data)
    
    try:
        # 対数正規分布のパラメータをフィッティング
        # lognormのパラメータ: s(shape), loc(location), scale
        params = lognorm.fit(clean_data, floc=0)  # locationを0に固定
        s, loc, scale = params
        
        # 対数正規分布の統計値を計算
        lognorm_mean = lognorm.mean(s, loc=loc, scale=scale)
        lognorm_std = lognorm.std(s, loc=loc, scale=scale)
        
        return lognorm_mean, lognorm_std
        
    except Exception as e:
        warnings.warn(f"対数正規分布のフィッティングに失敗: {e}. 通常の統計値を使用します")
        return np.nanmean(clean_data), np.nanstd(clean_data)

def norm_rp_improved(data, nan_data_dim=None, use_lognormal=True, min_method='std'):
    """
    改良された正規化関数
    
    Parameters:
    -----------
    data : numpy.ndarray
        正規化するデータ
    nan_data_dim : numpy.ndarray, optional
        統計値計算用の参照データ（Noneの場合はdataを使用）
    use_lognormal : bool, default=True
        対数正規分布を使用するかどうか
    min_method : str, default='std'
        最小値の決定方法 ('min', 'std', '2std')
    
    Returns:
    --------
    numpy.ndarray
        正規化されたデータ
    """
    # 参照データの決定
    ref_data = nan_data_dim if nan_data_dim is not None else data
    
    if use_lognormal:
        # 対数正規成分から統計値を取得
        mean, std = fit_lognormal_component(ref_data)
    else:
        # 従来の方法（ガウス分布仮定）
        mean = np.nanmean(ref_data)
        std = np.nanstd(ref_data)
    
    # 最大値の計算（3σルール）
    max_val = mean + 3 * std
    
    # 最小値の決定
    if min_method == 'min':
        min_val = np.nanmin(ref_data)
    elif min_method == 'std':
        min_val = mean - std
    elif min_method == '2std':
        min_val = mean - 2 * std
    else:
        raise ValueError("min_method must be 'min', 'std', or '2std'")
    
    # 最大値が小さすぎる場合の対処
    if max_val < 0.5:
        max_val = 0.5
    
    # 最小値がデータの最小値より大きくなりすぎないように調整
    data_min = np.nanmin(ref_data)
    if min_val > data_min + (max_val - data_min) * 0.5:
        min_val = data_min
    
    # 正規化の実行
    data_normalized = data.copy()
    data_normalized -= min_val
    data_normalized /= (max_val - min_val)
    
    return data_normalized

In [5]:
def norm_rp(data, nan_data_dim=None):
    if nan_data_dim is not None:
        data_min = np.nanmin(nan_data_dim)
        std = np.nanstd(nan_data_dim)
        mean = np.nanmean(nan_data_dim)
        max_ = mean + 3 * std
    else:
        data_min = np.nanmin(data)
        std = np.nanstd(data)
        mean = np.nanmean(data)
        max_ = mean + 3 * std

    # if max_ < 0.5:
    #     max_ = 0.5
    data -= data_min
    data /= max_
    return data


def normalize_rp(array, r_header, g_header):
    """
    Input : (y, x, 2 or 3)
    Output: (y ,x, 2 or 3)
    """
    gauss_list = []
    dims = array.shape[2]
    for dim in range(dims):
        cut_data_k = array[:, :, dim]
        if dim == 0 or dim == 2:
            # cut_data_k_ = norm_rp(cut_data_k)
            cut_data_k_ = norm_rp_improved(cut_data_k)
            gauss_list.append(cut_data_k_[:, :, None])
        else:
            nan_data = remove_peak(cut_data_k, dim, r_header, g_header)
            # cut_data_k_ = norm_rp(cut_data_k, nan_data)
            cut_data_k_ = norm_rp_improved(cut_data_k, nan_data)
            gauss_list.append(cut_data_k_[:, :, None])
    cut_data = np.concatenate(gauss_list, axis=2)

    return cut_data


def remove_peak(array, dim, r_resolution, g_resolution):
    data = array.copy()
    mean, median, std = sigma_clipped_stats(data, sigma=3)
    if dim == 0:
        # fwhm_arcsec = 0.674
        fwhm_arcsec = 0.7
        fwhm_pixel = fwhm_arcsec / r_resolution
    elif dim == 1:
        # fwhm_arcsec = 0.269
        fwhm_arcsec = 0.3
        fwhm_pixel = fwhm_arcsec / g_resolution

    daofind = DAOStarFinder(fwhm=abs(fwhm_pixel), threshold=mean + 3 * std)
    sources = daofind(data)
    try:
        positions = np.transpose((sources["xcentroid"], sources["ycentroid"]))
        same_shape_zero = np.zeros_like(data)
        for y, x in positions:
            same_shape_zero = cv2.circle(same_shape_zero, (int(y), int(x)), int(4), (255, 255, 255), -1)

        data[same_shape_zero == same_shape_zero.max()] = np.nan
        return data
    except:
        return data


def resize(data, size):
    """
    Resize data to the specified size.

    Input  :（y, x, 2 or 3）
    Output :（size ,size, 2 or 3）
    """
    cut_data = np.swapaxes(data, 1, 2)
    cut_data = np.swapaxes(cut_data, 0, 1)
    cut_data = torch.from_numpy(cut_data)
    cut_data = cut_data.unsqueeze(0)
    resize_data = F.interpolate(cut_data, (size, size), mode="bilinear", align_corners=False)
    resize_data = np.squeeze(resize_data.detach().numpy())

    resize_data_ = np.swapaxes(resize_data, 0, 1)
    resize_data_ = np.swapaxes(resize_data_, 1, 2)
    return resize_data_


def norm_res(data, r_header, g_header):
    """
    Cuts the data,
    and performs normalization and resizing.
    """
    # shape_y = data.shape[0]
    # shape_x = data.shape[1]
    # data = data[int(shape_y / 4) : int(shape_y * 3 / 4), int(shape_x / 4) : int(shape_x * 3 / 4)]
    data_ = copy.deepcopy(data)
    data_ = normalize_rp(data_, r_header, g_header)
    data_ = resize(data_, 300)

    return data_

In [6]:
def conv(obj_size, obj_sig, data):
    """
    Input size of data↓
    Input: (y, x, 2 or 3)
    Output: (size, size, 2 or 3)
    -------------------------------
    If the cut-out data is larger than obj_size, perform smoothing.
    If it's smaller, return it as is.
    """
    if data.shape[0] > obj_size*1.1:
        fwhm = (data.shape[0] / obj_size) * 2
        sig3 = fwhm / (2 * (2 * np.log(2)) ** (1 / 2))
        if sig3 > obj_sig:
            sig2 = (sig3**2 - obj_sig**2) ** (1 / 2)        
            kernel = np.outer(signal.gaussian(8 * round(sig2) + 1, sig2), signal.gaussian(8 * round(sig2) + 1, sig2))
            kernel1 = kernel / np.sum(kernel)
    
            # conv_list = []
            # for k in range(data.shape[2]):
            #     cut_data_k = data[:, :, k]
            #     lurred_k = signal.fftconvolve(cut_data_k, kernel1, mode="same")
            #     conv_list.append(lurred_k[:, :, None])
    
            pi = signal.fftconvolve(data, kernel1, mode="same")
        else:
            pi = data
    else:
        pi = data
    return pi

In [7]:
def cut_data(data_, many_ind, cut_shape, r_hdu_header, g_hdu_header):
    data_list = []
    position_list_ = []
    for i in many_ind:
        x_min = i[1] - cut_shape/50
        x_max = i[1] + cut_shape+cut_shape/50
        y_min = i[0] - cut_shape/50
        y_max = i[0] + cut_shape+cut_shape/50
        data_c = data_[int(y_min):int(y_max), int(x_min):int(x_max)].view()
        
        if np.max(data_c) == np.max(data_c): # NaNの確認
            flag = True
            dim_data = []
            for dim in range(data_c.shape[2]):
                non_zero_count = np.count_nonzero(data_c[:,:,dim]) # 0が入っていないか？
                
                if non_zero_count==data_c.shape[0]*data_c.shape[1]:
                    d = copy.deepcopy(data_c)
                    d = median_filter(d[:,:,dim], size=3) # ノイズ除去
                    if dim == 0: # resize時（大きい画像→小さい画像）のエイリアシングを考慮するためのconvolution
                        dim_data.append(conv(300, sig1_r, d)[:,:,None])
                    elif dim == 1:
                        dim_data.append(conv(300, sig1_g, d)[:,:,None])
                    pass
                else:
                    flag = False
            if flag:
                d = np.concatenate(dim_data, axis=2)
                d = d[int(cut_shape/52):int(cut_shape*51/52), int(cut_shape/52):int(cut_shape*51/52)]
                d = norm_res(d, r_hdu_header['CDELT1']*3600, g_hdu_header['CDELT1']*3600)
                data_list.append(d)
                position_list_.append([int(y_min)+int(cut_shape/50), int(x_min)+int(cut_shape/50)])
        else:
            pass

    return data_list, position_list_

In [8]:
# --- 入力値 ---
fwhm_arcsec = 0.269  # F770WのFWHM (秒角)
pixel_scale = 0.11   # MIRIのピクセルスケール (秒角/ピクセル)

# --- 計算 ---
# 1. FWHMをピクセル単位へ変換
fwhm_pixels = fwhm_arcsec / pixel_scale

# 2. FWHM(ピクセル)を標準偏差σへ変換
obj_sig = fwhm_pixels / (2 * (2 * np.log(2))**(1/2))

print(f"FWHM in pixels: {fwhm_pixels:.4f}")
print(f"Calculated obj_sig: {obj_sig:.4f}")

FWHM in pixels: 2.4455
Calculated obj_sig: 1.0385


In [9]:
# sig1_g = (0.269/0.11)/(2*(2*np.log(2))**(1/2))
# sig1_r = (0.674/0.11)/(2*(2*np.log(2))**(1/2))

sig1_g = (0.3/0.11)/(2*(2*np.log(2))**(1/2))
sig1_r = (0.7/0.11)/(2*(2*np.log(2))**(1/2))

In [10]:
sig1_r

2.7023875463709697

In [11]:
# api = wandb.Api()
# artifact = api.artifact("galactic_bubble/search_BestModel_SpitzerBubblePaper2/training_log:v20")
# artifact.download("")

In [12]:
net_w = SSD()
net_weights = torch.load(
    'training_log:v20/earlystopping.pth')
net_w.load_state_dict(net_weights['model_state_dict'])
del net_weights

In [13]:
torch.backends.cudnn.benchmark = True
device = torch.device(torch.device("cuda:0") if torch.cuda.is_available() else 'cpu')
net_w.to(device)

SSD(
  (vgg): ModuleList(
    (0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU()
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU()
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU()
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU()
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU()
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU()
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU()
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=True)
    (17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(

#### 切り出すインデックスを計算

In [14]:
# size_list = [25, 50, 75, 100, 125, 150, 200, 250, 300, 400, 500, 600, 700, 800, 900, 950]
size_list = [25, 50, 75, 100, 150, 300, 600]
# size_list = [600]
batch_list = [3000, 1000, 500, 100, 30, 30, 30]
# batch_list = [30]
len(size_list), len(batch_list)

(7, 7)

In [15]:
detect = Detect(nms_thresh=0.45, top_k=100)

In [16]:
galaxies=["ic5332", "ngc0628", "ngc1087", "ngc1300", "ngc1365", "ngc1385", "ngc1433", "ngc1512",
          "ngc1566", "ngc1672", "ngc2835", "ngc3351", "ngc3627", "ngc4254", "ngc4303", "ngc4321", 
          "ngc4535", "ngc5068", "ngc7496"]

In [17]:
gal_name = galaxies[0]
gal_name

'ic5332'

In [24]:
g_path

['/home/elmegreen/jupyter/research/Bubble_detection/Paper2/phangs_jwst/Detection_Model/SN_Model_New_Norm/analyse_JWST/noise/smooth_fits/ic5332/hlsp_phangs-jwst_jwst_miri_ic5332_f770w_v1p1_img_subtract_0.1.smooth0.3arcsec.fits']

In [18]:
# FITSファイルのパスを取得
r_path = glob.glob(
    f'/home/elmegreen/jupyter/research/Bubble_detection/Paper2/phangs_jwst/Detection_Model/SN_Model_New_Norm/analyse_JWST/noise/smooth_fits/{gal_name}/*f2100w*0.25.smooth0.7arcsec.fits')
g_path = glob.glob(
    f'/home/elmegreen/jupyter/research/Bubble_detection/Paper2/phangs_jwst/Detection_Model/SN_Model_New_Norm/analyse_JWST/noise/smooth_fits/{gal_name}/*f770w*0.1.smooth0.3arcsec.fits')

if not r_path or not g_path:
    print(f"警告: {gal_name} のFITSファイルが見つかりません。スキップします。")
    pass

# FITSデータの読み込みと前処理
r_hdu = astropy.io.fits.open(r_path[0])[0]
g_hdu = astropy.io.fits.open(g_path[0])[0]

r_hdu_data = r_hdu.data
# lower_bound = np.percentile(r_hdu_data, 1)
# upper_bound = np.percentile(r_hdu_data, 99)
# r_hdu_data = np.clip(r_hdu_data, lower_bound, upper_bound)
r_hdu_data[r_hdu_data == 0.0] = np.nan

g_hdu_data = g_hdu.data
# lower_bound = np.percentile(g_hdu_data, 1)
# upper_bound = np.percentile(g_hdu_data, 99)
# g_hdu_data = np.clip(g_hdu_data, lower_bound, upper_bound)
g_hdu_data[g_hdu_data == 0.0] = np.nan
# data_ = np.concatenate([remove_nan(r_hdu_data)[:, :, None],
#                         remove_nan(g_hdu_data)[:, :, None]], axis=2)
data_ = np.concatenate([r_hdu_data[:, :, None],
                        g_hdu_data[:, :, None]], axis=2)

In [19]:
start_time = time.time()
print(f"=============== 開始: {gal_name} の処理 ===============")
os.makedirs(f'result/{gal_name}/', exist_ok=True)



# 異なる画像サイズでループ処理
for size_index, size in enumerate(size_list):
    bb = batch_list[size_index]
    print(f"\n  -> 現在の処理サイズ: {size}x{size}, バッチ分割数: {bb}")
    
    # ----------------- indexの計算 ----------------- #
    cut_shape = (size, size)
    fragment = 3
    l = []
    slide_pix = (int(round(cut_shape[0] / fragment)), int(round(cut_shape[1] / fragment)))    
    shape = data_.shape
    x_num = int(shape[1] / slide_pix[1]) - 1
    y_num = int(shape[0] / slide_pix[0]) - 1
    x_idx = np.arange(cut_shape[1] / 5, slide_pix[1] * x_num, slide_pix[1])
    y_idx = np.arange(cut_shape[0] / 5, slide_pix[0] * y_num, slide_pix[0])
    x_ind, y_ind = np.meshgrid(x_idx, y_idx)
    
    for x, y in zip(x_ind.ravel(), y_ind.ravel()):
        l.append([y, x])
    ind = np.array(l)
    # ------------------------------------------------- #
    
    # ----------------- 推論 (infer) ----------------- #
    result, position = [], []
    result_filename = f'result/{gal_name}/result_ring_select_csize{cut_shape[0]}.npy'
    if os.path.exists(result_filename):
        os.remove(result_filename)
    batch = np.linspace(0, ind.shape[0], bb, dtype=int)
    
    # tqdmで進捗を表示
    pbar = tqdm(range(len(batch) - 1))
    for i in pbar:
        pbar.set_description(f"  [推論中] Galaxy: {gal_name}, Size: {size}")
        
        # indをバッチサイズに分割し、データを切り出して推論
        cut_ind = ind[batch[i]:batch[i+1]]
        data_list, p_list = cut_data(data_, cut_ind, cut_shape[0], r_hdu.header, g_hdu.header)
        if len(data_list) == 0:
            continue

        p_data = torch.from_numpy(np.array(data_list).astype(np.float32))
        pp_data = p_data.permute(0, 3, 1, 2)

        with torch.no_grad():
            net_w.eval()
            pp_data = pp_data.to(device)
            output, decoded_box = net_w(pp_data)
            detections = detect(*output)
            position.append(p_list)
            
            # 結果の保存
            detections_cpu = detections.to('cpu').detach().numpy().copy()
            if size == 25: 
                with NpyAppendArray(result_filename) as npaa:
                    npaa.append(detections_cpu)
            else:
                result.append(detections_cpu)

    if not position:
        print(f"    -> サイズ {size} では有効なデータがなかったため、結果ファイルは作成されませんでした。")
        continue

    position = np.concatenate(position)
    np.save(f'result/{gal_name}/position_ring_select_csize{cut_shape[0]}.npy', position)

    if size != 25 and result:
        result = np.concatenate(result)
        np.save(result_filename, result)
    # ------------------------------------------------- #

end_time = time.time()
total_time_minutes = (end_time - start_time) / 60
print(f"\n=============== 完了: {gal_name} の処理 ===============")
print(f"処理時間: {total_time_minutes:.2f} 分")
print("=" * 50 + "\n")


  -> 現在の処理サイズ: 25x25, バッチ分割数: 3000


  [推論中] Galaxy: ic5332, Size: 25: 100%|█████████████████████████████████████████████████████████████| 2999/2999 [01:57<00:00, 25.44it/s]/s]



  -> 現在の処理サイズ: 50x50, バッチ分割数: 1000


  [推論中] Galaxy: ic5332, Size: 50: 100%|███████████████████████████████████████████████████████████████| 999/999 [00:25<00:00, 39.08it/s]/s]



  -> 現在の処理サイズ: 75x75, バッチ分割数: 500


  [推論中] Galaxy: ic5332, Size: 75: 100%|███████████████████████████████████████████████████████████████| 499/499 [00:12<00:00, 39.71it/s]/s]



  -> 現在の処理サイズ: 100x100, バッチ分割数: 100


  [推論中] Galaxy: ic5332, Size: 100: 100%|████████████████████████████████████████████████████████████████| 99/99 [00:13<00:00,  7.16it/s]/s]



  -> 現在の処理サイズ: 150x150, バッチ分割数: 30


  [推論中] Galaxy: ic5332, Size: 150: 100%|████████████████████████████████████████████████████████████████| 29/29 [00:08<00:00,  3.24it/s]/s]



  -> 現在の処理サイズ: 300x300, バッチ分割数: 30


  [推論中] Galaxy: ic5332, Size: 300: 100%|████████████████████████████████████████████████████████████████| 29/29 [00:01<00:00, 17.42it/s]/s]



  -> 現在の処理サイズ: 600x600, バッチ分割数: 30


  [推論中] Galaxy: ic5332, Size: 600: 100%|████████████████████████████████████████████████████████████████| 29/29 [00:00<00:00, 35.92it/s]/s]


処理時間: 3.02 分






In [20]:
fwhm = (600 / 300) * 2
sig3 = fwhm / (2 * (2 * np.log(2)) ** (1 / 2))
sig1_r = (0.674/0.11)/(2*(2*np.log(2))**(1/2))

In [21]:
sig3, sig1_r

(1.6986436005760381, 2.6020131517914766)

In [22]:
(sig3**2 - obj_sig**2) ** (1 / 2)

1.344221271625419