# 学習データを作成する
#### このノートブックでは、spitzer bubbleを切り出し、augmentationする
#### bubbleは事前に選定したrank4~5のものを使用
#### augmentationの仕方は、切り出すサイズを変え、位置をずらす

In [1]:
import astropy.io.fits
import astroquery.vizier
import astropy.wcs
from astropy.coordinates import SkyCoord

import numpy as np
import pandas as pd
from scipy import signal

import matplotlib.pyplot as plt
import matplotlib.patches as patches
from PIL import Image

import time
import pathlib
import random
import collections
import copy

import torch
from torch.nn import functional as F
from torch import nn

In [3]:
# 事前に自分でcygnus領域に含まれるbubbleとfitsのデータが足りないなどで上手く切り出せなかったbubbleを除いたものを選定した。
nishimoto = pd.read_csv('new_選定.csv')
nishimoto = nishimoto.drop('Unnamed: 0', axis=1)
nishimoto = nishimoto.fillna(0)
nishimoto.head()

Unnamed: 0,1,2,3,4,5
0,0.0,0.0,0.0,1.0,0.0
1,0.0,0.0,0.0,1.0,0.0
2,0.0,0.0,1.0,0.0,0.0
3,0.0,1.0,0.0,0.0,0.0
4,0.0,1.0,0.0,0.0,0.0


In [4]:
rank1 = []
rank2 = []
rank3 = []
rank4 = []
rank5 = []
for i in range(len(nishimoto)):
    nishimoto_s = nishimoto.loc[i]
    Q = np.where(np.array(nishimoto_s.tolist())==1)[0]
#     print(Q)
    if Q == 0:
        rank1.append(i)
    elif Q == 1:
        rank2.append(i)
    elif Q == 2:
        rank3.append(i)
    elif Q == 3:
        rank4.append(i)
    elif Q == 4:
        rank5.append(i)

In [5]:
# 選定するときに用いたカタログ
catalogue = pd.read_csv('all_ring_mwp_pytorch_remove_cygnus_catalogue.csv')
catalogue.head()

Unnamed: 0.1,Unnamed: 0,GLON,GLAT,Disp,MajAxis,MinAxis,Reff,e_Reff,theta,e_theta,...,HR3,RelFlag,HierFlag,IDDR1,IDA14,Dist,IDCW,_RA.icrs,_DE.icrs,id
0,2G0020120-0068213,2.012,-0.6821,0.12,0.53,0.47,0.5,0.14,26,84,...,0.23,R,,1G002011-006818,G002.009-00.680,,CN28,268.2442,-27.5619,0
1,2G0021660+0000943,2.166,0.0094,0.03,0.2,0.18,0.19,0.07,169,30,...,0.18,C,,,G002.164+00.010,,,267.6633,-27.0767,1
2,2G0022683+0024069,2.2683,0.2407,0.77,2.79,2.1,2.47,0.45,50,91,...,0.28,C,,1G002270+002402,G002.272+00.237,,,267.4993,-26.8702,2
3,2G0023153+0024455,2.3153,0.2446,0.23,0.82,0.58,0.71,0.17,28,103,...,0.12,R,,,,,,267.5226,-26.8279,3
4,2G0023189-0014641,2.3189,-0.1464,0.2,0.91,0.81,0.86,0.12,43,57,...,0.14,C,,1G002319-001476,G002.317-00.149,,,267.9014,-27.025,4


In [6]:
catalogue_ = pd.concat([catalogue.iloc[rank4], catalogue.iloc[rank5]])

In [7]:
catalogue_ = catalogue_.rename(columns={'Unnamed: 0':'MWP'})
MWP = catalogue_.set_index('MWP')
MWP.head()

Unnamed: 0_level_0,GLON,GLAT,Disp,MajAxis,MinAxis,Reff,e_Reff,theta,e_theta,Ecc,...,HR3,RelFlag,HierFlag,IDDR1,IDA14,Dist,IDCW,_RA.icrs,_DE.icrs,id
MWP,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
2G0020120-0068213,2.012,-0.6821,0.12,0.53,0.47,0.5,0.14,26,84,0.44,...,0.23,R,,1G002011-006818,G002.009-00.680,,CN28,268.2442,-27.5619,0
2G0021660+0000943,2.166,0.0094,0.03,0.2,0.18,0.19,0.07,169,30,0.48,...,0.18,C,,,G002.164+00.010,,,267.6633,-27.0767,1
2G0024521+0013755,2.4522,0.1376,0.14,0.6,0.46,0.53,0.1,27,117,0.65,...,0.13,C,,1G002452+001373,G002.451+00.136,,,267.7041,-26.7652,8
2G0032101-0010494,3.2101,-0.105,0.05,0.26,0.16,0.22,0.04,17,65,0.8,...,0.18,R,D,,,,CN37,268.369,-26.2368,20
2G0035672+0005046,3.5672,0.0505,0.12,0.44,0.37,0.4,0.07,9,100,0.53,...,0.19,C,,1G003567+000524,G003.570+00.050,,,268.4215,-25.8501,26


