In [21]:
import cv2
import numpy as np
import pandas as pd
import time
import os
import glob
import shutil
from horizon.flir_boson_settings import I, D, P

import pickle
import tqdm
from scipy.spatial.transform import Rotation
from scipy.interpolate import RegularGridInterpolator, NearestNDInterpolator

import rasterio
import rasterio.plot
import rasterio.merge
import rasterio.mask
import pyproj
from PIL import ImageColor
import shapely
from shapely import Polygon
import skimage 

from utils.utils import thermal2rgb
from utils.projections import (
    project_points, 
    power_spacing, 
    world2cam, 
    create_world_grid,
    world_to_camera_coords,
)

from utils.draw import (
    draw_overlay_and_labels, 
    points_to_segmentation, 
    generate_binary_mask,
    colorize_dynamic_world_label,
    dynamic_world_color_map,
    chesapeake_cvpr_landcover_color_map,
)



##########################################################################
### Set path to flight sequence folder containing images and csv files.
##########################################################################
# DATA_PATH = '/data/onr-thermal/2022-12-20_Castaic_Lake/flight4'
# DATA_PATH = '/data/onr-thermal/2022-05-15_ColoradoRiver/flight3'
DATA_PATH = '/data/onr-thermal/caltech_duck/ONR_2023-03-22-14-41-46'
# DATA_PATH = '/data/onr-thermal/big_bear/ONR_2022-05-08-11-23-59'
# DATA_PATH = '/data/onr-thermal/kentucky_river/flight3-1'

BASE_PATH = '/data/microsoft_planetary_computer/outputs/preprocessed/'
EPSG = 'epsg-32618'
PLACE = 'duck'
LULC_TYPE = 'dynamicworld'
D3_TYPE = 'dsm'
RESOLUTION = '0.6'

# BASELINE_ELEVATION = 428.54 # Water elevation of Castaic Lake, Dec. 22, 2022
# BASELINE_ELEVATION = 0 # Water elevation of Colorado River, Parker Dam
# BASELINE_ELEVATION = 114 # Water elevation of Colorado River, Parker Dam
# BASELINE_ELEVATION = 2058 # Water elevation of big bear lake
# BASELINE_ELEVATION = 177 # kentucky river
BASELINE_ELEVATION = 0 # duck

LABEL_RASTER_PATH = os.path.join(BASE_PATH, EPSG, PLACE, LULC_TYPE, RESOLUTION, 'crf_naip_naip-nir', 'mosaic.tiff')
DSM_PATH = os.path.join(BASE_PATH, EPSG, PLACE, D3_TYPE, RESOLUTION, 'mosaic.tiff')

##########################################################################
### Create output folders
##########################################################################
if os.path.exists('outputs') and os.path.isdir('outputs'):
    shutil.rmtree('outputs')
os.makedirs('outputs')

color_map = dynamic_world_color_map()


##########################################################################
### Read csv of uav global/local pose
##########################################################################
print('Reading data...')
t0 = time.time()
alignment_data = pd.read_csv(os.path.join(DATA_PATH, "aligned.csv"), header=13)
# alignment_data = pd.read_csv(os.path.join(DATA_PATH, "aligned.csv"), header=14)
# alignment_data = pd.read_csv(os.path.join(DATA_PATH, "aligned.csv"), header=0)

alignment_data = alignment_data.applymap(lambda x: x.strip() if isinstance(x, str) else x)
alignment_data.columns = alignment_data.columns.str.replace(' ', '')
t1 = time.time()
print('{:3f} seconds to read csvs'.format(t1 - t0))


##########################################################################
### Get rectified camera matrix
##########################################################################
print('Creating new camera matrix...')
H, W = (512, 640)
newcameramtx, roi = cv2.getOptimalNewCameraMatrix(I, D, (W, H), 0, (W, H))
new_P = np.hstack([newcameramtx, np.zeros((3,1))])


##########################################################################
### Read raster data (dynamic world labels + dsm)
##########################################################################
t0 = time.time()
label_tiff_data = rasterio.open(LABEL_RASTER_PATH)
dsm = rasterio.open(DSM_PATH)

