### Here we will visualize the outputs to a Migdal skim

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from migYOLO.utils.readYAML import read_config_file
from migYOLO.pipeline.pipeline import downsample

In [None]:
'''Open up globalConf'''
conf = read_config_file('globalConf.yaml')

In [None]:
conf

In [None]:
'''Load outputs'''
model = os.path.splitext(os.path.split(conf['yoloConf']['model'])[1])[0] #will be 'base' or 'augment'
yolodir = conf['yoloConf']['outpath']+'/%s/without_pixel_hits/'%(model)
migdal_candidate_dir = yolodir + 'migdal_candidates/'

In [None]:
os.listdir(yolodir)

In [None]:
'''Convenience functions for loading all YOLO output data. 
We add an "fnum" key to specify the file number so we can
link original frames with 'fnum' and 'original_index'''

import re
'''Convention is a number comes after an underscore in our filenames so we'll extract this'''
def find_number_in_filename(filename):
    match = re.search(r'_(\d+)', filename)
    if match:
        return match.group(1)
    return None

def load_data(migdal_candidates):
    if not migdal_candidates:
        path = yolodir
    else:
        path = migdal_candidate_dir
    dfs = []
    for fi in sorted(os.listdir(path)):
        if '.feather' in fi:
            df = pd.read_feather(path+fi)
            df['fnum'] = find_number_in_filename(fi)
            dfs.append(df)
    df = pd.concat(dfs)
    df.index = [i for i in range(0,len(df))]
    return df

In [None]:
'''Load all tracks'''

df = load_data(migdal_candidates = False)

In [None]:
'''Load Migdal candidates'''

migs = load_data(migdal_candidates = True)

### Lets plot some images

In [None]:
'''raw image directoru'''
imagepath = conf['downsample']['data_dir']

In [None]:
migs.columns

In [None]:
#Open up a single frame from an MTIFF
import tifffile
def quick_read(MTIFF_file,frame_index): 
    with tifffile.TiffFile(MTIFF_file) as tif: 
        image = tif.asarray(key=frame_index)
    return image

In [None]:
def plot_event(df,i,process_image):
    tmp = df.iloc[i]
    imfile_basename = 'Images_batch_'
    imname = imagepath+'/'+imfile_basename+str(tmp['fnum'])+'.MTIFF'
    im = quick_read(imname,tmp['original_index'])
    if process_image:
        a = downsample(im)
        im = a.processedImages
        im[im<0] = 0
        plt.imshow(np.log10(im+1),cmap='jet',vmin = 1.4,vmax=4)
    else:
        for col in ['colmin','colmax','rowmin','rowmax']:
            tmp[col] = tmp[col]*4
        plt.imshow(im,cmap='jet')
    colors = {0:'pink',1:'cyan',2:'red',3:'yellow',4:'goldenrod',5:'white',6:'green',7:'darkgreen',8:'white'}
    for cmin,cmax,rmin,rmax,pred in zip(tmp['colmin'],tmp['colmax'],tmp['rowmin'],tmp['rowmax'],tmp['prediction']):
        plt.hlines(rmin,cmin,cmax,color = colors[pred],lw=2)
        plt.hlines(rmax,cmin,cmax,color = colors[pred],lw=2)
        plt.vlines(cmin,rmin,rmax,color = colors[pred],lw=2)
        plt.vlines(cmax,rmin,rmax,color = colors[pred],lw=2)
    xmin = tmp['colmin'].min()
    xmax = tmp['colmax'].max()
    ymin = tmp['rowmin'].min()
    ymax = tmp['rowmax'].max()
    plt.xlim(xmin-5,xmax+5)
    plt.ylim(ymin-5,ymax+5)

In [None]:
plot_event(migs,4,process_image = True)

In [None]:
'''Function plots a line fit between the centroids of the bounding boxes'''

def plot_distance(df,i,process_image,BB=False):
    tmp = df.iloc[i]
    imfile_basename = 'Images_batch_'
    imname = imagepath+'/'+imfile_basename+str(tmp['fnum'])+'.MTIFF'
    im = quick_read(imname,tmp['original_index'])
    if process_image:
        a = downsample(im)
        im = a.processedImages
        im[im<0] = 0
        plt.imshow(np.log10(im+1),cmap='jet',vmin = 1.4,vmax=4)
    else:
        for col in ['colmin','colmax','rowmin','rowmax','centroidx','centroidy']:
            tmp[col] = tmp[col]*4
        plt.imshow(im,cmap='jet')
    colors = {0:'pink',1:'cyan',2:'red',3:'yellow',4:'goldenrod',5:'white',6:'green',7:'darkgreen',8:'white'}

    '''Perform line fit'''
    x1 = tmp['centroidx'][0]
    x2 = tmp['centroidx'][1]
    y1 = tmp['centroidy'][0]
    y2 = tmp['centroidy'][1]
    fit = np.polyfit([x1,x2],[y1,y2],1)
    
    '''Plot line'''
    xs = np.linspace(x1,x2,101)
    plt.plot(xs,fit[0]*xs+fit[1],color='w',lw=2)
    
    '''Optional: Plot bounding boxes'''
    if BB:
        for cmin,cmax,rmin,rmax,pred in zip(tmp['colmin'],tmp['colmax'],tmp['rowmin'],tmp['rowmax'],tmp['prediction']):
            plt.hlines(rmin,cmin,cmax,color = colors[pred],lw=2)
            plt.hlines(rmax,cmin,cmax,color = colors[pred],lw=2)
            plt.vlines(cmin,rmin,rmax,color = colors[pred],lw=2)
            plt.vlines(cmax,rmin,rmax,color = colors[pred],lw=2)
            
    
    xmin = tmp['colmin'].min()
    xmax = tmp['colmax'].max()
    ymin = tmp['rowmin'].min()
    ymax = tmp['rowmax'].max()
    plt.xlim(xmin-5,xmax+5)
    plt.ylim(ymin-5,ymax+5)

In [None]:
plot_distance(migs,7,process_image = True, BB=True)