In [8]:
l = ['spitzer_02100+0000_rgb','spitzer_04200+0000_rgb','spitzer_31200+0000_rgb','spitzer_33300+0000_rgb',
     'spitzer_35400+0000_rgb','spitzer_00300+0000_rgb','spitzer_02400+0000_rgb','spitzer_04500+0000_rgb','spitzer_31500+0000_rgb','spitzer_33600+0000_rgb',
     'spitzer_35700+0000_rgb','spitzer_00600+0000_rgb','spitzer_02700+0000_rgb','spitzer_04800+0000_rgb','spitzer_29700+0000_rgb','spitzer_31800+0000_rgb',
     'spitzer_33900+0000_rgb','spitzer_00900+0000_rgb','spitzer_03000+0000_rgb','spitzer_05100+0000_rgb','spitzer_30000+0000_rgb','spitzer_32100+0000_rgb',
     'spitzer_34200+0000_rgb','spitzer_01200+0000_rgb','spitzer_03300+0000_rgb','spitzer_05400+0000_rgb','spitzer_30300+0000_rgb','spitzer_32400+0000_rgb',
     'spitzer_34500+0000_rgb','spitzer_01500+0000_rgb','spitzer_03600+0000_rgb','spitzer_05700+0000_rgb','spitzer_30600+0000_rgb','spitzer_32700+0000_rgb',
     'spitzer_34800+0000_rgb','spitzer_01800+0000_rgb','spitzer_03900+0000_rgb','spitzer_06000+0000_rgb','spitzer_30900+0000_rgb','spitzer_33000+0000_rgb',
     'spitzer_35100+0000_rgb']

l = sorted(l)
#,'spitzer_29400+0000_rgb'は、8µmのデータが全然ないため、x　'spitzer_00000+0000_rgb'は、ringが無い
# 'spitzer_06300+0000_rgb'は検証データとして使用
# 'spitzer_00000+0000_rgb'はringの同定は行われていない