if os.path.exists('dw_interp.pkl') and os.path.exists('dsm_interp.pkl'):
    with open('dw_interp.pkl', 'rb') as f:
        label_interp = pickle.load(f)
    
    with open('dsm_interp.pkl', 'rb') as f:
        dsm_interp = pickle.load(f)

else:
    label_array = label_tiff_data.read()
    dsm_array = dsm.read()

    print(label_array.shape)
    n_bands, height, width = label_array.shape
    cols, rows = np.meshgrid(np.arange(width), np.arange(height))
    xs, ys = rasterio.transform.xy(label_tiff_data.transform, rows, cols)
    label_utm_grid = np.stack([xs, ys], axis=2).reshape(-1, 2)
    label_interp = NearestNDInterpolator(label_utm_grid, label_array.transpose(1, 2, 0).reshape(height*width, n_bands))

    with open('dw_interp.pkl', 'wb') as f:
        pickle.dump(label_interp, f, pickle.HIGHEST_PROTOCOL)

    print(dsm_array.shape)
    n_bands, height, width = dsm_array.shape
    cols, rows = np.meshgrid(np.arange(width), np.arange(height))
    xs, ys = rasterio.transform.xy(dsm.transform, rows, cols)
    dsm_utm_grid = np.stack([xs, ys], axis=2).reshape(-1, 2)
    dsm_interp = NearestNDInterpolator(dsm_utm_grid, dsm_array.transpose(1, 2, 0).reshape(height*width, 1))
    with open('dsm_interp.pkl', 'wb') as f:
        pickle.dump(dsm_interp, f, pickle.HIGHEST_PROTOCOL)
t1 = time.time()
print('{:3f} seconds to read rasters'.format(t1 - t0))

crs = label_tiff_data.crs
tform = pyproj.Transformer.from_crs("epsg:4326", "epsg:{}".format(crs.to_epsg()))


Reading data...
0.284083 seconds to read csvs
Creating new camera matrix...
(1, 23361, 24769)
(1, 23361, 24769)
423.878205 seconds to read rasters


In [22]:
import moderngl

import numpy as np
from PIL import Image

import cv2
from PIL import ImageColor
import pyvista as pv

def glOrtho(left, right, bottom, top, near, far):
    tx = -(right + left) / (right - left)
    ty = -(top + bottom) / (top - bottom)
    tz = -(far + near) / (far - near)
    x = np.array([
        [2 / (right - left), 0, 0, tx],
        [0, 2 / (top - bottom), 0, ty],
        [0, 0, -2 / (far - near), tz],
        [0, 0, 0, 1],
    ], dtype='f4')
    return x

HEX_COLORS = [
    '#419BDF', '#397D49', '#88B053', '#7A87C6', '#E49635', '#DFC35A',
    '#C4281B', '#ffffff', '#B39FE1', '#A8DEFF'
]
    
# K_gl = np.zeros((4,4), dtype='f4')
# K_gl[0,0] = -I[0,0]
# K_gl[1,1] = -I[1,1]
# K_gl[0,2] = (cols - I[0,2])
# K_gl[1,2] = (rows - I[1,2])
# K_gl[2,2] = A
# K_gl[2,3] = B
# K_gl[3,2] = 1

# NDC = np.zeros((4,4), dtype='f4')
# NDC[0,0] = -2 / cols
# NDC[1,1] = 2 / rows
# NDC[2,2] = -2 / (far - near)

# NDC[0,3] = 1
# NDC[1,3] = -1
# NDC[2,3] = -(far + near) / (far - near)
# NDC[3,3] = 1



def dynamic_world_color_map():
    rgb_colors = [ImageColor.getcolor(c, "RGB") for c in HEX_COLORS]
    color_map = dict(zip(list(range(0, 10)), rgb_colors))
    return color_map

def colorize_dynamic_world_label(label):
    mapping = dynamic_world_color_map()

    h = len(label)
    color_label = np.zeros((h, 3), dtype=np.uint8)
    for i in range(0, 10):
        color_label[label == i, :] = mapping[i] 
    return color_label / 255

