In [None]:
import os
import cv2
import numpy as np
from pydng.core import RPICAM2DNG
from time import time
from io import BytesIO
from skimage.feature import blob_log
from matplotlib import pyplot as plt
from matplotlib.patches import Circle, Polygon

%matplotlib inline
figsize = (20,20)

def _plot(img, vmax=0.5):
    fig = plt.figure(figsize=figsize)
    ax = fig.add_subplot(111)
    ax.imshow(img, cmap='gray', interpolation='none')
    return ax

def _circle(ax, pt, rad=6,c=(1,0,0)):
    circ = Circle(pt,rad,facecolor=c)
    ax.add_patch(circ)
    
def _triangle(ax, tri, c=(1,0,0)):
#     print("got triangle", tri, "color",c)
      poly = Polygon(tri, True, color=c, fill=False)
      ax.add_patch(poly)
    
    
def openRaw(image_file):
    with open(image_file, 'rb') as fh:
        buf = BytesIO(fh.read())
    rpicam = RPICAM2DNG()
    raw = rpicam.__extractRAW__(buf)
    raw = np.array(raw, dtype=np.float)
    #0.2125 R + 0.7154 G + 0.0721 B

    carray = 0.2125*raw[::2,::2] + 0.7154 * (raw[1::2,::2] + raw[::2,1::2]) / 0.5 + 0.0721 * raw[1::2,1::2]
    carray = carray[200:-200,400:-500]
    carray /= np.max(carray)
    return carray    

def getConstellation(pts):
    '''
    
    '''
    # first get all possible triangles
    print("Number stars", len(pts), pts[0])
    N = len(pts)
    triangles = []
    tpoints = []
    
    for tt in range(N):
        for ttt in range(tt+1,N):
            for tttt in range(ttt+1,N):
                # the triangle now is pt tt,ttt,tttt
                pt1 = pts[tt][1::-1]
                pt2 = pts[ttt][1::-1]
                pt3 = pts[tttt][1::-1]
                l1 = np.sqrt(np.sum((pt1-pt2)**2))
                l2 = np.sqrt(np.sum((pt2-pt3)**2))
                l3 = np.sqrt(np.sum((pt3-pt1)**2))
                tri = [l1,l2,l3]
                tri.sort()
                triangles.append(np.array(tri))                
                tpoints.append(np.array((pt1,pt2,pt3)))
#     print("len",len(triangles),triangles[0])
    return triangles, tpoints

def compareConstelations(tri1,tri2):
    ptlist = []
    for cnt,triangle in enumerate(tri1):
        small = 100000000
        idx = None
        for cc,triangle2 in enumerate(tri2):
            distance = np.sqrt(np.sum((triangle-triangle2)**2))
            if distance < small:
                idx = cc
                small = distance
        ptlist.append((small,cnt,idx))
    ptlist.sort()
    return ptlist[:100]

def findMatching(blob1, blob2):
    # long brute force, need a better way
    matches = []
    for bl1 in blob1:
        for bl2 in blob2:
            distance = np.sqrt((bl1[0]-bl2[0])**2 + (bl1[1]-bl2[1])**2)
            if distance < 5:
                matches.append([bl1[0],bl1[1],bl2[0],bl2[1]])
                break
    return matches

def stackImagesRaw(file_list):

    orb = cv2.ORB_create()

    # disable OpenCL to because of bug in ORB in OpenCV 3.1
    cv2.ocl.setUseOpenCL(False)

    stacked_image = None
    first_blobs = None
    first_triangles = None
    first_tpoints = None
    for count,file in enumerate(file_list):
        print(file)
        
        imageF = openRaw(file)
        blobs_log = blob_log(imageF, min_sigma=2, max_sigma=10, num_sigma=5, threshold=.02)
        
        
        if stacked_image is None:
            # Save keypoints for first image
            stacked_image = imageF
            first_blobs = blobs_log
            first_triangles, first_tpoints = getConstellation(first_blobs)
#             ax = _plot(imageF)
#             for blob in first_blobs:
#                 y, x, r = blob
#                 _circle(ax,(x,y),c='red',rad=6)
        else:
            triangles, tpoints = getConstellation(blobs_log)
            print("Length is", len(triangles))
#             for blob in blobs_log:
#                 y, x, r = blob
#                 _circle(ax,(x,y),c='yellow',rad=3)
                
            matches = findMatching(first_blobs, blobs_log)
            print("Found matches", len(matches))
            src_pts = []
            dst_pts = []
            for match in matches:
                y1,x1,y2,x2 = match
                src_pts.append((y1,x1))
                dst_pts.append((y2,x2))
            
            src = np.array(src_pts, dtype=np.float32).reshape(-1,1,2)
            dst = np.array(dst_pts, dtype=np.float32).reshape(-1,1,2)
#             print(src.shape)

