In [1]:
import matplotlib as mpl
mpl.use('Agg')
import pandas as pd
import numpy as np
import os,time,pickle,bz2,datetime,fnmatch,random,logging,argparse,sys
from scipy.signal import savgol_filter
from scipy.optimize import minimize,curve_fit,leastsq
import matplotlib.pyplot as plt
from datetime import datetime


In [2]:
BYTES_PER_PIXEL = 2
SATURATION_VALUE = 2**10-1 #for CU40 camera
STRIPE_SPACING = 7.4
H = 720//3 # hard coded - sloppy! - I cut off the top and bottom thirds of the image to save data
N_TEST = 101 #number of files to find stripe positions on
REPORT_FOLDER = 'report'
EXP_DARK = 6300 #the longest exposure time to use for calibration (shorter than the shortest one ever used at nighttime)


In [3]:
def startlogger(logfilename):
    logger = logging.getLogger()
    logger.setLevel(logging.DEBUG)
    formatter = logging.Formatter()
    fh = logging.FileHandler(logfilename)
    fh.setLevel(logging.DEBUG)
    fh.setFormatter(formatter)
    logger.addHandler(fh)
    ch = logging.StreamHandler()
    ch.setLevel(logging.DEBUG)
    ch.setFormatter(formatter)
    logger.addHandler(ch)
    return logger,ch,fh

def log2html(logger,message,fmt=None):
    if fmt==None:
        msg = message
    elif fmt=='p':
        msg = '<p>{:s}</p>'.format(message)
    elif fmt=='pre':
        msg = '<pre>{:s}</pre>'.format(message)
    elif fmt[0]=='h':
        msg = '<h{:s}>{:s}</h{:s}>'.format(fmt[1],message,fmt[1])
    elif fmt=='i':
        msg = '<img src="{:s}" alt="{:s}"><br>'.format(message,message)
    logger.info(msg)
    return None

def update_progress(progress):
    barLength = 30 # Modify this to change the length of the progress bar
    status = ""
    if isinstance(progress, int):
        progress = float(progress)
    if not isinstance(progress, float):
        progress = 0
        status = "error: progress var must be float\r\n"
    if progress < 0:
        progress = 0
        status = "Halt...\r\n"
    if progress >= 1:
        progress = 1
        status = "Done...\r\n"
    block = int(round(barLength*progress))
    text = "\rPercent: [{0}] {1:.2f}% {2}".format( "#"*block + "-"*(barLength-block), progress*100, status)
    sys.stdout.write(text)
    sys.stdout.flush()

def get_phase_vslice(vslice,stripespacing):
    x = np.arange(len(vslice))
    cos_slice = lambda x: np.cos(2*np.pi*x/stripespacing)
    sin_slice = lambda x: np.sin(2*np.pi*x/stripespacing)
    cos_overlap = np.sum(cos_slice(x)*vslice)
    sin_overlap = np.sum(sin_slice(x)*vslice)
    lcol_phase = np.arctan2(sin_overlap,cos_overlap)
    return lcol_phase


# In[4]:


def findstripes(dat,num_fibers,stripespacing=14.9,halfwidth=-1,background_top=False):
    '''
    Finds the rows that are at the centre of the horizontal stripes (from spectra of optical fiber tips)
    inputs:
        dat: image data (h,w) - monochrome
        num_fibers: number of stripes sought
    returns:
        stripelocs: rows position at the left side of the frame for each stripe
        bak: estimated image background - mean of the top row in the image (scalar)
        halfwidth: half-width of each stripe
    
    This method uses sin/cos functions to calculate the phase of the variation in signal from the stripes along the vertical axis
    to define an offset and tilt.  
    Then picks the num_fibers bins with the largest signals corresponding to the peaks in this matched cos function
    '''
    # first find the stripes:
    #columns to use for finding the tilt:
    h,w = dat.shape
    lcol = int(0.3*w)
    rcol = int(0.7*w)
    if halfwidth==-1:
        halfwidth = int(stripespacing*0.4) #more conservative to capture all data
    x = np.arange(h)
    vsliceleft = dat[:,lcol-10:lcol+10].sum(axis=1)
    vsliceright = dat[:,rcol-10:rcol+10].sum(axis=1)
    #now we have two coordinates that define a centerline for the stripe:
    colspacing = 50
    phasecols = np.arange(lcol,rcol,colspacing)
    phases = np.array([get_phase_vslice(dat[:,pc],stripespacing) for pc in phasecols])
    phases = np.unwrap(phases)
    z = np.polyfit(phasecols, phases, 1)
    tilt = z[0]*stripespacing/(2*np.pi)
    bak = np.mean(dat[0])
    lcol_phase = phases[0]
    #split into bins centered at the middles of the slices
    offset = lcol_phase*stripespacing/(2*np.pi)
    all_stripelocs = np.arange(offset,h,stripespacing)
    all_stripeidxs = all_stripelocs.astype(np.int)
    all_stripe_bins = [x[i-halfwidth:i+halfwidth] for i in all_stripeidxs]
    all_stripe_vals = np.array([vsliceleft[bin].sum() for bin in all_stripe_bins])
    stripe_bunch_vals = np.array([all_stripe_vals[i:i+num_fibers].sum() for i in range(len(all_stripe_vals)-num_fibers)]) #sum groups of fiber signals
    stripes = np.arange(num_fibers) + stripe_bunch_vals.argmax() #pick the group with the highest score - these are where the stripes are.
    stripelocs = all_stripelocs[stripes] - lcol*tilt
    if background_top:
        stripelocs = np.hstack([stripelocs[0] - stripespacing,stripelocs])
    else:
        stripelocs = np.hstack([stripelocs,stripelocs[-1] + stripespacing]) #add one more for the background just beyond the last fiber.
    return stripelocs,bak,halfwidth,tilt