def get_mask_mgl(pts, label):

    # Intrinsics matrix
    I_old = np.array([
        [511.03573247, 0.000000, 311.80346835], 
        [0.000000, 508.22913692, 261.56701122], 
        [0.000000, 0.000000, 1.000000]
    ])
    D = np.array([-0.43339599, 0.18974767, -0.00146426, 0.00118333, 0.000000])

    H, W = (512, 640)
    I, roi = cv2.getOptimalNewCameraMatrix(I_old, D, (W, H), 0, (W, H))
    # print(I)
    near = 0.1
    far = 10000
    rows = 512
    cols = 640
    A = -(near + far)
    B  = near * far
    
    K_gl = np.array([
        [I[0,0], 0, -I[0,2], 0],
        [0, I[1,1], -I[1,2], 0],
        [0, 0, near+far, near*far],
        [0, 0, -1, 0],
    ], dtype='f4')
    NDC = glOrtho(0, cols, rows, 0, near, far)

    P_gl = NDC @ K_gl
    # pts = np.load('outputs/thermal-10000.npy').T

    xyz = pts[[1, 2, 0], :].T
    xyz[:,2] *= -1
    xyz[:,0] *= -1

    # print("Input Label: ", label.shape)
    label = pts[3,:].T
    # print("After: ", label.shape)

    # num_vertices = xyz.shape[0]
    color = colorize_dynamic_world_label(label)
    # colors = np.hstack([color, np.ones(num_vertices).reshape(-1,1)])

    cloud = pv.PolyData(xyz)
    surf = cloud.delaunay_2d()
    print(surf.point_normals)
    vertex_normals = surf.point_normals
    vertices = surf.points
    vertex_colors = color

    faces = surf.faces.reshape(-1, 4)
    triangles = faces[:,1:]

    # -------------------
    # CREATE CONTEXT HERE
    # -------------------

    with moderngl.create_context(standalone=True, backend='egl') as ctx:

        prog = ctx.program(
            vertex_shader='''
                #version 330
                uniform mat4 proj;

                in vec3 in_vert;
                in vec3 in_color;

                out vec3 v_color;
                out vec4 pose;

                void main() {
                    v_color = in_color;
                    pose = proj*vec4(in_vert, 1.0);
                    gl_Position = proj*vec4(in_vert, 1.0);
                }
            ''',
            fragment_shader='''
                #version 330

                in vec3 v_color;
                out vec3 f_color;

                void main() {
                    f_color = v_color;
                }
            ''',
            varyings=['pose']
        )

        P_gl = np.ascontiguousarray(P_gl.T)
        prog['proj'].write(P_gl)
        
        vertices_info = np.hstack([vertices, vertex_colors]).astype('f4')

        fbo = ctx.simple_framebuffer((640, 512))
        fbo.use()
        
        ctx.enable(moderngl.DEPTH_TEST)
        ibo = ctx.buffer(triangles.astype('i4'))
        vbo = ctx.buffer(vertices_info.astype('f4'))
        vao = ctx.vertex_array(
            prog, 
            [(vbo, '3f 3f', 'in_vert', 'in_color')],
            ibo,
        )

        vao.render(moderngl.TRIANGLES)
        image = Image.frombytes('RGB', fbo.size, fbo.read(), 'raw', 'RGB', 0, 1)
        image.save('output.png')
        print(ctx.error)
        # Return as np array for plotting over original image
    return np.asarray(image) 

In [23]:
import importlib
import sys
# import pydensecrf.densecrf as dcrf

# from mgl_v2 import get_mask_mgl
import utils.draw
import utils.projections
import utils.postprocessing
# importlib.reload(sys.modules['mgl_v2']) 
importlib.reload(sys.modules['utils.draw']) 

importlib.reload(sys.modules['utils.projections']) 
importlib.reload(sys.modules['utils.postprocessing']) 