#             Estimate perspective transformation
            M, mask = cv2.findHomography(dst, src, cv2.RANSAC, 5.0)
            w, h = imageF.shape
            print("M",M)
            imageF = cv2.warpPerspective(imageF, M, (h, w))
            if count == len(file_list) - 1:
                ax = _plot(imageF)
                for match in matches:
                    y1,x1,y2,x2 = match
                    _circle(ax,(x1,y1),c='red',rad=6)
                    _circle(ax,(x2,y2),c='yellow',rad=3)
            stacked_image += imageF
#             count += 1
#             if count >= 25:
#                 return stacked_image
    return stacked_image

# Align and stack images by matching ORB keypoints
# Faster but less accurate
def stackImagesKeypointMatching(file_list):

    orb = cv2.ORB_create()

    # disable OpenCL to because of bug in ORB in OpenCV 3.1
    cv2.ocl.setUseOpenCL(False)

    stacked_image = None
    first_image = None
    first_kp = None
    first_des = None
    M = None
    count = 0
    for file in file_list:
        print(file)
#         image = cv2.imread(file,1)
#         print(type(image), image.shape)    
#         imageF = image.astype(np.float32) / 255
        
        imageF = openRaw(file)
        image = np.array(imageF*240, dtype=np.uint8)
        
        print(image.shape, imageF.shape)

        # compute the descriptors with ORB
        kp = orb.detect(image, None)
        kp, des = orb.compute(image, kp)

        # create BFMatcher object
        matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)

        if first_image is None:
            # Save keypoints for first image
            stacked_image = imageF
            first_image = image
            first_kp = kp
            first_des = des

        else:
             # Find matches and sort them in the order of their distance
            matches = matcher.match(first_des, des)
            matches = sorted(matches, key=lambda x: x.distance)
            
            for mm in matches:
                print(mm.distance)

            src_pts = np.float32(
                [first_kp[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2)
            dst_pts = np.float32(
                [kp[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2)
            
            print("type",type(src_pts), src_pts.shape)
            print(src_pts)
            
            # Estimate perspective transformation
            M, mask = cv2.findHomography(dst_pts, src_pts, cv2.RANSAC, 5.0)
            w, h = imageF.shape
            print("M",M)
            imageF = cv2.warpPerspective(imageF, M, (h, w))
            if count == 0:
                ax = _plot(imageF)
#                 for kppt in first_kp:
                for cnt in range(3):
                    _circle(ax,(src_pts[cnt,0,0],src_pts[cnt,0,1]),c=(1,0,0))
                    _circle(ax,(dst_pts[cnt,0,0]+6,dst_pts[cnt,0,1]),c=(0,1,0))
                return

            stacked_image += imageF
            count += 1
            if count >= 25:
                return stacked_image

#     stacked_image /= len(file_list)
#     stacked_image = (stacked_image*255).astype(np.uint8)
    return stacked_image

In [None]:
image_folder = "."
file_list = os.listdir(image_folder)
file_list = [os.path.join(image_folder, x) for x in file_list if x.startswith('testa_')]
file_list.sort()
# print(file_list)
stacked_image = stackImagesRaw(file_list)
_plot(stacked_image / np.max(stacked_image))

In [None]:
# ===== MAIN =====
# Read all files in directory
# import argparse

# if __name__ == '__main__':

#     parser = argparse.ArgumentParser(description='')
#     parser.add_argument('input_dir', help='Input directory of images ()')
#     parser.add_argument('output_image', help='Output image name')
#     parser.add_argument('--method', help='Stacking method ORB (faster) or ECC (more precise)')
#     parser.add_argument('--show', help='Show result image',action='store_true')
#     args = parser.parse_args()

#     image_folder = args.input_dir
#     if not os.path.exists(image_folder):
#         print("ERROR {} not found!".format(image_folder))
#         exit()

#     file_list = os.listdir(image_folder)
#     file_list = [os.path.join(image_folder, x)
#                  for x in file_list if x.endswith(('.jpg', '.png','.bmp'))]

#     if args.method is not None:
#         method = str(args.method)
#     else:
#         method = 'KP'

#     tic = time()

#     if method == 'ECC':
#         # Stack images using ECC method
#         description = "Stacking images using ECC method"
#         print(description)
#         stacked_image = stackImagesECC(file_list)

#     elif method == 'ORB':
#         #Stack images using ORB keypoint method
#         description = "Stacking images using ORB method"
#         print(description)
#         stacked_image = stackImagesKeypointMatching(file_list)

#     else:
#         print("ERROR: method {} not found!".format(method))
#         exit()

#     print("Stacked {0} in {1} seconds".format(len(file_list), (time()-tic) ))

#     print("Saved {}".format(args.output_image))
#     cv2.imwrite(str(args.output_image),stacked_image)

#     # Show image
#     if args.show:
#         cv2.imshow(description, stacked_image)
#         cv2.waitKey(0)