In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
pip install rasterio lightning pytorch-msssim kornia==0.6.3 rio-cogeo

In [None]:
import drive.MyDrive.SuperResolution12RV2.inference_pipline as ip
#import inference_pipline as ip
# Ben test
import sys
sys.path.append('/content/drive/MyDrive/SuperResolution12RV2')
# Ben test

import subprocess
from rasterio.merge import merge
import torch
import rasterio
from rasterio.plot import show
import matplotlib.pyplot as plt
import numpy as np
import os
import gc
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#device = torch.device('cpu')
torch.set_default_device(device=device)
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

In [None]:
src = "./drive/MyDrive/SuperResolution12RV2/"
site = 3
files = [f"{src}Site_{site}_Image_{num}.tif" for num in range(8)]
model_path = "/content/drive/MyDrive/SuperResolution12RV2/src/model.ckpt"
export_path = "/content/drive/MyDrive/SuperResolutionInference/site3-full/"
# Start x, start y, end x, end y
overlap = 13
chip_size = 26
chip_norm_method = "global"
verbose = False

In [None]:
# Ben test prediction export as COG from the start (with Error Catching)
# Added to the inference_pipeline file
mu,std = ip.full_inference_to_chips(files,
                                    model_path,
                                    export_path,
                                    batch_size=8,
                                    chip_size=chip_size,
                                    overlap=overlap,
                                    device=device,
                                    chip_norm=chip_norm_method,
                                    verbose=verbose)

In [None]:
def load_tif(path, device='cpu'):
    '''
    Load tiff from path into tensor
    '''
    with rasterio.open(path) as w:
        img = w.read()
    return torch.tensor(img).to(device)

In [None]:
# Clear GPU ram for next steps
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Ben's working code, normalize and merge into 4 quadrants
gc.collect()
torch.cuda.empty_cache()

base_metadata_path = files[0]
with rasterio.open(base_metadata_path) as r:
    test = r.profile

width = test["width"]
height = test["height"]

test["height"] = 1248
test["width"] = 1248
test["count"] = 3
transform = rasterio.Affine(test["transform"][0] * (chip_size/156),
                test["transform"][1],
                test["transform"][2],
                test["transform"][3],
                test["transform"][4] * (chip_size/156),
                test["transform"][5])
test["transform"] = transform

# Define quadrant parameters: (name, sec_x, sec_y)
quadrants = [
    ("group_top_left", False, False),
    ("group_top_right", True, False),
    ("group_bottom_left", False, True),
    ("group_bottom_right", True, True)
]

for quadrant_name, sec_x, sec_y in quadrants:
    gc.collect()
    torch.cuda.empty_cache()

    r_x = range(0, (width - width % 208) // 208 // 2 + sec_x)
    r_y = range(0, (height - height % 208) // 208 // 2 + sec_y)

    start_x = sec_x * (len(r_x) - 1) * 208
    start_y = sec_y * (len(r_y) - 1) * 208

    n_height = len(r_y) * 1248
    n_width = len(r_x) * 1248
    canvas = torch.zeros((3, n_height, n_width))

    for x in range(len(r_x)):
        for y in range(len(r_y)):
            tile_path = f"{export_path}x{x * 208 + start_x}_y{y * 208 + start_y}.tif"
            if not os.path.exists(tile_path):
                print(f"Missing tile: {tile_path}")
                continue

            infer = load_tif(tile_path, device=device).to(torch.float32)
            # infer = (infer - infer.mean(dim=(1,2), keepdim=True)) / infer.std(dim=(1,2), keepdim=True) * std_g + mu_g
            canvas[:, y*1248:(y+1)*1248, x*1248:(x+1)*1248] += infer
            del infer
            gc.collect()
            torch.cuda.empty_cache()

    transform = rasterio.Affine(
        test["transform"][0],
        test["transform"][1],
        test["transform"][2] + test["transform"][0] * start_x * 6,
        test["transform"][3],
        test["transform"][4],
        test["transform"][5] + test["transform"][4] * start_y * 6
    )

    quadrant_profile = test.copy()
    quadrant_profile.update({
        "transform": transform,
        "height": n_height,
        "width": n_width
    })

    canvas = canvas.to(torch.int32).cpu()
    with rasterio.open(f"{export_path}{quadrant_name}_normTest.tif", 'w', **quadrant_profile) as w:
        w.write(canvas.detach().numpy())

In [None]:
# Clear GPU ram for next steps
gc.collect()
torch.cuda.empty_cache()

In [None]:
# Merge the 4 quadrants into one final Cloud Optimized Geotiff (ben)
def merge_rasters(files, output_path):
    """
    Merge multiple raster files into one and save the result as a COG.

    Parameters:
    - input_files (list of str): Paths to input raster files.
    - output_path (str): Path where the merged raster will be saved.
    """
    src_files_to_mosaic = []

    for file in files:
        with rasterio.open(file) as src:
            src_files_to_mosaic.append(src)

    src_files_to_mosaic = [rasterio.open(file) for file in files]

    mosaic, transform = merge(src_files_to_mosaic)

    output_meta = src_files_to_mosaic[0].meta.copy()
    output_meta.update({
        "driver": "GTiff",
        "height": mosaic.shape[1],
        "width": mosaic.shape[2],
        "transform": transform,
        "count": mosaic.shape[0]
    })

    # Temporary tif file to create a COG
    temp_output = output_path.replace(".tif", "_temp.tif")

    with rasterio.open(temp_output, "w", **output_meta) as dest:
        dest.write(mosaic)

    for src in src_files_to_mosaic:
        src.close()

    print(f"Temporary raster saved to: {temp_output}")

    subprocess.run(["rio", "cogeo", "create", temp_output, output_path], check=True)
    print(f"Final Cloud Optimized GeoTIFF saved to: {output_path}")

    os.remove(temp_output)

In [None]:
files = [f"{export_path}{q[0]}_normTest.tif" for q in quadrants]

output = f"{export_path}merged_raster_COG.tif"
merge_rasters(files, output)

In [None]:
# Check output resolution (1.5 m)

# with rasterio.open("/content/drive/MyDrive/BenSuperResolutionInference/site0/x1872_y4160.tif") as src:
#     # Get pixel size in meters (for projected CRS like UTM)
#     x_res = src.transform.a  # X resolution (e.g., 1.5)
#     y_res = abs(src.transform.e)  # Y resolution (e.g., 1.5)
#     crs_units = src.crs.linear_units  # Confirm units (should be "metre")

# print(f"Resolution: {x_res}m (X), {y_res}m (Y)")

# #Resolution: 1.4971921401992024e-05m (X), 1.4971921401992024e-05m (Y)
