## Prerequisites
Before running this notebook, make sure you have masked all frames using OVSeg (with numbers in the file name) and make sure you have the following folder structure:
```
├── Backprojection_OVSeg.ipynb
└── PROJECT
    ├── masked_images                    //Output frames from OVSeg (numbered from 1...n)
        └── ...
    ├── Model                            //Output folder from COLMAP
        ├── cameras.txt 
        ├── images.txt
        ├── points3D.txt
        └── project.ini
    ├── Model_Original                   //Backup copy of COLMAP output
        ├── cameras.txt
        ├── images.txt
        ├── points3D.txt
        └── project.ini
    └── DB.db                            //Database needed in COLMAP
```

## GLOBAL VARIABLES

In [None]:
WIDTH = 1920
HEIGHT = 1080
START_FRAME = 0
END_FRAME = 9018
#This should just be the name of the folder containing all the things shown above
PROJECT = "_FINAL"

## Imports

In [None]:
# System libs
import os, csv, torch, numpy, scipy.io, PIL.Image, cv2 
from tqdm.notebook import tqdm
import numpy as np

## Preparing Segmentation

In [None]:
# Dictionary used to not having to store RGB values but corresponding temporary labels (mainly to save RAM)
dict_RGB_to_label = {
  tuple([254,   0,   0]) : 0,
  tuple([  0,   0, 254]) : 1,
  tuple([  0, 255,   1]) : 2,
  tuple([255, 255,   0]) : 3,
  tuple([  0, 255, 255]) : 4,
  tuple([255,   0, 254]) : 5,
  tuple([174,   0, 255]) : 6,
  tuple([255, 176,   1]) : 7,
  tuple([255,  93,   0]) : 8,
  tuple([  0, 255, 199]) : 9,
  tuple([139,  64, 191]) : 10,
  tuple([239, 174,  32]) : 11,
  tuple([ 64, 129,  99]) : 12,
  tuple([150,  60,  98]) : 13, 
}

# Dictionary used to recolour pointcloud according to class colours
dict_label_to_RGB = {
  0 : [255,   0,   0],
  1 : [  0,   0, 255],
  2 : [  0, 255,   0],
  3 : [255, 255,   0],
  4 : [  0, 255, 255],
  5 : [255,   0, 255],
  6 : [173,   0, 255],
  7 : [255, 176,   0],
  8 : [255,  93,   0],
  9 : [  0, 255, 198],
  10 : [138,  64, 191],
  11 : [239, 174,  31],
  12 : [ 64, 128,  99],
  13 : [150,  60,  99],
}

In [None]:
# This stores all frames as one 3D numpy array of labels (we don't stpre RGB values to save storage)
# WARNING: You need around 18GB of RAM if the video has around 9000 frames 

# Helper function that replaces RGB values with corresponding label number (if the colour belongs to defined class-colour)
def better_getter(r,g,b):
    tmp = (r,g,b)
    return dict_RGB_to_label.get(tmp, 42)

# Empty array that stores the labels of all pixels of all frames
predictions = np.empty((END_FRAME - START_FRAME, HEIGHT, WIDTH),dtype=np.uint8)

# Store each frame as 2D label array in corresponding "layer" of predictions array
for i in tqdm(range(START_FRAME, END_FRAME)):
    # ADJUST IMAGE PATH HERE:
    image_name = PROJECT + "/masked_images/frame_id_" + f'{i+1:04d}' + ".jpg"
    image = np.array(cv2.imread(image_name)[:,:, ::-1], dtype=np.uint8)
    predictions[i] = np.vectorize(better_getter)(image[:,:,0], image[:,:,1], image[:,:,2])
    
print(predictions.shape)

# Save array to make repeated runs faster
np.save(PROJECT +'/' + PROJECT +'_RGB_masks_ASDF', predictions)

In [None]:
# Load predictions ONLY if you haven't freshly created them
predictions = np.load(PROJECT +"/" + PROJECT +"_RGB_masks.npy")
print(predictions.shape)

## Colouring 3D Points

In [None]:
# COLMAP helper functions, taken from here: https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py

import os
import collections
import numpy as np
import struct
import argparse

BaseImage = collections.namedtuple(
    "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
Point3D = collections.namedtuple(
    "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])

class Image(BaseImage):
    def qvec2rotmat(self):
        return qvec2rotmat(self.qvec)

def qvec2rotmat(qvec):
    return np.array([
        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])    

    
def read_points3D_text(path):
    points3D = {}
    with open(path, "r") as fid:
        while True:
            line = fid.readline()
            if not line:
                break
            line = line.strip()
            if len(line) > 0 and line[0] != "#":
                elems = line.split()
                point3D_id = int(elems[0])
                xyz = np.array(tuple(map(float, elems[1:4])))
                rgb = np.array(tuple(map(int, elems[4:7])))
                error = float(elems[7])
                image_ids = np.array(tuple(map(int, elems[8::2])))
                point2D_idxs = np.array(tuple(map(int, elems[9::2])))
                points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
                                               error=error, image_ids=image_ids,
                                               point2D_idxs=point2D_idxs)
    return points3D

def read_images_text(path):
    images = {}
    with open(path, "r") as fid:
        while True:
            line = fid.readline()
            if not line:
                break
            line = line.strip()
            if len(line) > 0 and line[0] != "#":
                elems = line.split()
                image_id = int(elems[0])
                qvec = np.array(tuple(map(float, elems[1:5])))
                tvec = np.array(tuple(map(float, elems[5:8])))
                camera_id = int(elems[8])
                image_name = elems[9]
                elems = fid.readline().split()
                xys = np.column_stack([tuple(map(float, elems[0::3])),
                                       tuple(map(float, elems[1::3]))])
                point3D_ids = np.array(tuple(map(int, elems[2::3])))
                images[image_id] = Image(
                    id=image_id, qvec=qvec, tvec=tvec,
                    camera_id=camera_id, name=image_name,
                    xys=xys, point3D_ids=point3D_ids)
    return images

def write_points3D_text(points3D, path):
    if len(points3D) == 0:
        mean_track_length = 0
    else:
        mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D)
    HEADER = "# 3D point list with one line of data per point:\n" + \
             "#   POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + \
             "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length)

    with open(path, "w") as fid:
        fid.write(HEADER)
        for _, pt in tqdm(points3D.items()):
            point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
            fid.write(" ".join(map(str, point_header)) + " ")
            track_strings = []
            for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
                track_strings.append(" ".join(map(str, [image_id, point2D])))
            fid.write(" ".join(track_strings) + "\n")