# In[5]:


def import_oneframe_raw(fn,stripe_params):
    if fn[-3:]=='bz2': #if file is compressed
        input_file = bz2.BZ2File(fn, 'rb')
        npdat = np.frombuffer(input_file.read(),dtype=np.uint16).astype(np.uint16)
        input_file.close()
    elif fn[-3:]=='raw':
        npdat = np.fromfile(fn,dtype=np.uint16).astype(np.uint16)
    else:
        log2html(logger,'don\'t know about this file format',fmt='p')
    dat_all = npdat.reshape(-1,stripe_params['shape'][0],stripe_params['shape'][1])
    nframes = dat_all.shape[0]
    #choose one frame at random
    dat = dat_all[np.random.randint(nframes)]
    dat_color = np.array([dat[::2,::2],dat[::2,1::2],dat[1::2,::2],dat[1::2,1::2]])
    return dat_color


# In[6]:


def plot_raw_image(dat_color,fn):
    dat_mono = dat_color.sum(axis=0)
    fig = plt.figure(figsize=(8,2))
    ax = fig.add_subplot(111)
    ax.imshow(dat_mono)
    ax.set_axis_off()
    image_fn = fn[:-3]+'png'
    fig.savefig(os.path.join(REPORT_FOLDER,image_fn))
    return image_fn


# In[7]:


def init_stripe_parmas():
    num_fibers=8
    camera_model='See3CAM_CU40'
    camera_serno=None
    stripespacing=10
    shape=(720,1280)
    wavelength_params=(1.6666666666666667e-06, -4.46384411e-01, -1.49934459e+03)
    stripe_params = {}
    stripe_params['num_fibers'] = num_fibers
    stripe_params['shape'] = shape
    stripe_params['camera_model'] = camera_model
    if camera_model=='See3CAM_CU40':
        stripe_params['saturation_value'] = 2**10-1 #10-bit for CU40 camera
    stripe_params['camera_serno'] = camera_serno

    return stripe_params


# In[8]:


def plot_stripe_positions(dat,stripe_params,fn,display=False):
    '''
    Make a pretty plot showing how the stripes line up with the extracted spectral bins
    '''
    stripelocs = stripe_params['stripelocs']
    halfwidth = stripe_params['halfwidth']
    tilt = stripe_params['tilt']
    hmax = np.argmax(dat.sum(axis=0))
    vslice = dat[:,hmax-50:hmax+50].sum(axis=1)
    h,w = dat.shape
    x = np.arange(h)
    y = vslice - vslice.min()
    fig = plt.figure()
    ax = fig.add_subplot(111,xlim = (stripelocs.min()-2*halfwidth+tilt*hmax,stripelocs.max()+2*halfwidth+tilt*hmax))
    for s in stripelocs:
        ax.plot(x,y.max()*np.logical_and(x>s+hmax*tilt-halfwidth,x<s+hmax*tilt+halfwidth))
    ax.plot(y)
    ax.set_xlabel('pixel row')
    ax.set_ylabel('Signal (summed pixel counts, across row, background removed)')
    ax.set_title('Locations of the stripes used in data analysis')
    fig_fn = fn[:-4]+'_'+stripe_params['camera_serno']+'_stripe_locations.png'
    if display:
        fig.show()
    else:
        fig.savefig(fig_fn)
    return fig_fn


# In[9]:


def optimize_stripelocs(dat,stripe_params):
    yshift_range = stripe_params['halfwidth']*2
    cutoff = yshift_range*2
    pxvals = np.zeros(yshift_range)
    halfwidth_small=1
    rgrid,bgrid,ggrid1,ggrid2 = sw.colorgrids(dat)
    stripe_masks_cut = sw.get_stripe_masks(dat[:-cutoff],
                                       stripelocs-yshift_range,
                                       tilt,halfwidth_small,
                                       rgrid[:-cutoff],
                                       bgrid[:-cutoff],
                                       ggrid1[:-cutoff],
                                       ggrid2[:-cutoff])
    for i,shift_y in enumerate(np.arange(yshift_range)*2):
        rgbspectra,satpxs = sw.getstripes(dat[shift_y:-cutoff+shift_y],stripe_masks_cut,stripelocs,tilt,halfwidth_small)
        pxvals[i] = np.sum(rgbspectra[:-1,:,0])
    real_shift_vals = np.arange(yshift_range)*2-yshift_range
    best_shift = real_shift_vals[np.argmax(pxvals)]
    stripelocs = stripelocs+best_shift
    return stripelocs


# In[10]:


def get_curves(dat,stripe_params,xidx,yidx):
    n_colors = dat.shape[0]
    curves = np.zeros([len(stripe_params['stripelocs']),stripe_params['shape'][1],n_colors])
    for c in range(n_colors): #loop through colors
        stripes = dat[c][(yidx,xidx)] #returns selected data in one long vector
        stripes_reshaped = stripes.reshape(-1,len(stripe_params['stripelocs']),stripe_params['halfwidth']*2) #separate the stripes apart
        satpx = np.sum(stripes_reshaped==stripe_params['saturation_value'])  #calculate the number of saturated pixels
        curves[:,:,c] = stripes_reshaped.sum(axis=2).T #sum vertically across each stripe
    return curves

# In[11]:


def get_stripe_masks_mono(datshape,stripelocs,tilt,halfwidth,verbose=True):    
    '''
    Requires:
        dat: image data
        stripelocs: returned from 'findstripes'
        tilt: returned from 'findstripes'
    returns:
        stripe_mask: masks for each stripe, with tilt and Bayer filter embedded
    '''
    #separate the colors
    ygrid,xgrid = np.indices(datshape)
    stripe_masks = np.zeros([len(stripelocs),datshape[0],datshape[1]],dtype=np.bool)
    for i in range(len(stripelocs)):
        yint = stripelocs[i]
        ystr = np.round(yint+tilt*np.arange(datshape[1])).astype(np.int) #vertical center of stripe at each point in the picture
        stripe_masks[i,:,:] = np.logical_and(ygrid>ystr-halfwidth,ygrid<=ystr+halfwidth)
    return stripe_masks


# In[12]:


def get_manystripes_indices(datshape,stripelocs,tilt,halfwidth):    
    '''
    Parameters:
    -----------
        datshape    : tuple (int,int)
                        image shape 
        stripelocs  : list
                        vertical locations of the stripes in the image
        tilt        : float
                        tangent of the tilt angle of the stripes 
    Returns:
    --------
        stripe_indices  : tuple (array,array)
                            indices for each stripe, with tilt and Bayer filter embedded
                            returned from np.where call
    '''
    stripe_masks = get_stripe_masks_mono(datshape,stripelocs,tilt,halfwidth,verbose=True)
    manystripe_mask = np.sum(stripe_masks,axis=0)
    xidx,yidx = np.where(manystripe_mask.T)
    return xidx,yidx

def import_oneframe_raw(fn,stripe_params,verbose=False):
    if fn[-3:]=='bz2': #if file is compressed
        input_file = bz2.BZ2File(fn, 'rb')
        npdat = np.fromstring(input_file.read(),dtype=np.uint16).astype(np.uint16)
        input_file.close()
    elif fn[-3:]=='raw':
        npdat = np.fromfile(fn,dtype=np.uint16).astype(np.uint16)
    else:
        log2html(logger,'don\'t know about this file format',fmt='p')
    dat_all = npdat.reshape(-1,stripe_params['shape'][0],stripe_params['shape'][1])
    nframes = dat_all.shape[0]
    dat = dat_all[np.random.randint(nframes)]
    dat_color = np.array([dat[::2,::2],dat[::2,1::2],dat[1::2,::2],dat[1::2,1::2]])
    if verbose:
        dat_mono = dat_color.sum(axis=0)
        fig = plt.figure()
        ax = fig.add_subplot(111)
        ax.imshow(dat_mono)
        ax.set_axis_off()
        fig.savefig(fn[:-3]+'png')
    return dat_color

