In [None]:
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt

In [None]:
ROOT_DIR = os.path.dirname(os.getcwd())
DATA_FOLDER = os.path.join(ROOT_DIR, "data")

In [None]:
def correct_distortion(image, k1=-0.5, k2=0):
    """
    Correct lens distortion using radial distortion parameters.
    k1, k2: radial distortion coefficients
    """
    height, width = image.shape[:2]
    camera_matrix = np.array([[width, 0, width/2],
                            [0, width, height/2],
                            [0, 0, 1]], dtype=np.float32)
    dist_coeffs = np.array([k1, k2, 0, 0], dtype=np.float32)
    
    # Get optimal new camera matrix
    newcameramtx, roi = cv2.getOptimalNewCameraMatrix(
        camera_matrix, dist_coeffs, (width, height), 1, (width, height))
    
    # Undistort the image
    dst = cv2.undistort(image, camera_matrix, dist_coeffs, None, newcameramtx)
    
    return dst

In [None]:
image = os.path.join(DATA_FOLDER, "distortion.png")

In [None]:
image1 = cv2.imread(image)
corrected1 = correct_distortion(image1)

In [None]:
plt.figure(figsize=(20,20))
plt.imshow(corrected1)

In [None]:
for i in np.arange(-1, 1, 0.2):
    for j in np.arange(-1, 1, 0.2):
        image1 = cv2.imread(image)
        corrected1 = correct_distortion(image1, k1=i, k2=j)
        plt.figure(figsize=(20,20))
        plt.imshow(corrected1)
        plt.show()