# Semantic Segmentation of House Tour Videos
The first half of this notebook was adapted from the following [notebook](https://github.com/CSAILVision/semantic-segmentation-pytorch/tree/master/notebooks) which was originally made for running the benchmark semantic segmentation network from the the [ADE20K MIT Scene Parsing Benchchmark](http://sceneparsing.csail.mit.edu/).

## Prerequisites
Before running this notebook, make sure you have extracted all video frames (with numbers in the file name) and make sure you have the following folder structure:
```
├── Backprojection_CSAIL.ipynb
├── config
    ├── color150.mat
    ├── object150_info.csv
    ├── encoder_epoch_20.pth
    ├── decoder_epoch_20.pth
    └── ...
├── mit_semseg
    └── ...
└── PROJECT
    ├── img                              //Image frames (numbered from 1...n)
        └── ...
    ├── Model                            //Output folder from COLMAP
        ├── cameras.txt 
        ├── images.txt
        ├── points3D.txt
        └── project.ini
    ├── Model_Original                   //Backup copy of COLMAP output
        ├── cameras.txt
        ├── images.txt
        ├── points3D.txt
        └── project.ini
    ├── DB.db                            //Database needed in COLMAP
    ├── labels.txt                       //Labels mentioned in subtitles
    └── PROJECT_segmentation.npy         //Segmentation output from 1st part
```

## GLOBAL VARIABLES

In [None]:
WIDTH = 1920
HEIGHT = 1080
START_FRAME = 0
END_FRAME = 9018
#This should just be the name of the folder containing all the things shown above
PROJECT = "_FINAL"

## Imports and Utility Functions

In [None]:
# Parts taken from here: https://github.com/CSAILVision/semantic-segmentation-pytorch/blob/master/notebooks/DemoSegmenter.ipynb

# System libs
import os, csv, torch, numpy, scipy.io, PIL.Image, torchvision.transforms 
from tqdm.notebook import tqdm
import numpy as np
# Our libs
from mit_semseg.models import ModelBuilder, SegmentationModule
from mit_semseg.utils import colorEncode

colors = scipy.io.loadmat('config/color150.mat')['colors']
names = {}
with open('config/object150_info.csv') as f:
    reader = csv.reader(f)
    next(reader)
    for row in reader:
        names[int(row[0])] = row[5].split(";")[0]

def visualize_result(img, pred, index=None):
    # filter prediction class if requested
    if index is not None:
        pred = pred.copy()
        pred[pred != index] = -1
        print(f'{names[index+1]}:')
        
    # colorize prediction
    pred_color = colorEncode(pred, colors).astype(numpy.uint8)

    # aggregate images and save
    im_vis = numpy.concatenate((img, pred_color), axis=1)
    display(PIL.Image.fromarray(im_vis))
    
# Network Builders
net_encoder = ModelBuilder.build_encoder(
    arch='resnet50dilated',
    fc_dim=2048,
    #ADJUST PATH HERE:
    weights='config/encoder_epoch_20.pth')
net_decoder = ModelBuilder.build_decoder(
    arch='ppm_deepsup',
    fc_dim=2048,
    num_class=150,
    #ADJUST PATH HERE:
    weights='config/decoder_epoch_20.pth',
    use_softmax=True)

crit = torch.nn.NLLLoss(ignore_index=-1)
segmentation_module = SegmentationModule(net_encoder, net_decoder, crit)
segmentation_module.eval()
segmentation_module.cuda()

## Image Segmentation

In [None]:
# Run this only if you want to create a new segmentation numpy file
# WARNING: You need around 18GB of RAM if the video has around 9000 frames 

# Empty array that stores the labels of all pixels of all frames
predictions = np.empty((END_FRAME - START_FRAME, HEIGHT, WIDTH),dtype=np.uint8)

for i in tqdm(range(START_FRAME, END_FRAME)):
    # ADJUST IMAGE PATH HERE:
    image_name = PROJECT + "/img/frame_id_" + f'{i+1:04d}' + ".jpg"


    # Load and normalize one image as a singleton tensor batch
    pil_to_tensor = torchvision.transforms.Compose([
      torchvision.transforms.ToTensor(),
      torchvision.transforms.Normalize(
          mean=[0.485, 0.456, 0.406], # These are RGB mean+std values
          std=[0.229, 0.224, 0.225])  # across a large photo dataset.
    ])
    pil_image = PIL.Image.open(image_name).convert('RGB')
    img_original = numpy.array(pil_image)
    img_data = pil_to_tensor(pil_image)
    singleton_batch = {'img_data': img_data[None].cuda()}
    output_size = img_data.shape[1:]


    # Run the segmentation at the highest resolution.
    with torch.no_grad():
      scores = segmentation_module(singleton_batch, segSize=output_size)

    # Get the predicted scores for each pixel
    _, pred = torch.max(scores, dim=1)
    pred = pred.cpu()[0].numpy()
    
    predictions[i] = pred.astype(dtype=np.uint8)
        

print(predictions.shape)

np.save(PROJECT +'/' + PROJECT +'_segmentation', predictions)

## Colouring 3D Points

In [None]:
# COLMAP helper functions, taken from here: https://github.com/colmap/colmap/blob/dev/scripts/python/read_write_model.py

import os
import collections
import numpy as np
import struct
import argparse

BaseImage = collections.namedtuple(
    "Image", ["id", "qvec", "tvec", "camera_id", "name", "xys", "point3D_ids"])
Point3D = collections.namedtuple(
    "Point3D", ["id", "xyz", "rgb", "error", "image_ids", "point2D_idxs"])

class Image(BaseImage):
    def qvec2rotmat(self):
        return qvec2rotmat(self.qvec)

def qvec2rotmat(qvec):
    return np.array([
        [1 - 2 * qvec[2]**2 - 2 * qvec[3]**2,
         2 * qvec[1] * qvec[2] - 2 * qvec[0] * qvec[3],
         2 * qvec[3] * qvec[1] + 2 * qvec[0] * qvec[2]],
        [2 * qvec[1] * qvec[2] + 2 * qvec[0] * qvec[3],
         1 - 2 * qvec[1]**2 - 2 * qvec[3]**2,
         2 * qvec[2] * qvec[3] - 2 * qvec[0] * qvec[1]],
        [2 * qvec[3] * qvec[1] - 2 * qvec[0] * qvec[2],
         2 * qvec[2] * qvec[3] + 2 * qvec[0] * qvec[1],
         1 - 2 * qvec[1]**2 - 2 * qvec[2]**2]])    

    
def read_points3D_text(path):
    points3D = {}
    with open(path, "r") as fid:
        while True:
            line = fid.readline()
            if not line:
                break
            line = line.strip()
            if len(line) > 0 and line[0] != "#":
                elems = line.split()
                point3D_id = int(elems[0])
                xyz = np.array(tuple(map(float, elems[1:4])))
                rgb = np.array(tuple(map(int, elems[4:7])))
                error = float(elems[7])
                image_ids = np.array(tuple(map(int, elems[8::2])))
                point2D_idxs = np.array(tuple(map(int, elems[9::2])))
                points3D[point3D_id] = Point3D(id=point3D_id, xyz=xyz, rgb=rgb,
                                               error=error, image_ids=image_ids,
                                               point2D_idxs=point2D_idxs)
    return points3D

def read_images_text(path):
    images = {}
    with open(path, "r") as fid:
        while True:
            line = fid.readline()
            if not line:
                break
            line = line.strip()
            if len(line) > 0 and line[0] != "#":
                elems = line.split()
                image_id = int(elems[0])
                qvec = np.array(tuple(map(float, elems[1:5])))
                tvec = np.array(tuple(map(float, elems[5:8])))
                camera_id = int(elems[8])
                image_name = elems[9]
                elems = fid.readline().split()
                xys = np.column_stack([tuple(map(float, elems[0::3])),
                                       tuple(map(float, elems[1::3]))])
                point3D_ids = np.array(tuple(map(int, elems[2::3])))
                images[image_id] = Image(
                    id=image_id, qvec=qvec, tvec=tvec,
                    camera_id=camera_id, name=image_name,
                    xys=xys, point3D_ids=point3D_ids)
    return images

def write_points3D_text(points3D, path):
    if len(points3D) == 0:
        mean_track_length = 0
    else:
        mean_track_length = sum((len(pt.image_ids) for _, pt in points3D.items()))/len(points3D)
    HEADER = "# 3D point list with one line of data per point:\n" + \
             "#   POINT3D_ID, X, Y, Z, R, G, B, ERROR, TRACK[] as (IMAGE_ID, POINT2D_IDX)\n" + \
             "# Number of points: {}, mean track length: {}\n".format(len(points3D), mean_track_length)

    with open(path, "w") as fid:
        fid.write(HEADER)
        for _, pt in tqdm(points3D.items()):
            point_header = [pt.id, *pt.xyz, *pt.rgb, pt.error]
            fid.write(" ".join(map(str, point_header)) + " ")
            track_strings = []
            for image_id, point2D in zip(pt.image_ids, pt.point2D_idxs):
                track_strings.append(" ".join(map(str, [image_id, point2D])))
            fid.write(" ".join(track_strings) + "\n")

### Loading .txt files

In [None]:
pts = read_points3D_text(PROJECT + "/Model/points3D.txt")
# To keep original pts clean up until just before overwriting
res_pts = pts.copy()
ims = read_images_text(PROJECT + "/Model/images.txt")

# Load and preview keywords
keywords = np.loadtxt(PROJECT + "/labels.txt", dtype=str)
print(keywords)

### Create all possible mapping from 3D points to other information

In [None]:
# Create all necessary of dictionaries; "3D" means the ID of a 3D pt

# Create first dictionary going from point-ID to all contributing imageIDs and their pixel coordinates
dict_3D_to_imageIDs = {}
for k in tqdm(pts.keys()):
    dict_3D_to_imageIDs[k] = np.column_stack((pts[k].image_ids,pts[k].point2D_idxs))
                                             #Left: imageId, right: index of x,y

print(len(pts.keys()))

# Create second dictionary going from point-ID to filename, imageID and pixel coordinates
dict_3D_to_imageName_XY = {}
for k in tqdm(dict_3D_to_imageIDs.keys()):
    #Make list for all corresponding images
    dict_3D_to_imageName_XY[k] = []
    for tup in dict_3D_to_imageIDs[k]:
        #For all images mentioend in 3D point, get tuple of name, imageID and X,Y
        dict_3D_to_imageName_XY[k].append((ims[tup[0]].name, tup[1], list(np.rint(ims[tup[0]].xys[tup[1]]).astype(int))))
      
    
    
segmentation_labels = np.load(PROJECT +"/" + PROJECT +"_segmentation.npy")
print(segmentation_labels.shape)


import re

dict_3D_to_labels = {}

for k in tqdm(dict_3D_to_imageName_XY):
    label_list = []
    for trip in dict_3D_to_imageName_XY[k]:
        #Extract frame number from image name (e.g., "out23.png" -> 23)
        frame_number = int(re.search(r'\d+', trip[0]).group())
        # Get pixel coordinates
        x = trip[2][0]-1
        y = trip[2][1]-1
        # Get pixel label
        lbl = segmentation_labels[int(frame_number)-1][y][x]
        
        #Supression: only count votes for classes that were used for segmentation
        if(names[lbl+1] in keywords):
            label_list.append(lbl)
    dict_3D_to_labels[k] = label_list  
    
# Helper function used to get most common class label for each point
from collections import Counter
def Most_Common(lst):
    if (lst):        
        data = Counter(lst)
        return int(data.most_common(1)[0][0])


# For each pointID get most common label
dict_3D_to_SINGLE_LABEL = {}
for k in tqdm(dict_3D_to_labels.keys()):
    # Use voting strategy on each label list
    dict_3D_to_SINGLE_LABEL[k] = Most_Common(dict_3D_to_labels[k])
    
print(len(dict_3D_to_SINGLE_LABEL.keys()))  

#### Colour dictionaries

In [None]:
# Colours used for final segmentation colours

colours = {
    0 : [120 , 120 , 120],
    1 : [180 , 120 , 120],
    2 : [6 , 230 , 230],
    3 : [80 , 50 , 50],
    4 : [4 , 200 , 3],
    5 : [120 , 120 , 80],
    6 : [140 , 140 , 140],
    7 : [173, 0, 255],
    8 : [129, 238, 185],
    9 : [176, 254, 98],
    10 : [255, 0, 255],
    11 : [235 , 255 , 7],
    12 : [150 , 5 , 61],
    13 : [120 , 120 , 70],
    14 : [8, 255, 51],
    15 : [255, 0, 0],
    16 : [143 , 255 , 140],
    17 : [176, 254, 98],
    18 : [0, 255, 255],
    19 : [0, 0, 255],
    20 : [0 , 102 , 200],
    21 : [61 , 230 , 250],
    22 : [255, 176, 0],
    23 : [173, 0, 255],
    24 : [252, 93, 0],
    25 : [255, 9, 224],
    26 : [9 , 7 , 230],
    27 : [220, 220, 220],
    28 : [138, 64, 191],
    29 : [112 , 9 , 255],
    30 : [0, 0, 255],
    31 : [0, 0, 255],
    32 : [255 , 184 , 6],
    33 : [255, 0, 0],
    34 : [255 , 41 , 10],
    35 : [255, 0, 255],
    36 : [0 , 255, 0],
    37 : [102 , 8 , 255],
    38 : [255 , 61 , 6],
    39 : [0, 255 , 198],
    40 : [255 , 122 , 8],
    41 : [206, 92, 0],
    42 : [255 , 8 , 41],
    43 : [255 , 5 , 153],
    44 : [252, 93, 0],
    45 : [235 , 12 , 255],
    46 : [160 , 150 , 20],
    47 : [150, 79, 165],
    48 : [140 , 140 , 140],
    49 : [250 , 10 , 15],
    50 : [20, 255, 0],
    51 : [31 , 255 , 0],
    52 : [255 , 31 , 0],
    53 : [255 , 224 , 0],
    54 : [153 , 255 , 0],
    55 : [0 , 0 , 255],
    56 : [255 , 0, 0],
    57 : [0, 255 , 198],
    58 : [0 , 173 , 255],
    59 : [31, 0, 255],
    60 : [11 , 200 , 200],
    61 : [255 , 82 , 0],
    62 : [0 , 255 , 245],
    63 : [0 , 61 , 255],
    64 : [255 , 0, 0],
    65 : [0 , 255 , 133],
    66 : [255 , 0 , 10],
    67 : [255 , 163 , 0],
    68 : [176, 254, 98],
    69 : [255 , 0, 0],
    70 : [255 , 0, 0],
    71 : [51 , 255 , 0],
    72 : [0 , 82 , 255],
    73 : [0 , 255 , 41],
    74 : [143, 89, 2],
    75 : [0, 0, 255],
    76 : [173 , 255 , 0],
    77 : [0 , 255 , 153],
    78 : [255 , 92 , 0],
    79 : [255 , 0 , 255],
    80 : [255 , 0 , 245],
    81 : [0, 255, 255],
    82 : [0 , 255, 0],
    83 : [255 , 0 , 20],
    84 : [255 , 184 , 184],
    85 : [0 , 255, 0],
    86 : [0 , 255 , 61],
    87 : [0 , 255, 0],
    88 : [255 , 0 , 204],
    89 : [255, 255, 0],
    90 : [0 , 255 , 82],
    91 : [0 , 10 , 255],
    92 : [0, 255, 255],
    93 : [51 , 0 , 255],
    94 : [0 , 194 , 255],
    95 : [0 , 122 , 255],
    96 : [0 , 255 , 163],
    97 : [255 , 153 , 0],
    98 : [0 , 255 , 10],
    99 : [255 , 112 , 0],
    100 : [255, 176, 0],
    101 : [82 , 0 , 255],
    102 : [163 , 255 , 0],
    103 : [255 , 235 , 0],
    104 : [8 , 184 , 170],
    105 : [133 , 0 , 255],
    106 : [0 , 255 , 92],
    107 : [184 , 0 , 255],
    108 : [245, 121, 0],
    109 : [0 , 184 , 255],
    110 : [0, 0, 255],
    111 : [255 , 0 , 112],
    112 : [92 , 255 , 0],
    113 : [0 , 224 , 255],
    114 : [112 , 224 , 255],
    115 : [149, 187, 24],
    116 : [163 , 0 , 255],
    117 : [153 , 0 , 255],
    118 : [71 , 255 , 0],
    119 : [255 , 0 , 163],
    120 : [255 , 204 , 0],
    121 : [255 , 0 , 143],
    122 : [0 , 255 , 235],
    123 : [133 , 255 , 0],
    124 : [94, 183, 102],
    125 : [245 , 0 , 255],
    126 : [255 , 0 , 122],
    127 : [255 , 245 , 0],
    128 : [10 , 190 , 212],
    129 : [79, 130, 203],
    130 : [255, 255, 0],
    131 : [0, 255, 255],
    132 : [255 , 255 , 0],
    133 : [0 , 153 , 255],
    134 : [0 , 41 , 255],
    135 : [0 , 255 , 204],
    136 : [0 , 255, 0],
    137 : [41 , 255 , 0],
    138 : [173 , 0 , 255],
    139 : [0 , 245 , 255],
    140 : [71 , 0 , 255],
    141 : [255, 255, 0],
    142 : [0, 255, 184],
    143 : [255, 255, 0],
    144 : [184 , 255 , 0],
    145 : [0, 133, 255],
    146 : [255 , 214 , 0],
    147 : [25 , 194 , 194],
    148 : [239, 174, 31],
    149 : [92 , 0 , 255]
}

# Class names
names = {
    0: "wall", 
    1: "building", 
    2: "sky", 
    3: "floor", 
    4: "tree", 
    5: "ceiling", 
    6: "road", 
    7: "bed", 
    8: "windowpane", 
    9: "grass", 
    10: "cabinet", 
    11: "sidewalk", 
    12: "person", 
    13: "earth", 
    14: "door", 
    15: "table", 
    16: "mountain", 
    17: "plant", 
    18: "curtain", 
    19: "chair", 
    20: "car", 
    21: "water", 
    22: "painting", 
    23: "sofa", 
    24: "shelf", 
    25: "house", 
    26: "sea", 
    27: "mirror", 
    28: "rug", 
    29: "field", 
    30: "armchair", 
    31: "seat", 
    32: "fence", 
    33: "desk", 
    34: "rock", 
    35: "wardrobe", 
    36: "lamp", 
    37: "bathtub", 
    38: "railing", 
    39: "cushion", 
    40: "base", 
    41: "box", 
    42: "column", 
    43: "signboard", 
    44: "chest", 
    45: "counter", 
    46: "sand", 
    47: "sink", 
    48: "skyscraper", 
    49: "fireplace", 
    50: "refrigerator", 
    51: "grandstand", 
    52: "path", 
    53: "stairs", 
    54: "runway", 
    55: "case", 
    56: "pool", 
    57: "pillow", 
    58: "screen", 
    59: "stairway", 
    60: "river", 
    61: "bridge", 
    62: "bookcase", 
    63: "blind", 
    64: "coffee", 
    65: "toilet", 
    66: "flower", 
    67: "book", 
    68: "hill", 
    69: "bench", 
    70: "countertop", 
    71: "stove", 
    72: "palm", 
    73: "kitchen", 
    74: "computer", 
    75: "swivel", 
    76: "boat", 
    77: "bar", 
    78: "arcade", 
    79: "hovel", 
    80: "bus", 
    81: "towel", 
    82: "light", 
    83: "truck", 
    84: "tower", 
    85: "chandelier", 
    86: "awning", 
    87: "streetlight", 
    88: "booth", 
    89: "television", 
    90: "airplane", 
    91: "dirt", 
    92: "apparel", 
    93: "pole", 
    94: "land", 
    95: "bannister", 
    96: "escalator", 
    97: "ottoman", 
    98: "bottle", 
    99: "buffet", 
    100: "poster", 
    101: "stage", 
    102: "van", 
    103: "ship", 
    104: "fountain", 
    105: "conveyer", 
    106: "canopy", 
    107: "washer", 
    108: "plaything", 
    109: "swimming", 
    110: "stool", 
    111: "barrel", 
    112: "basket", 
    113: "waterfall", 
    114: "tent", 
    115: "bag", 
    116: "minibike", 
    117: "cradle", 
    118: "oven", 
    119: "ball", 
    120: "food", 
    121: "step", 
    122: "tank", 
    123: "trade", 
    124: "microwave", 
    125: "pot", 
    126: "animal", 
    127: "bicycle", 
    128: "lake", 
    129: "dishwasher", 
    130: "screen", 
    131: "blanket", 
    132: "sculpture", 
    133: "hood", 
    134: "sconce", 
    135: "vase", 
    136: "traffic", 
    137: "tray", 
    138: "ashcan", 
    139: "fan", 
    140: "pier", 
    141: "crt", 
    142: "plate", 
    143: "monitor", 
    144: "bulletin", 
    145: "shower", 
    146: "radiator", 
    147: "glass", 
    148: "clock", 
    149: "flag"
}

# Dictionary to get colour of each class given a name
name_to_colour = {}
for k in names.keys():
    name_to_colour[names[k]] = colours[k]

# Dictionary to get name of class given a concatenated RGB colour string
concat_color_to_name = {}
for k,v in name_to_colour.items():
    rgb_string = str(v[0]) + str(v[1]) + str(v[2])
    concat_color_to_name[rgb_string] = k

## Creating mapping from 3D ID to final segmentation colour

In [None]:
# This prepares the colours of all segmented points

# For each pointID get colour of corresponding class that got voted
dict_3D_to_RGB = {}
for k in tqdm(dict_3D_to_SINGLE_LABEL.keys()):
    #Only change colour if label list has voted for at least one label (otherwise errors occur)
    if (dict_3D_to_SINGLE_LABEL[k]):
        dict_3D_to_RGB[k] = colours[dict_3D_to_SINGLE_LABEL[k]]

print(len(dict_3D_to_SINGLE_LABEL.keys()))    

        
keys = dict_3D_to_RGB.keys()

# Make all points black (that way any non-segmented points are black)
for k in tqdm(res_pts.keys()):
    res_pts[k] = res_pts[k]._replace(rgb=[0, 0, 0])

# Now, go over all segmented points and colour the point according to the class RGB value
for k in tqdm(keys):
    if(dict_3D_to_SINGLE_LABEL[k]):
        res_pts[k] = res_pts[k]._replace(rgb=dict_3D_to_RGB[k])

In [None]:
# Write the points to the original .txt file and show the percentage of segmented points
write_points3D_text(res_pts, PROJECT +"/Model/points3D.txt")
print(str(len(dict_3D_to_RGB)) + " segmented points out of " + str(len(pts)) + " total points = " +  str(100 * len(dict_3D_to_RGB) / len(pts)) + "%% segmented")