In [9]:
# bubbleをタイル状に並べるための関数
from PIL import Image, ImageDraw, ImageFont
def data_view_rectangl(col, imgs, infos=None, moji_size=100):
    '''
    col: number of columns
    imgs: tensor or nparray with a shape of (?, y, x, 1) or (?, y, x, 3)
    infos: dictonary from CutTable
    '''
    imgs = np.uint8(imgs[:, ::-1, :, 0]) if imgs.shape[3] == 1 else np.uint8(imgs[:,::-1])
    row = (lambda x, y: x//y if x/y-x//y==0.0 else x//y+1)(imgs.shape[0], col)
    dst = Image.new('RGB', (imgs.shape[1]*col, imgs.shape[2]*row))
    ## fontのサイズを指定するために、そのパソコンの文字が登録されているpathが必要
    ## 使用するデバイスによって変更する必要がある。
    font = ImageFont.truetype('/usr/share/fonts/truetype/freefont/FreeMono.ttf', moji_size)
    for i, arr in enumerate(imgs):
        img = Image.fromarray(arr)
        img = img.point(lambda x: x * 1.5)
        if infos is not None:
            draw = ImageDraw.Draw(img)
            draw.text((10, 10), '%s'%infos['id'].tolist()[i], font=font)
            for j in range(len(infos['xmin'].tolist()[i])):
                draw.rectangle((infos['xmin'].tolist()[i][j]*300, (1 - infos['ymax'].tolist()[i][j])*300,
                            infos['xmax'].tolist()[i][j]*300, (1 - infos['ymin'].tolist()[i][j])*300), width=2)

        quo, rem = i//col, i%col
        dst.paste(img, (arr.shape[0]*rem, arr.shape[1]*quo))

    return dst

In [10]:
## 1chを規格化するための関数
def norm(data):
    min_ = np.min(data)
    b = np.std(data)
    mean = np.mean(data)
    data -= min_
    max_ = b*3 + mean
    data[data>max_] = max_

    data /= np.max(data)
    return data

In [11]:
# norm関数と連携
# チャンネルごとに規格化し、最後に合わせる
def normalize(array):
    """
    入力：（y, x, 2 or 3）
    出力：（y ,x, 2 or 3）
    """
    gauss_list = []
    s = array.shape[2]
    for k in range(s):
        cut_data_k = array[:,:,k]
        cut_data_k = norm(cut_data_k)
        gauss_list.append(cut_data_k[:,:,None])                           

    cut_data = np.concatenate(gauss_list, axis=2)
    return cut_data

In [12]:
# resizeを行う
def resize(data, size):
    """
    sizeは、自由　　　　　　
    今はy ,xは同じサイズだが、違うサイズにしたければ、タプルでsizeを入力するとよい
    入力データ：（y, x, 2 or 3）
    出力：（size ,size, 2 or 3）
    """
    cut_data = np.swapaxes(data, 1, 2)
    cut_data = np.swapaxes(cut_data, 0, 1)
    cut_data = torch.from_numpy(cut_data)
    cut_data = cut_data.unsqueeze(0)
    resize_data = F.interpolate(cut_data, (size, size), mode='bilinear', align_corners=False)
    resize_data = np.squeeze(resize_data.detach().numpy())
    
    resize_data_ = np.swapaxes(resize_data, 0, 1)
    resize_data_ = np.swapaxes(resize_data_, 1, 2)
    return resize_data_

In [13]:
# 一つのfits内にあるbubbleを全て抽出
def all_star(dataframe, world):
    """
    データセットのringの範囲をここで決める
    """
    star_dic = {}
    for index,row in dataframe.iterrows():    
    
        lmax = row['GLON'] + 1.5*row['Reff']/60
        bmin = row['GLAT'] - 1.5*row['Reff']/60
        #右端
        lmin = row['GLON'] - 1.5*row['Reff']/60
        bmax = row['GLAT'] + 1.5*row['Reff']/60
        #これは、リングを切り取る範囲　　切り取る範囲はRoutの3倍
        x_pix_min, y_pix_min = world.all_world2pix(lmax, bmin, 0)
        x_pix_max, y_pix_max = world.all_world2pix(lmin, bmax, 0)
        
        star_dic[row['MWP']] = [x_pix_min, y_pix_min, x_pix_max, y_pix_max]
        
    return star_dic

In [14]:
def calc_pix(row, world):
    """
    切り出す画像の範囲をここで決める
    """
#     import random
#     random_num = random.choices([3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19], 
#                         weights=[6, 6, 6, 6, 6, 5, 6, 6, 6, 6, 6, 6, 6, 5, 6, 6, 6])[0]
    random_num = 1/np.random.uniform(0.052, 0.95)
#     random_num = 2
    lmax = row['GLON'] + random_num*1.5*row['Reff']/60
    bmin = row['GLAT'] - random_num*1.5*row['Reff']/60
    #右端
    lmin = row['GLON'] - random_num*1.5*row['Reff']/60
    bmax = row['GLAT'] + random_num*1.5*row['Reff']/60
    #これは、リングを切り取る範囲　　
    x_min, y_min = world.all_world2pix(lmax, bmin, 0)
    x_max, y_max = world.all_world2pix(lmin, bmax, 0)
    r = int((x_max - x_min)/(2*random_num))#ringの半径pixel
    
    width = x_max - x_min
    height = y_max - y_min
    
    x_pix_min = x_min - width/2
    y_pix_min = y_min - height/2
    x_pix_max = x_max + width/2
    y_pix_max = y_max + height/2
    
    #random_num - ２とは、切り出した画像が一辺random_num*rに対し、bboxが2*rだから、画像からリングがはみ出さないように
    x_offset = random.uniform(-(random_num-1.5)*r, (random_num-1.5)*r)
    y_offset = random.uniform(-(random_num-1.5)*r, (random_num-1.5)*r)
    x_pix_min = x_pix_min + int(x_offset)
    x_pix_max = x_pix_max + int(x_offset)
    y_pix_min = y_pix_min + int(y_offset)
    y_pix_max = y_pix_max + int(y_offset)
    width = x_pix_max - x_pix_min
    height = y_pix_max - y_pix_min
    
    
    return  x_pix_min, y_pix_min, x_pix_max, y_pix_max, width, height#, star_list

In [15]:
def judge_01(number):
    if number > 1:
        return 1
    elif number<0:
        return 0
    else:
        return number

In [16]:
def make_label(x_pix_min, y_pix_min,x_pix_max, y_pix_max, cover_star_position, cover_star_name,  width, hight, MWP):
    """
    sは、主体となるringの位置情報
    x_pix_min, y_pix_min,x_pix_max, y_pix_maxは、切り出す画像のサイズ
    主体となるringに重なっているringのindex情報、重なったringの情報はstar_listの中にある。
    """

    xmin_list = []
    ymin_list = []
    xmax_list = []
    ymax_list = []
    named_list = []
    MWP_name_select = MWP.index.tolist()
    #切り出した画像にたまたま入った天体があるか、ないか
    if len(cover_star_position) == 0:
        pass
    else:
        
        for p, n in zip(cover_star_position, cover_star_name):
            # pは、('2G0020120-0068213', [array(7573.50002914), array(4663.19997904), array(7673.50003014), array(4763.19998004)])
            #のように、天体名とpostionが入っている
            if p[0] in MWP_name_select:
                
                xmin_c = p[1][0] - (x_pix_min+width/4)
                ymin_c = p[1][1] - (y_pix_min+hight/4)
                xmax_c = p[1][2] - (x_pix_min+width/4)
                ymax_c = p[1][3] - (y_pix_min+hight/4)
                xmin_list.append(judge_01(xmin_c/(width/2)))
                xmax_list.append(judge_01(xmax_c/(width/2)))
                ymin_list.append(judge_01(ymin_c/(hight/2)))
                ymax_list.append(judge_01(ymax_c/(hight/2)))
                named_list.append(n)
            
    return xmin_list, ymin_list, xmax_list, ymax_list, named_list

In [17]:
def find_cover(star_list, x_pix_min, y_pix_min, x_pix_max, y_pix_max):
    """
    切り出した画像の中に、他のリングが入っていないか確かめる。
    入っていたら、ラベル付けする
    star_listはdictionaryで、中身は、x_pix_min, y_pix_min, x_pix_max, y_pix_maxという順になっている
    """
    width = (x_pix_max - x_pix_min)/4
    hight = (y_pix_max - y_pix_min)/4
    
    g_area = ((x_pix_max-width)-(x_pix_min+width))*((y_pix_max-hight)-(y_pix_min+hight))
    
    overlapp_list = []
    overlapp_name = []
    for d in star_list.items():
        s_xmin = d[1][0]
        s_xmax = d[1][2]
        s_ymin = d[1][1]
        s_ymax = d[1][3]
        
        xx = np.array([s_xmin, s_xmax])
        yy = np.array([s_ymin, s_ymax])
        c_xx = np.clip(xx, x_pix_min+width, x_pix_max-width)
        c_yy = np.clip(yy, y_pix_min+hight, y_pix_max-hight)   
        s_area = (xx[1]-xx[0])*(yy[1]-yy[0])
        c_area = (c_xx[1]-c_xx[0])*(c_yy[1]-c_yy[0])
        
#         print('area : ', s_area, g_area*3/4)
        if (c_area>=s_area/2 and (d[1][2]-d[1][0])>=(width*2)/20 and 
            (d[1][3]-d[1][1])>=(hight*2)/20):
            overlapp_list.append(d)
            overlapp_name.append(d[0])
        
#         if (((x_pix_min+width)<=d[1][0]<d[1][2]<=(x_pix_max-width)) and
#             ((y_pix_min+hight)<=d[1][1]<d[1][3]<=(y_pix_max-hight)) and 
#             (d[1][2]-d[1][0])>=(width*2)/13 and (d[1][3]-d[1][1])>=(hight*2)/13):
#             overlapp_list.append(d)
#             overlapp_name.append(d[0])

        else:pass
        
    return overlapp_list, overlapp_name

In [18]:
def conv(obj_size, obj_sig, data):
    """
    dataの入力サイズ↓
    入力：（y ,x, 2 or 3）
    出力：（size ,size, 2 or 3）
    -------------------------------
    切り出したデータがobj_sizeより大きければ、smoothingをする
    小さければ、そのまま返す。
    """
    if data.shape[0]>obj_size:
        fwhm = (data.shape[0]/obj_size)*2
        sig3 = fwhm/(2*(2*np.log(2))**(1/2))
        sig2 = (sig3**2-obj_sig**2)**(1/2)

        kernel = np.outer(signal.gaussian(8*round(sig2)+1, sig2), signal.gaussian(8*round(sig2)+1, sig2))
        kernel1= kernel/np.sum(kernel)

        conv_list = []
        for k in range(data.shape[2]):
            cut_data_k = data[:,:,k]
            lurred_k = signal.fftconvolve(cut_data_k, kernel1, mode='same')
            conv_list.append(lurred_k[:,:,None])

        pi = np.concatenate(conv_list, axis=2)
    else:
        pi = data
    return pi


In [19]:
frame_mwp = pd.DataFrame(columns=['fits', 'name', 'xmin', 'xmax', 'ymin', 'ymax'])
mwp_ring_list = []
file_path = pathlib.Path('fitsがあるpath')
sig1 = 1/(2*(np.log(2))**(1/2))

In [20]:
start= time.time()

for i in range(len(l)): 
    
    fits_path = l[i]
    spitzer_rfits = astropy.io.fits.open(file_path/fits_path/'r.fits')[0]
    spitzer_gfits = astropy.io.fits.open(file_path/fits_path/'g.fits')[0]
    spitzer_bfits = astropy.io.fits.open(file_path/fits_path/'b.fits')[0]
    
    #RGBにしたいため、fitsのdataを重ねる
    data = np.concatenate([spitzer_rfits.data[:,:,None], 
                            spitzer_gfits.data[:,:,None], 
                            spitzer_bfits.data[:,:,None]], axis=2)

    a = data.shape[0]
    b = data.shape[1]
    w = astropy.wcs.WCS(spitzer_rfits.header)
    GLON_min, GLAT_min = w.all_pix2world(b, 0, 0)
    GLON_max, GLAT_max = w.all_pix2world(0, a, 0) 

    GLON_center = (GLON_min+GLON_max)/2
    GLON_new_min = GLON_center-1.5
    GLON_new_max = GLON_center+1.5

    mwp = MWP.query('@GLON_new_min < GLON <= @GLON_new_max')
    mwp = mwp.reset_index()
    # star_listは辞書
    star_dic = all_star(mwp, w)
    print(fits_path)
    for index,row in mwp.iterrows():    
        
        s = star_dic[row['MWP']] # 主体となるringの位置情報

        for _ in range(10):
            
            x_pix_min, y_pix_min, x_pix_max, y_pix_max, width, hight = calc_pix(row, w)
            cover_star_position, cover_star_name = find_cover(star_dic, x_pix_min, y_pix_min, x_pix_max, y_pix_max)

            if x_pix_min<0 or y_pix_min<0:
#                 print('min_error')
                pass

            else:
                c_data = data[int(y_pix_min):int(y_pix_max), int(x_pix_min):int(x_pix_max)].view()
                cut_data = copy.deepcopy(c_data)
                
                if np.isnan(cut_data.sum()):
#                     print('nan_error')
                    pass
                else:
                    pi = conv(300, sig1, cut_data)
                    r_shape_y = pi.shape[0]
                    r_shape_x = pi.shape[1]
                    res_data = pi[int(r_shape_y/4):int(r_shape_y*3/4), int(r_shape_x/4):int(r_shape_x*3/4)]
                    res_data = normalize(res_data)
                    res_data = resize(res_data, 300)
                    xmin_list, ymin_list, xmax_list, ymax_list, name_list = make_label(x_pix_min, y_pix_min, x_pix_max, y_pix_max, 
                                                                                       cover_star_position, cover_star_name,
                                                                                       width, hight, MWP)
                    mwp_ring_list.append(res_data)
                    
                    info = [[fits_path, name_list, xmin_list, xmax_list, ymin_list, ymax_list]]
                    p_data = pd.DataFrame(columns=['fits', 'name', 'xmin', 'xmax', 'ymin', 'ymax'], data=info)
                    frame_mwp = pd.concat([frame_mwp, p_data])
                    
                    
    
print(time.time()-start)

spitzer_00300+0000_rgb
spitzer_00600+0000_rgb
spitzer_00900+0000_rgb
spitzer_01200+0000_rgb
spitzer_01500+0000_rgb
spitzer_01800+0000_rgb
spitzer_02100+0000_rgb
spitzer_02400+0000_rgb
spitzer_02700+0000_rgb
spitzer_03000+0000_rgb
spitzer_03300+0000_rgb
spitzer_03600+0000_rgb
spitzer_03900+0000_rgb
spitzer_04200+0000_rgb
spitzer_04500+0000_rgb
spitzer_04800+0000_rgb
spitzer_05100+0000_rgb
spitzer_05400+0000_rgb
spitzer_05700+0000_rgb
spitzer_06000+0000_rgb
spitzer_29700+0000_rgb
spitzer_30000+0000_rgb
spitzer_30300+0000_rgb
spitzer_30600+0000_rgb
spitzer_30900+0000_rgb
spitzer_31200+0000_rgb
spitzer_31500+0000_rgb
spitzer_31800+0000_rgb
spitzer_32100+0000_rgb
spitzer_32400+0000_rgb
spitzer_32700+0000_rgb
spitzer_33000+0000_rgb
spitzer_33300+0000_rgb
spitzer_33600+0000_rgb
spitzer_33900+0000_rgb
spitzer_34200+0000_rgb
spitzer_34500+0000_rgb
spitzer_34800+0000_rgb
spitzer_35100+0000_rgb
spitzer_35400+0000_rgb
spitzer_35700+0000_rgb
823.8004236221313


In [21]:
len(mwp_ring_list)

10593

In [22]:
#ひとつの画像に本当にringが一つ以上入っているか計算する
for i in range(len(frame_mwp)):
    series = frame_mwp.iloc[i]
    if len(series['xmin'])==0:
        print(i)

In [23]:
frame_mwp['id']  = [i for i in range(len(frame_mwp))]

In [24]:
#ringをタイル状に並べる
#一番最後の行の.saveをコメントアウトすれば保存されずに出力される
mwp_ring_list_ = np.array(mwp_ring_list)
mwp_ring_list_ = mwp_ring_list_*255
mwp_ring_list_ = np.uint8(mwp_ring_list_)

data_view_rectangl(20, mwp_ring_list_[::10], frame_mwp[::10]).save('example_hukusuu_many.pdf')

# 画像の中に複数リング
### 上のセルでは、一つ（切り出し方によっては複数）のbubbleを切り出したが、これより下は
### 意図的に複数のbubbleを切り出す。

In [25]:
# 適当に切り出す
def random_cut(data, GLON_new_min1, GLON_new_max1, GLAT_new_min1, GLAT_new_max1):
    
    random_GLON = random.uniform(GLON_new_min1, GLON_new_max1)
    random_GLAT = random.uniform(GLAT_new_min1, GLAT_new_max1)
    random_Rout = random.uniform(1, 6)
    lmax_random = random_GLON + random_Rout/60
    bmin_random = random_GLAT - random_Rout/60
    #右端
    lmin_random = random_GLON - random_Rout/60 
    bmax_random = random_GLAT + random_Rout/60
    ### おそらくworld2pixは360を超えても-360をしてくれる（今回は関係ないが）
#     print(lmin_random, lmax_random )
    x_random_min, y_random_min = w.all_world2pix(lmax_random, bmin_random, 0)
    x_random_max, y_random_max = w.all_world2pix(lmin_random, bmax_random, 0)
    width = x_random_max - x_random_min
    height = y_random_max - y_random_min
    
    x_random_min = x_random_min - width/2
    x_random_max = x_random_max + width/2
    y_random_min = y_random_min - width/2
    y_random_max = y_random_max + width/2
    
    cut_data = data[int(y_random_min):int(y_random_max), int(x_random_min):int(x_random_max)]
    
    return cut_data, y_random_min, y_random_max, x_random_min, x_random_max

In [26]:
# random_cutと連動
# randomに切り出した中でbubbleが二つ以上入っていれば、while文を抜ける
def judge_two(data, star_dic, GLON_min, GLON_max, GLAT_min, GLAT_max):
# star_listの中身、[x_pix_min, y_pix_min, x_pix_max, y_pix_max]がfitsに含まれているring分、入っている 
    count = 0
    number = 0
    flag=0
    while count<2:
        number += 1 
        count = 0
        new_star_list = []
        name_list = []
        cut_data_random, y_random_min, y_random_max, x_random_min, x_random_max = random_cut(data, 
                                                                                         GLON_min, GLON_max, 
                                                                                         GLAT_min, GLAT_max)
        boolian = not np.isnan(np.sum(cut_data_random))
#         print(star_dic)
        for k, i in star_dic.items():
#             print(i)
            width = x_random_max - x_random_min
            height = y_random_max - y_random_min
            
            
            if ((x_random_min+width/4) <= i[0] < i[2] <= (x_random_max-width/4) and (y_random_min+height/4) <= i[1] < i[3] <= (y_random_max-height/4)
                and boolian and ((i[2]-i[0])>((x_random_max-width/4)-(x_random_min+width/4))/20) and (k in MWP.index.tolist())):

                count += 1
                new_star_list.append(i)
                name_list.append(k)
                
            else:pass
             
        if number ==1000:
            flag = 1
            break
            
    return cut_data_random, new_star_list, y_random_min, y_random_max, x_random_min, x_random_max, flag, name_list

In [27]:
# labelを作る
def make_label2(star_list, x_pix_min, y_pix_min, x_pix_max, y_pix_max):
    # star_listの中身、[x_pix_min, y_pix_min, x_pix_max, y_pix_max]がfitsに含まれているring分、入っている 
    """
    sは、主体となるringの位置情報
    x_pix_min, y_pix_min,x_pix_max, y_pix_maxは、切り出す画像のサイズ
    主体となるringに重なっているringのindex情報、重なったringの情報はstar_listの中にある。
    """
    width = x_pix_max - x_pix_min
    hight = y_pix_max - y_pix_min

    xmin_list = []
    ymin_list = []
    xmax_list = []
    ymax_list = []

    for c in star_list:
        xmin_c = c[0] - (x_pix_min+width/4)
        ymin_c = c[1] - (y_pix_min+hight/4)
        xmax_c = c[2] - (x_pix_min+width/4)
        ymax_c = c[3] - (y_pix_min+hight/4)
        xmin_list.append(judge_01(xmin_c/(width/2)))
        xmax_list.append(judge_01(xmax_c/(width/2)))
        ymin_list.append(judge_01(ymin_c/(hight/2)))
        ymax_list.append(judge_01(ymax_c/(hight/2)))
#         named_list.append(n)
            
    return xmin_list, ymin_list, xmax_list, ymax_list

In [28]:
start= time.time()
# frame_all= pd.DataFrame(columns=['fits', 'name', 'xmin', 'xmax', 'ymin', 'ymax'])
# many_ring_list = []
file_path = pathlib.Path('../../../../fits_data/remove_saturation_nan_fits//')
sig1 = 1/(2*(np.log(2))**(1/2))
for i in range(len(l)): 
    
    fits_path = l[i]
    spitzer_rfits = astropy.io.fits.open(file_path/fits_path/'r.fits')[0]
    spitzer_gfits = astropy.io.fits.open(file_path/fits_path/'g.fits')[0]
    spitzer_bfits = astropy.io.fits.open(file_path/fits_path/'b.fits')[0]
    
    #RGBにしたいため、fitsのdataを重ねる
    data_ = np.concatenate([spitzer_rfits.data[:,:,None], spitzer_gfits.data[:,:,None], spitzer_bfits.data[:,:,None]], axis=2)
    data = copy.deepcopy(data_)
    a = data.shape[0]
    b = data.shape[1]
    w = astropy.wcs.WCS(spitzer_rfits.header)
    GLON_min, GLAT_min = w.all_pix2world(b, 0, 0)
    GLON_max, GLAT_max = w.all_pix2world(0, a, 0) 
    
    GLON_center = (GLON_min+GLON_max)/2
    GLON_min = GLON_center-1.3
    GLON_max = GLON_center+1.3
    
    GLAT_center = (GLAT_min+GLAT_max)/2
    GLAT_min = GLAT_center-1
    GLAT_max = GLAT_center+1
    
    mwp = MWP.query('@GLON_min < GLON <= @GLON_max')
    mwp = mwp.reset_index()
    # star_dicは辞書
    star_dic = all_star(mwp, w)
#     print(star_dic)
    print(fits_path)

    for _ in range(10):
        cut_data_random_, new_star_list, y_random_min, y_random_max, x_random_min, x_random_max, flag, name_list = judge_two(data, 
                                                                                                       star_dic, 
                                                                                                       GLON_min, 
                                                                                                       GLON_max, 
                                                                                                       GLAT_min, 
                                                                                                       GLAT_max)
        if flag==1:
            break
    
        pi = conv(300, sig1, cut_data_random_)
        r_shape_y = pi.shape[0]
        r_shape_x = pi.shape[1]
        res_data = pi[int(r_shape_y/4):int(r_shape_y*3/4), int(r_shape_x/4):int(r_shape_x*3/4)]
        res_data = normalize(res_data)
        res_data = resize(res_data, 300)
        mwp_ring_list.append(res_data)

        xmin_list, ymin_list, xmax_list, ymax_list = make_label2(new_star_list,
                                                     x_random_min, y_random_min, 
                                                     x_random_max, y_random_max)
        info = [[fits_path, name_list, xmin_list, xmax_list, ymin_list, ymax_list]]
        p_data = pd.DataFrame(columns=['fits', 'name', 'xmin', 'xmax', 'ymin', 'ymax'], data=info)
        frame_mwp = pd.concat([frame_mwp, p_data])

spitzer_00300+0000_rgb
spitzer_00600+0000_rgb
spitzer_00900+0000_rgb
spitzer_01200+0000_rgb
spitzer_01500+0000_rgb
spitzer_01800+0000_rgb
spitzer_02100+0000_rgb
spitzer_02400+0000_rgb
spitzer_02700+0000_rgb
spitzer_03000+0000_rgb
spitzer_03300+0000_rgb
spitzer_03600+0000_rgb
spitzer_03900+0000_rgb
spitzer_04200+0000_rgb
spitzer_04500+0000_rgb
spitzer_04800+0000_rgb
spitzer_05100+0000_rgb
spitzer_05400+0000_rgb
spitzer_05700+0000_rgb
spitzer_06000+0000_rgb
spitzer_29700+0000_rgb
spitzer_30000+0000_rgb
spitzer_30300+0000_rgb
spitzer_30600+0000_rgb
spitzer_30900+0000_rgb
spitzer_31200+0000_rgb
spitzer_31500+0000_rgb
spitzer_31800+0000_rgb
spitzer_32100+0000_rgb
spitzer_32400+0000_rgb
spitzer_32700+0000_rgb
spitzer_33000+0000_rgb
spitzer_33300+0000_rgb
spitzer_33600+0000_rgb
spitzer_33900+0000_rgb
spitzer_34200+0000_rgb
spitzer_34500+0000_rgb
spitzer_34800+0000_rgb
spitzer_35100+0000_rgb
spitzer_35400+0000_rgb
spitzer_35700+0000_rgb


In [29]:
frame_mwp['id'] = [i for i in range(len(frame_mwp))]
frame_mwp

Unnamed: 0,fits,name,xmin,xmax,ymin,ymax,id
0,spitzer_00300+0000_rgb,[2G0020120-0068213],[0.1365964173721316],[0.23625005149378678],[0.5551416764976318],[0.6547953106192858],0
0,spitzer_00300+0000_rgb,[2G0020120-0068213],[0.2680902043345587],[0.5759350668725902],[0.11622007366783164],[0.42406493620585944],1
0,spitzer_00300+0000_rgb,[2G0020120-0068213],[0.20913421224965428],[0.7757558768455542],[0.17891439044007815],[0.7455360550359712],2
0,spitzer_00300+0000_rgb,[2G0020120-0068213],[0.0013299567596892655],[0.8233135459636813],[0.13284732971717245],[0.9548309189211464],3
0,spitzer_00300+0000_rgb,[2G0020120-0068213],[0.0362237721148361],[0.11609338729111267],[0.07988582797457988],[0.15975544315085544],4
...,...,...,...,...,...,...,...
0,spitzer_35700+0000_rgb,"[2G3563404-0009054, 2G3562769-0007904]","[0.2451006239665258, 0.5459186474259973]","[0.3225072113401268, 0.7007318221731992]","[0.45632328754134577, 0.47901142698633115]","[0.5337298749149418, 0.6338246017335217]",10905
0,spitzer_35700+0000_rgb,"[2G3579667-0016881, 2G3579868-0015789, 2G35800...","[0.6169360219919319, 0.4533268855395715, 0.303...","[0.6946688673583395, 0.5606722436887062, 0.491...","[0.7791167239108852, 0.8450045666568464, 0.750...","[0.8568495692769483, 0.9523499248059714, 0.939...",10906
0,spitzer_35700+0000_rgb,"[2G3558423-0048948, 2G3558804-0052267]","[0.4062549266650938, 0.027078547797243238]","[0.5988846380979796, 0.4630299925664748]","[0.7005793951317584, 0.3545218091713647]","[0.8932091065648169, 0.7904732539407884]",10907
0,spitzer_35700+0000_rgb,"[2G3558423-0048948, 2G3558804-0052267]","[0.40473551847200395, 0.11204213937229014]","[0.5534300160634377, 0.4485612597718746]","[0.3143074011994009, 0.04717904130453973]","[0.463001898790974, 0.3836981617042862]",10908


In [30]:
mwp_ring_list_ = np.array(mwp_ring_list)
mwp_ring_list_ = mwp_ring_list_*255
mwp_ring_list_ = np.uint8(mwp_ring_list_)

data_view_rectangl(20, mwp_ring_list_[::10], frame_mwp[::10]).save('Augmentation_ring_many_.pdf')

In [31]:
mwp_ring_list_.shape

(10910, 300, 300, 3)

In [32]:
#保存
np.save('std_many/augmentation_for_ssd.npy', mwp_ring_list)
frame_mwp.to_csv('std_many/augmentation_for_ssd_label.csv')

In [33]:
frame_mwp[frame_mwp['id']==490]

Unnamed: 0,fits,name,xmin,xmax,ymin,ymax,id
0,spitzer_00900+0000_rgb,[2G0081525+0024103],[0],[0.813003927681147],[0.10319974867147803],[0.9165170333631977],490


In [34]:
frame_mwp[frame_mwp['id']==60]

Unnamed: 0,fits,name,xmin,xmax,ymin,ymax,id
0,spitzer_00300+0000_rgb,[2G0037405+0002309],[0.22475151631875193],[0.4063290474737769],[0.1382860234576225],[0.3198635546126413],60


In [35]:
for i in range(len(frame_mwp)):
    series = frame_mwp.iloc[i]
    if len(series['xmin'])==0:
        print(i)

In [12]:
# mwp_ring_list_ = np.load('std_many/augmentation_for_ssd.npy')

In [13]:
# frame_mwp = pd.read_csv('std_many/augmentation_for_ssd_label.csv')

In [14]:
# import ast
# frame_mwp['xmin'] = [ast.literal_eval(d) for d in frame_mwp['xmin']]
# frame_mwp['xmax'] = [ast.literal_eval(d) for d in frame_mwp['xmax']]
# frame_mwp['ymin'] = [ast.literal_eval(d) for d in frame_mwp['ymin']]
# frame_mwp['ymax'] = [ast.literal_eval(d) for d in frame_mwp['ymax']]

In [15]:
# mwp_ring_list_ = mwp_ring_list_*255
# mwp_ring_list_ = np.uint8(mwp_ring_list_)

# data_view_rectangl(20, mwp_ring_list_[::10], frame_mwp[::10]).save('Augmentation_ring_many_noid.pdf')