##########################################################################
### Begin segmentation projection here
##########################################################################
print('Starting segmentation estimation')
image_paths = sorted(glob.glob(os.path.join(DATA_PATH, 'images/thermal/*')))[2000::1000]
for t, img_path in tqdm.tqdm(enumerate(image_paths), total=len(image_paths)):
    # if t < 8:
    #     continue
    # img_path = os.path.join(DATA_PATH, 'images/thermal/thermal-50000.tiff')
    # img_path = os.path.join(DATA_PATH, 'images/thermal/thermal-20000.tiff')
    # img_path = '/home/carson/data/thermal/2022-05-15_ColoradoRiver/flight3/images/thermal/thermal-02500.tiff'
    # output_path = 'outputs/{}'.format(os.path.basename(img_path).replace('tiff', 'png'))

    image_data = alignment_data[alignment_data['image'] == "images/thermal/{}".format(os.path.basename(img_path))]
    # image_data = alignment_data[alignment_data['ir_file'] == "images/thermal/{}".format(os.path.basename(img_path))]
    if len(image_data) == 0:
        print('Skipping {}, no pose info...'.format(img_path))
        continue
    
    coords = image_data[['camLLA_lat', 'camLLA_lon']].values.astype(float)[0]
    cam_xyzw = image_data[['camNED_qx', 'camNED_qy', 'camNED_qz', 'camNED_qw']].values.astype(float)[0]
    height = image_data[['camNED_D']].values.astype(float)[0, 0]

    # coords = image_data[['uav_lat', 'uav_lon']].values.astype(float)[0]
    # cam_xyzw = image_data[['ir_qx', 'ir_qy', 'ir_qz', 'ir_qw']].values.astype(float)[0]
    # height = image_data[['ir_Z']].values.astype(float)[0, 0]
    
    dist_to_ground_plane = image_data[['riverNED_Z']].values.astype(float)[0, 0]
    if np.isnan(coords).any():
        print('Skipping {}, lat/lng (camLLA_lat/lng) has NaNs...'.format(img_path))
        continue
    if np.isnan(cam_xyzw).any():
        print('Skipping {}, quaternion has NaNs...'.format(img_path))
        continue
    if np.isnan(height).any():
        print('Skipping {}, uav altitude (camNED_D) has NaNs...'.format(img_path))
        continue
    if np.isnan(dist_to_ground_plane).any():
        print('Skipping {}, uav to ground distance (riverNED_Z) has NaNs...'.format(img_path))
        continue

    # print('Processing image: {}'.format(img_path))
    img = cv2.imread(img_path, -1)
    img = thermal2rgb(img)
    undistorted_image = cv2.undistort(img, I, D, None, newcameramtx)

    # dist_to_ground_plane = -9
    z = height + dist_to_ground_plane
    r = Rotation.from_quat(cam_xyzw)
    yaw, pitch, roll =  r.as_euler('ZYX', degrees=False)
    
    Nx = 500
    Ny = 300

    t0 = time.time()
    x_unit_vec, y_unit_vec, x_magnitudes, y_magnitudes, world_pts, xx, yy = utils.projections.create_world_grid(
        yaw, 
        x_mag=10000,
        y_mag=8000,
        Nx=Nx,
        Ny=Ny,
        exp_x=3,
        exp_y=3,
    )
    t1 = time.time()
    # print('{:3f} seconds to create world grid'.format(t1 - t0))
    
    utm_e, utm_n = tform.transform(coords[0], coords[1])
    rows, cols = rasterio.transform.rowcol(label_tiff_data.transform, xs=utm_e, ys=utm_n)

    ptA_utm = np.array([utm_e, utm_n])
    ptA_rc = np.array([cols, rows])

    # y_grid = y_unit_vec.reshape(2, 1) * y_magnitudes.reshape(1, N)
    # x_grid = ptA_utm.reshape(2, 1) + x_unit_vec.reshape(2, 1) * x_magnitudes.reshape(1, N)
    # utm_grid = x_grid.T.reshape(N, 1, 2) + y_grid.T.reshape(1, N, 2)

    y_grid = y_unit_vec.reshape(2, 1) * yy.reshape(1, Nx*Ny)
    x_grid = ptA_utm.reshape(2, 1) + x_unit_vec.reshape(2, 1) * xx.reshape(1, Nx*Ny)
    utm_grid = x_grid + y_grid 
    utm_grid = utm_grid.reshape(2, Ny, Nx).transpose(2, 1, 0)

    # t0 = time.time()
    # sampled_labels = rasterio.sample.sample_gen(label_tiff_data, xy=utm_grid.reshape(-1, 2))
    # sampled_z = rasterio.sample.sample_gen(dsm, xy=utm_grid.reshape(-1, 2))
    # t1 = time.time()
    # print('{:3f} seconds to sample from rasters'.format(t1 - t0))

    t0 = time.time()
    # print(utm_grid.shape)
    sampled_labels = label_interp(utm_grid.reshape(-1, 2))
    sampled_z = dsm_interp(utm_grid.reshape(-1, 2))
    t1 = time.time()
    # print('{:3f} seconds to sample from array'.format(t1 - t0))

    t0 = time.time()
    # word_coord_pts = np.zeros((N*N, 4)) # x, y, z, label
    # for i, (val, h) in enumerate(zip(sampled_labels, sampled_z)):
    #     print(type(val), h)
    #     word_coord_pts[i, 0:2] = world_pts[i]
    #     word_coord_pts[i, 2] = np.clip(h - BASELINE_ELEVATION, 0, None)
    #     word_coord_pts[i, 3] = val[-1] 
    #     exit(0)

    world_coord_z = np.clip(sampled_z.reshape(Nx*Ny, 1) - BASELINE_ELEVATION, 0, None)
    word_coord_pts = np.concatenate([world_pts.reshape(Nx*Ny, 2), world_coord_z], axis=1)
    N_LULC_CLASSES = 10
    # world_coord_labels = sampled_labels.reshape(Nx*Ny, N_LULC_CLASSES)
    world_coord_labels = sampled_labels.reshape(Nx*Ny, 1)
    # print("World Coordinate Labels", world_coord_labels)
    t1 = time.time()
    # print('{:3f} seconds to label world coordinates'.format(t1 - t0))

    # Label points in back first
    # ind = np.argsort(word_coord_pts[:,0])[::-1]
    # word_coord_pts = word_coord_pts[ind]
    # world_coord_labels = world_coord_labels[ind]
    word_coord_pts[:, 2] = -word_coord_pts[:, 2] - z

    ### SARASWATI: word_coord_pts are the 3D labels (x: forward, y: right, z: down)
    camera_pts = world_to_camera_coords(cam_xyzw, word_coord_pts).T
    labeled_camera_pts = np.concatenate([camera_pts, world_coord_labels[:,-1].reshape(-1, 1)], axis=1)
    name = os.path.basename(img_path).split('.')[0]
    np.save('outputs/{}.npy'.format(name), labeled_camera_pts)

    surface_elevation = np.copy(word_coord_pts[:, 2])
    
    ### SARASWATI: dig into world2cam to find where the word_coord_pts get transformed into the 
    ### camera coordinate frame (3D), and subsequently, projected into the 2D image frame. 
    Xn = world2cam(cam_xyzw, new_P, word_coord_pts)

    original_img, masked_img, pts_img = utils.draw.draw_overlay_and_labels(
        undistorted_image, 
        points=Xn, 
        labels=world_coord_labels[:,-1], 
        color_map=color_map
    )

    name = os.path.basename(img_path).split('.')[0]
    cv2.imwrite('outputs/{}.png'.format(name), original_img)
    cv2.imwrite('outputs/{}_autoseg.png'.format(name), masked_img)
    cv2.imwrite('outputs/{}_pts.png'.format(name), pts_img)

    ### Start postprocessing
    probability_img = utils.draw.project_prob(undistorted_image.shape[:2], Xn, world_coord_labels)
    surface_img = utils.draw.project_elevation(undistorted_image.shape[:2], Xn, surface_elevation)

    thermal_surface_img = np.copy(original_img)
    # thermal_surface_img[:,:,2] = surface_img / 1000 * 255
    ground_truth_img = np.copy(original_img)
    ground_truth_img[:,:,2] = surface_img / 1000 * 255
    # params = dict(
    #     sxy=(45, 45),
    #     srgb=(35, 35, 100), 
    #     compat=10, 
    #     kernel=dcrf.FULL_KERNEL, 
    #     normalization=dcrf.NORMALIZE_SYMMETRIC,
    #     inference_steps=3,
    # )

    # refined_labels = utils.postprocessing.dense_crf(probability_img, thermal_surface_img, **params)
    # refined_mask = utils.draw.colorize_dynamic_world_label(refined_labels)
    
    mgl_mask = get_mask_mgl(labeled_camera_pts.T, None)
    
    # cv2.imwrite('original_output/{}_autoseg_refined.png'.format(name), cv2.cvtColor(refined_mask, cv2.COLOR_RGB2BGR))
    overlay = cv2.addWeighted(original_img, 0.7, mgl_mask, 0.4, 0)
    cv2.imwrite('outputs/{}_autoseg_mgl.png'.format(name), cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR))


