In [None]:
import numpy as np
from osgeo import gdal
from read_tif import read_tif  
import os

def write_tif(im_data, im_geotrans, im_proj, path):
    """
    Write a numpy array as a GeoTIFF file.

    Parameters
    ----------
    im_data : ndarray
        Raster data (2D or 3D). If 2D, will be converted to a single-band image.
    im_geotrans : tuple
        Affine transform (from original image).
    im_proj : str
        Projection information (WKT).
    path : str
        Output file path.
    """

    # Replace invalid values (inf, -inf, NaN) with a nodata value
    im_data = np.where(np.isfinite(im_data), im_data, -999)

    # Define GDAL data type
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32

    # Ensure im_data has shape (bands, rows, cols)
    if len(im_data.shape) == 2:
        im_data = np.expand_dims(im_data, axis=0)
    im_bands, im_height, im_width = im_data.shape

    # Create GeoTIFF
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)

    nodata_value = -999
    if dataset is not None:
        dataset.SetGeoTransform(im_geotrans)
        dataset.SetProjection(im_proj)

        for i in range(im_bands):
            band = dataset.GetRasterBand(i + 1)
            band.SetNoDataValue(nodata_value)
            band.WriteArray(im_data[i])

    # Release resources
    dataset = None

In [None]:
if __name__ == "__main__":
    ref_path = "example.tif"     # replace with your own file
    data_ref, X, Y, geotrans, proj, bands = read_tif(ref_path)
    write_tif(data_ref, geotrans, proj, "copy_example.tif")
    print("Saved:", "copy_example.tif")