In [None]:
import os
import rasterio
from rasterio.warp import calculate_default_transform, reproject, Resampling
from pyproj import CRS # don't use rasterio CRS

def reproject_to_crs(crs, input_path, output_path):
    new_crs = CRS.from_epsg(crs)
    with rasterio.open(input_path) as dataset:
        original_crs = dataset.crs
        transform = dataset.transform
        data = dataset.read()
        print(data.shape)

        # Calculate the transform and dimensions for the new CRS
        new_transform, new_width, new_height = calculate_default_transform(
            original_crs, new_crs, dataset.width, dataset.height, *dataset.bounds)
        
        # Define metadata for the output raster
        out_meta = dataset.meta.copy()
        out_meta.update({
            "crs": new_crs,
            "transform": new_transform,
            "width": new_width,
            "height": new_height
        })
        
        # Reproject and write each band
        with rasterio.open(output_path, 'w', **out_meta) as dst:
            for i in range(1, dataset.count + 1):  # Loop through each band
                reproject(
                    source=rasterio.band(dataset, i),
                    destination=rasterio.band(dst, i),
                    src_transform=transform,
                    src_crs=original_crs,
                    dst_transform=new_transform,
                    dst_crs=new_crs,
                    resampling=Resampling.nearest)  # Choose resampling method as needed