Starting segmentation estimation


  3%|▎         | 1/35 [00:07<04:18,  7.60s/it]

[[-0.00118662  0.9999833   0.00564619]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 ...
 [ 0.00474865  0.9158474   0.40149838]
 [ 0.00474865  0.9158474   0.40149838]
 [ 0.00474865  0.91584754  0.40149844]]
GL_NO_ERROR


  6%|▌         | 2/35 [00:14<03:50,  6.99s/it]

[[0.14377324 0.79661256 0.5871436 ]
 [0.         0.         0.        ]
 [0.         0.         0.        ]
 ...
 [0.14377326 0.7966127  0.58714366]
 [0.14377326 0.7966127  0.58714366]
 [0.14377324 0.79661256 0.5871436 ]]
GL_NO_ERROR


  9%|▊         | 3/35 [00:21<03:48,  7.13s/it]

[[0.0794169  0.75219125 0.6541418 ]
 [0.         0.         0.        ]
 [0.         0.         0.        ]
 ...
 [0.07981312 0.86507887 0.49524596]
 [0.07981312 0.86507887 0.49524596]
 [0.07981311 0.8650788  0.4952459 ]]
GL_NO_ERROR


 11%|█▏        | 4/35 [00:29<03:46,  7.29s/it]

[[ 0.39824978  0.77249986 -0.49461192]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 ...
 [ 0.10911317  0.8889362   0.4448447 ]
 [ 0.10911317  0.8889362   0.4448447 ]
 [ 0.10911316  0.88893616  0.4448447 ]]
