In [None]:
def IHS(data_low,data_high):
    """
    Fusion algorithm based on IHS transform
         Input: three-dimensional array in np.ndArray format
         Returns: a three-dimensional array in utf-8 format that can draw the image
    """
    A = [[1./3.,1./3.,1./3.],[-np.sqrt(2)/6.,-np.sqrt(2)/6.,2*np.sqrt(2)/6],[1./np.sqrt(2),-1./np.sqrt(2),0.]] 
         #RGB－>IHS positive transformation matrix
    B = [[1.,-1./np.sqrt(2),1./np.sqrt(2)],[1.,-1./np.sqrt(2),-1./np.sqrt(2)],[1.,np.sqrt(2),0.]] 
         #IHS－>RGB inverse transformation matrix
    A = np.matrix(A)
    B = np.matrix(B)
    band , w , h = data_high.shape
    pixels = w * h
    data_low = data_low.reshape(3,pixels)
    data_high = data_high.reshape(3,pixels)
         a1 = np.dot(A, np.matrix(data_high))#High score image positive transformation
         a2 = np.dot(A, np.matrix(data_low))#Low score image positive transformation
         a2[0,:] = a1[0,:]#Replace the first band of low-resolution image with the first band of high-resolution image
         RGB = np.array(np.dot(B, a2))#Inverse transformation of fusion image
    RGB = RGB.reshape((3,h,w))
    min_val = np.min(RGB.ravel())
    max_val = np.max(RGB.ravel())
    RGB = np.uint8((RGB.astype(np.float) - min_val) / (max_val - min_val) * 255)
    RGB = Image.fromarray(cv2.merge([RGB[0],RGB[1],RGB[2]]))
return RGB
 
def imresize(data_low,data_high):
    """
         Image scaling function
         Input: three-dimensional array in np.ndArray format
         Returns: a three-dimensional array in np.ndArray format
    """
    band , col , row = data_high.shape
    data = np.zeros(((band,col,row)))
    for i in range(0,band):
            data[i] = smi.imresize(data_low[i],(col,row))
return data
 
def gdal_open(path):
    """
         Read image function
         Input: image path
         Returns: a three-dimensional array in np.ndArray format
    """
    data = gdal.Open(path)
         col = data.RasterXSize#Read image length
         row = data.RasterYSize#Read image width
         data_array_r = data.GetRasterBand(1).ReadAsArray(0,0,col,row).astype(np.float)#Read the first band of the image and convert it to an array
         data_array_g = data.GetRasterBand(2).ReadAsArray(0,0,col,row).astype(np.float)#Read the second band of the image and convert it to an array
         data_array_b = data.GetRasterBand(3).ReadAsArray(0,0,col,row).astype(np.float)#Read the third band of the image and convert it to an array
    data_array = np.array((data_array_r,data_array_g,data_array_b))
return data_array


def main(path_low,path_high):
    data_low = gdal_open(path_low)
    data_high = gdal_open(path_high)
    data_low = imresize(data_low,data_high)
    RGB = IHS(data_low,data_high)
    RGB.save("IHS.png",'png')
    
    
    
if __name__ == "__main__":
    path_low = 'RGB.tif'
    path_high = 'Band8.tif'
    main(path_low,path_high)