In [None]:
from matplotlib import pyplot as plt
import numpy as np
import cv2

In [None]:
def kp2pt(kp):
    (px, py) = kp.pt
    return (int(px), int(py))

def getPairs(keypointmathces, keypoints1, keypoints2):
    pairs = []
    for mt in keypointmathces:
        p1 = keypoints1[mt.queryIdx]
        p2 = keypoints2[mt.trainIdx]
        pairs.append((kp2pt(p1), kp2pt(p2)))
    return pairs

def calcDist(pa, pb):
    (pax, pay) = pa
    (pbx, pby) = pb
    dx = pax - pbx
    dy = pay - pby
    return np.sqrt(dx * dx + dy * dy)

def addTupple(t1, t2):
    (t1x, t1y) = t1
    (t2x, t2y) = t2
    return (t1x + t2x, t1y + t2y)

def calcError(keypointmathces, keypoints1, keypoints2, center, td):
    pairs = getPairs(keypointmathces, keypoints1, keypoints2)
    c2 = addTupple(center, td)
    sum = 0
    for (p1, p2) in pairs:
        d1 = calcDist(p1, center)
        d2 = calcDist(p2, c2)
        sum += abs(d1 - d2)
    return sum

def optimizeCenter(keypointmathces, keypoints1, keypoints2, center):
    dx = 0
    dy = 0
    step = 0.1
    NKEY = 50
    for i in range(2000):
        if(i%2 == 0):
            e1 = calcError(keypointmathces, keypoints1, keypoints2, center, (dx + step, dy))
            e2 = calcError(keypointmathces, keypoints1, keypoints2, center, (dx - step, dy))
            if(e1 < e2):
                dx += step
            else:
                dx -= step
        else:
            e1 = calcError(keypointmathces, keypoints1, keypoints2, center, (dx, dy + step))
            e2 = calcError(keypointmathces, keypoints1, keypoints2, center, (dx, dy - step))
            if(e1 < e2):
                dy += step
            else:
                dy -= step
    return (int(np.round(-dx)), int(np.round(-dy)))

In [None]:
def correctTranslate(original, attacked, center):
    
    # Initiate ORB detector
    orb = cv2.ORB_create()
    
    # find the keypoints and descriptors with ORB
    kp1, des1 = orb.detectAndCompute(original,None)
    kp2, des2 = orb.detectAndCompute(attacked,None)
    
    # create BFMatcher object
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)

    # Match descriptors.
    matches = bf.match(des1, des2)

    # Sort them in the order of their distance.
    matches = sorted(matches, key = lambda x:x.distance)
    
    return optimizeCenter(matches[:50], kp1, kp2, center)