In [4]:
#args = parser.parse_args()
#global NUM_FIBERS
NUM_FIBERS = 12

global BACKGROUND_TOP
BACKGROUND_TOP = args.background_top

if not os.path.exists(REPORT_FOLDER):
    os.makedirs(REPORT_FOLDER)

logger,ch,fh = startlogger(os.path.join(REPORT_FOLDER,'calibration_log.html'))

# In[3]:
#initiate the html file:
logger.info('<!DOCTYPE html>\n<html>\n<body>\n')
log2html(logger,'Vertical calibration report',fmt='h1')
log2html(logger,'HydraSpectra sensor',fmt='h2')
log2html(logger,'script run on {:%B %d, %Y}'.format(datetime.now()),fmt='h3')
fncsv = 'hydraspectra_metadata.csv'
if os.path.isfile(fncsv):
    log2html(logger,'found {:s}'.format(fncsv),fmt='p')
else:
    log2html(logger,'Error: can\'t find file {:s}'.format(fncsv),fmt='p')


# In[4]:

log2html(logger,'loading metadata...',fmt='p')
metadata_df = pd.read_csv(fncsv)
sLength = len(metadata_df)
metadata_df = metadata_df.assign(file_time=pd.Series(np.zeros(sLength)).values)

# In[5]:
log2html(logger,'locating data files...',fmt='p')
matches = {}
for root, dirnames, filenames in os.walk('.'):
    for filename in fnmatch.filter(filenames, '*.raw.bz2'):
        matches[os.path.basename(os.path.join(root, filename))] = os.path.join(root, filename)


# In[6]:

for i in metadata_df.index.values:
    base_fn = metadata_df.loc[i,'filename']
    if base_fn in matches.keys():
        metadata_df.loc[i,'filename'] = matches[base_fn]
        metadata_df.loc[i,'file_time'] = os.path.getmtime(base_fn)
    else:
        metadata_df.loc[i,'filename'] = ''
metadata_df = metadata_df[metadata_df.filename != '']
log2html(logger,'found {:d} spectral data files'.format(len(metadata_df)),fmt='p')
log2html(logger,metadata_df.head().to_string(),fmt='pre')


# In[25]:



# In[27]:

#allstripes_index = random.choice(metadata_df.index.values)
#fn_allstripes = metadata_df.loc[allstripes_index,'filename']
#default values for CU40 camera
stripe_params = {'saturation_value': SATURATION_VALUE,
                 'num_fibers': NUM_FIBERS,
                 'tilt': 0,
                 'camera_serno': metadata_df.loc[0,'serial_no'],
                 'shape': (H, metadata_df.loc[0,'w']),
                 'camera_model': metadata_df.loc[0,'model'],
                 'halfwidth': 3,
                 'stripespacing':STRIPE_SPACING,
                 'stripelocs': None,
                 'wavelength_params': (1.6666666666666667e-06, -0.446384411, -1499.34459)
                }

# In[20]:
#consider only images that used the shortest exposures in the dataset
exposures_used = np.unique(metadata_df.exposure.values)
log2html(logger,'exposures used:{}'.format(exposures_used),fmt='p')
short_exp_df = metadata_df[metadata_df.exposure<=EXP_DARK]
#pick data points at random, and find the vertical positions
log2html(logger,'calculating vertical positions of fibres in {:d} randomly chosen images:'.format(N_TEST),fmt='p')
for j in range(N_TEST):
    i = random.choice(short_exp_df.index.values)
    fn = short_exp_df.loc[i,'filename']
    dat_mono = import_oneframe_raw(fn,stripe_params).sum(axis=0)
    stripelocs,bak,halfwidth,tilt = findstripes(dat_mono,stripe_params['num_fibers'],stripespacing=STRIPE_SPACING,background_top=BACKGROUND_TOP)
    #stripelocs,bak,halfwidth,tilt = findstripes(dat_mono,stripe_params['num_fibers'],stripespacing=STRIPE_SPACING)
    metadata_df.loc[i,'first_stripe_location'] = stripelocs[0]
    update_progress(j/N_TEST)
#log2html(logger,metadata_df.loc[i,'filename']+' : '+short_exp_df.loc[i,'filename'],fmt='p')
# In[29]:

#choose an image to use for the stripe parameters:
test_df = metadata_df[~np.isnan(metadata_df['first_stripe_location'].values)]
vert_stripe_median = np.median(test_df['first_stripe_location'].values)
min_idx = np.argmin(np.abs(test_df['first_stripe_location']-vert_stripe_median))
m = test_df.loc[min_idx,'first_stripe_location']
# In[30]:

n_tested_files = len(test_df)
cal_basename = os.path.basename(test_df[test_df.index==min_idx]['filename'].values[0])
fig_fn = os.path.join(REPORT_FOLDER,cal_basename.split('.')[0]+'_fibre_drift.png')
log2html(logger,'Tested {:d} files</p>'.format(n_tested_files))
fig, ax = plt.subplots()
ax.plot(test_df.index,test_df['first_stripe_location'],'o')
vert_stripe_median = np.median(test_df['first_stripe_location'].values)
ax.plot(test_df.index,np.ones(n_tested_files)*(vert_stripe_median+halfwidth),'r--',label='width of single stripe')
ax.plot(test_df.index,np.ones(n_tested_files)*(vert_stripe_median-halfwidth),'r--')
ax.plot(test_df.index,np.ones(n_tested_files)*(vert_stripe_median),'g-')
ax.plot(min_idx,m,'r*',label='image used for calibration')
ax.set_xlabel('File index number')
ax.set_ylabel('Vertical position (pixels)')
ax.set_title('Drift of fibre image positions')
leg = ax.legend(loc='best', fancybox=True)
leg.get_frame().set_alpha(0.5)
fig.savefig(fig_fn)
log2html(logger,fig_fn.split('/')[1],fmt='i')

# In[31]:
fn = test_df[test_df.index==min_idx]['filename'].values[0]
log2html(logger,'Image used for calibration:\n{:s}'.format(fn),fmt='pre')

# In[20]:


dat_color = import_oneframe_raw(fn,stripe_params)
dat_mono = dat_color.sum(axis=0)


# In[21]:


image_fn = plot_raw_image(dat_color,fn)
log2html(logger,image_fn.split('/')[1],fmt='i')

# In[34]:


stripelocs,bak,halfwidth,tilt = findstripes(dat_mono,stripe_params['num_fibers'],stripespacing=7.4,background_top=BACKGROUND_TOP)


# In[35]:


#update stripe_params values
w = dat_mono.shape[1]
stripe_params['stripelocs'] = stripelocs#-b
stripe_params['tilt'] = tilt
stripe_params['shape'] = dat_mono.shape
stripe_params['halfwidth'] = halfwidth
if not('calibration_date' in stripe_params.keys()):
    stripe_params['calibration_date'] = time.strftime("%d_%b_%Y", time.gmtime(os.path.getmtime(fn)))


# In[36]:


stripe_pos_fn = plot_stripe_positions(dat_mono,stripe_params,os.path.join(REPORT_FOLDER,os.path.basename(fn)[:-4]))
log2html(logger,os.path.basename(stripe_pos_fn),fmt='i')

# In[37]:


stripe_indices = get_manystripes_indices(dat_color.shape[1:],stripe_params['stripelocs'],
                                     stripe_params['tilt'],stripe_params['halfwidth'])
xidx,yidx = stripe_indices
len(xidx)


# In[38]:


curves = get_curves(dat_color,stripe_params,xidx,yidx)

fig = plt.figure(figsize=(10,7))
ax = fig.add_subplot(111)
for f in range(len(stripelocs)):
    ax.plot(curves[f].sum(axis=1))
ax.set_xlabel('Pixel')
ax.set_ylabel('Signal (summed pixel counts, vertically across stripe)')
ax.set_title('Raw Spectra: {}'.format(fn))
image_fn = fn.split('/')[1].split('.')[0]+'_curves.png'
fig.savefig(os.path.join(REPORT_FOLDER,image_fn))
log2html(logger,image_fn,fmt='i')
'''
stripe_params = find_stripes(fn_allstripes,
            num_fibers=NUM_FIBERS,
            camera_model=metadata_df['model'].values[allstripes_index],
            camera_serno=metadata_df['serial_no'].values[allstripes_index],
            stripespacing=STRIPE_SPACING,
            shape=(H,metadata_df['w'].values[allstripes_index]),
            save_raw_image=True
            )

log2html(logger,fn_allstripes[:-3]+'png',fmt='i')
log2html(logger,fn_allstripes[:-8]+'_'+stripe_params['camera_serno']+'_stripe_locations.png',fmt='i')

# In[32]:

stripe_params['calibration_date'] = datetime.fromtimestamp(os.path.getmtime(fn_allstripes)).strftime("%B %d %Y")
fnp = os.path.join(REPORT_FOLDER,'stripe_params_'+stripe_params['camera_serno']+'.pkl')
fspickle = open(fnp,'wb')
pickle.dump(stripe_params,fspickle)
fspickle.close()
'''
#finish the html report file:
logger.info('</body>\n</html>\n')

NameError: name 'args' is not defined