GL_NO_ERROR


 14%|█▍        | 5/35 [00:36<03:41,  7.39s/it]

[[0.10301533 0.9115288  0.3981244 ]
 [0.         0.         0.        ]
 [0.         0.         0.        ]
 ...
 [0.1027948  0.91154945 0.39813444]
 [0.1027948  0.91154945 0.39813444]
 [0.10279478 0.9115493  0.3981344 ]]
GL_NO_ERROR


 17%|█▋        | 6/35 [00:44<03:39,  7.56s/it]

[[0.08475045 0.8732952  0.47976342]
 [0.         0.         0.        ]
 [0.         0.         0.        ]
 ...
 [0.08501433 0.9080696  0.41010016]
 [0.08501433 0.9080696  0.41010016]
 [0.08501433 0.9080696  0.41010013]]
GL_NO_ERROR


 20%|██        | 7/35 [00:51<03:28,  7.44s/it]

[[0.14166117 0.76199645 0.6318968 ]
 [0.         0.         0.        ]
 [0.         0.         0.        ]
 ...
 [0.14307539 0.84429336 0.51642835]
 [0.14307539 0.84429336 0.51642835]
 [0.14307539 0.84429336 0.5164283 ]]
GL_NO_ERROR


 23%|██▎       | 8/35 [01:01<03:39,  8.12s/it]

[[ 0.04599185 -0.8653282  -0.49909106]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 ...
 [ 0.04599185 -0.8653282  -0.4990911 ]
 [ 0.04599185 -0.8653282  -0.4990911 ]
 [ 0.04599185 -0.8653282  -0.4990911 ]]
GL_NO_ERROR


 26%|██▌       | 9/35 [01:11<03:52,  8.93s/it]

[[ 0.11707334 -0.79319274 -0.5976111 ]
 [ 0.          0.          0.        ]
 [ 0.          0.          0.        ]
 ...
 [ 0.13177991 -0.8857805  -0.44500214]
 [ 0.20735303 -0.9294345  -0.3052151 ]
 [ 0.1176472  -0.9670821  -0.22563596]]
GL_NO_ERROR


 26%|██▌       | 9/35 [01:13<03:31,  8.12s/it]


KeyboardInterrupt: 