### Loading COLMAP .txt files

In [None]:
pts = read_points3D_text(PROJECT + "/Model/points3D.txt")
# Copy to keep original points clean up until just before overwriting
res_pts = pts.copy()
ims = read_images_text(PROJECT + "/Model/images.txt")

### Create all possible mapping from 3D points to other information

In [None]:
# Create all necessary of dictionaries; "3D" means the ID of a 3D pt

# Create first dictionary going from point-ID to all contributing imageIDs and their pixel coordinates
dict_3D_to_imageIDs = {}
for k in tqdm(pts.keys()):
    dict_3D_to_imageIDs[k] = np.column_stack((pts[k].image_ids,pts[k].point2D_idxs))
                                             #Left: imageId, right: index of x,y

print(len(pts.keys()))

# Create second dictionary going from point-ID to filename, imageID and pixel coordinates
dict_3D_to_imageName_XY = {}
for k in tqdm(dict_3D_to_imageIDs.keys()):
    #Make list for all corresponding images
    dict_3D_to_imageName_XY[k] = []
    for tup in dict_3D_to_imageIDs[k]:
        #For all images mentioend in 3D point, get tuple of name, imageID and X,Y
        dict_3D_to_imageName_XY[k].append((ims[tup[0]].name, tup[1], list(np.rint(ims[tup[0]].xys[tup[1]]).astype(int))))
      


print(len(dict_3D_to_imageIDs.keys()))

import re

# Create dictionary going from pointID to all class labels for this point (used to later vote for the most common one)
dict_3D_to_lbl = {}
for k in tqdm(dict_3D_to_imageName_XY):
    lbl_list = []
    for trip in dict_3D_to_imageName_XY[k]:
        # Extract frame number from image name (e.g., "out23.png" -> 23)
        frame_number = int(re.search(r'\d+', trip[0]).group())
        # Get pixel coordinates
        x = trip[2][0]-1
        y = trip[2][1]-1
        # Get pixel label
        lbl = predictions[int(frame_number)-1][y][x]
        
        #Supression: only count votes for classes that were used for segmentation
        if(lbl < 42):
            lbl_list.append(lbl)
    dict_3D_to_lbl[k] = lbl_list  
    
    
# Helper function used to get most common class label for each point
from collections import Counter
def Most_Common(lst):
    if (lst):        
        data = Counter(lst)
        return int(data.most_common(1)[0][0])


# For each pointID get most common label
dict_3D_to_SINGLE_LABEL = {}
for k in tqdm(dict_3D_to_lbl.keys()):
    # Use voting strategy on each label list
    dict_3D_to_SINGLE_LABEL[k] = Most_Common(dict_3D_to_lbl[k])
    
print(len(dict_3D_to_SINGLE_LABEL.keys()))  

## Creating mapping from 3D ID to final segmentation colour

In [None]:
# This prepares the colours of all segmented points

# For each pointID get colour of corresponding class that got voted
dict_3D_to_RGB = {}
for k in tqdm(dict_3D_to_SINGLE_LABEL.keys()):
    #Only change colour if label list has voted for at least one label (otherwise errors occur)
    if (dict_3D_to_SINGLE_LABEL[k]):
        dict_3D_to_RGB[k] = dict_label_to_RGB[dict_3D_to_SINGLE_LABEL[k]]

print(len(dict_3D_to_SINGLE_LABEL.keys()))    

        
keys = dict_3D_to_RGB.keys()

# Make all points black (that way any non-segmented points are black)
for k in tqdm(res_pts.keys()):
    res_pts[k] = res_pts[k]._replace(rgb=[0, 0, 0])

# Now, go over all segmented points and colour the point according to the class RGB value
for k in tqdm(keys):
    if(dict_3D_to_SINGLE_LABEL[k]):
        res_pts[k] = res_pts[k]._replace(rgb=dict_3D_to_RGB[k])

In [None]:
# Write the points to the original .txt file and show the percentage of segmented points
write_points3D_text(res_pts, PROJECT +"/Model/points3D.txt")
print(str(len(dict_3D_to_RGB)) + " segmented points out of " + str(len(pts)) + " total points = " +  str(100 * len(dict_3D_to_RGB) / len(pts)) + "%% segmented")