In [None]:
from matplotlib import pyplot as plt
import numpy as np
import cv2

In [None]:
def kp2pt(kp):
    (px, py) = kp.pt
    return (int(px), int(py))

def getPairs(keypointmathces, keypoints1, keypoints2):
    pairs = []
    for mt in keypointmathces:
        p1 = keypoints1[mt.queryIdx]
        p2 = keypoints2[mt.trainIdx]
        pairs.append((kp2pt(p1), kp2pt(p2)))
    return pairs

def calcDist(pa, pb):
    (pax, pay) = pa
    (pbx, pby) = pb
    dx = pax - pbx
    dy = pay - pby
    return np.sqrt(dx * dx + dy * dy)

def addTupple(t1, t2):
    (t1x, t1y) = t1
    (t2x, t2y) = t2
    return (t1x + t2x, t1y + t2y)

def calcError(keypointmathces, keypoints1, keypoints2, center1, center2, s):
    pairs = getPairs(keypointmathces, keypoints1, keypoints2)
    sum = 0
    for (p1, p2) in pairs:
        d1 = calcDist(p1, center1)
        d2 = s * calcDist(p2, center2)
        sum += abs(d1 - d2)
    return sum / len(pairs)

def optimizeCenter(keypointmathces, keypoints1, keypoints2, c1, c2):
    dx = 0
    dy = 0
    s = 1
    d_step = 0.1
    s_step = 0.0005
    N_STEP = 10000
    
    err = 0
    
    for i in range(N_STEP):
        
        if(i%3 == 0):
            e1 = calcError(keypointmathces[:], keypoints1, keypoints2, (c1, c1), (c2 + dx + d_step, c2 + dy), s)
            e2 = calcError(keypointmathces[:], keypoints1, keypoints2, (c1, c1), (c2 + dx - d_step, c2 + dy), s)
            if(e1 < e2):
                err = e1
                dx += d_step
            else:
                err = e2
                dx -= d_step
        elif(i%3 == 1):
            e1 = calcError(keypointmathces[:], keypoints1, keypoints2, (c1, c1), (c2 + dx, c2 + dy + d_step), s)
            e2 = calcError(keypointmathces[:], keypoints1, keypoints2, (c1, c1), (c2 + dx, c2 + dy - d_step), s)
            if(e1 < e2):
                err = e1
                dy += d_step
            else:
                err = e2
                dy -= d_step
        else:
            e1 = calcError(keypointmathces[:], keypoints1, keypoints2, (c1, c1), (c2 + dx, c2 + dy), s + s_step)
            e2 = calcError(keypointmathces[:], keypoints1, keypoints2, (c1, c1), (c2 + dx, c2 + dy), s - s_step)
            if(e1 < e2):
                err = e1
                s += s_step
            else:
                err = e2
                s -= s_step
                
    return (int(np.round(-dx)), int(np.round(-dy)), s)

In [None]:
def correctTranslate(original, attacked):
    
    # Initiate ORB detector
    orb = cv2.ORB_create()
    
    # find the keypoints and descriptors with ORB
    kp1, des1 = orb.detectAndCompute(original,None)
    kp2, des2 = orb.detectAndCompute(attacked,None)
    
    # create BFMatcher object
    bf = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True)

    # Match descriptors.
    matches = bf.match(des1, des2)

    # Sort them in the order of their distance.
    matches = sorted(matches, key = lambda x:x.distance)
    
    (n1, *_) = original.shape
    (n2, *_) = attacked.shape
    
    return optimizeCenter(matches[:50], kp1, kp2, (n1 - 1) / 2, (n2 - 1) / 2)