In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from torchmetrics import MeanMetric
from tqdm import tqdm
import sklearn
import cv2
import torch
import torchvision
import torchmetrics
from tqdm import tqdm

from semsegcluster.eval_munkres import get_measurements, measure_from_confusion_matrix
import semsegcluster.data.scannet
from semsegcluster.settings import EXP_OUT
from semsegcluster.data.nyu_depth_v2 import TRAINING_LABEL_NAMES, NYU40_COLORS
import detectron2 as det2
import detectron2.utils.visualizer
from semsegcluster.sacred_utils import get_incense_loader
from deeplab.oaisys_utils import OAISYS_LABELS, data_converter_rugd, load_checkpoint
from semsegcluster.sacred_utils import get_checkpoint
from semsegcluster.data.tfds_to_torch import TFDataIterableDataset
from deeplab.marsscapes_utils import data_converter_marsscapes
from semsegcluster.marsscapes_eval_munkres import get_measurements

In [None]:
det2.utils.colormap.random_color(rgb=True).astype(np.uint8)

In [None]:
skip_frames = 10
uncert_treshold = -4.7
uncert = 'maxlogit-pp'
inference = 'inference'
stop_index = None
OAISYS_LABELS[15] = 'Sand'#'Grass-Field'
path = '/home/asl/Downloads/MarsScapes/processed'

directory = '/home/asl/Downloads/file_transfer/marsscapes/inference/deeplab/'
adapted_directory = '/home/asl/Downloads/file_transfer/marsscapes/adapted_inference/deeplab/'
cluster_directory = '/home/asl/Downloads/file_transfer/marsscapes/clustering/'

m = det2.data.Metadata()
m.stuff_classes = m.stuff_classes = OAISYS_LABELS[0:11] + [f'c{i}' for i in range(80)]
m.stuff_colors = NYU40_COLORS + NYU40_COLORS + NYU40_COLORS
m_bool = det2.data.Metadata()
m_bool.stuff_classes = ['outlier', 'inlier']
m_bool.stuff_colors = NYU40_COLORS
m_label = det2.data.Metadata()
m_label.stuff_classes = OAISYS_LABELS[0:11] + [f'c{i}' for i in range(80)]
m_label.stuff_colors = NYU40_COLORS + NYU40_COLORS + NYU40_COLORS
m_label.stuff_classes[15] = 'Sand'

test_path = f'{path}/test/'
val_path = f'{path}/val/'
train_path = f'{path}/train/'
idx = 0
for image_name in os.listdir(cluster_directory):
  # load image and label
  image_name = image_name[:-11]
  # if not (image_name.endswith(".png")):
  #   continue
  # if 'color' in image_name:
  #   continue
  # if 'semanticId' in image_name:
  #   continue
  idx += 1
  if idx%skip_frames != 0:
    continue
  
  image = cv2.cvtColor(cv2.imread(os.path.join(test_path, f'{image_name}.png')), cv2.COLOR_RGB2BGR)/255.
  # image = torch.tensor(image).permute([2, 0, 1])[None,:].type(torch.FloatTensor).numpy()
  label = np.load(os.path.join(directory, f'{image_name}_label.npy')).squeeze()
  pred = np.load(os.path.join(directory, f'{image_name}_ensemble_pred.npy')).squeeze()
  pred = np.where(label==255, 255, pred)
  adapted_pred = cv2.imread(os.path.join(adapted_directory, f'{image_name}_pred401-merged-ensemble_pred-seg107.png'),
                         cv2.IMREAD_ANYDEPTH).squeeze().astype(np.int32)
  adapted_pred = np.where(label==255, 255, adapted_pred)
  uncertainty = np.load(os.path.join(directory, f'{image_name}_ensemble_{uncert}.npy')).squeeze()
  cluster = np.load(os.path.join(cluster_directory, f'{image_name}_seg440.npy')).squeeze()
  cluster = np.where(label==255, 255, cluster)
  inlier = uncertainty<uncert_treshold
  inlier = np.where(label==255, 255, inlier)
  _, axs = plt.subplots(1, 6, figsize=(30, 20))
  axs[0].imshow(image)
  axs[0].axis('off')
  axs[0].set_title(f'Image')
  vis = det2.utils.visualizer.Visualizer(image*255, m_label, scale=0.4)
  axs[1].imshow(vis.draw_sem_seg(label).get_image())
  axs[1].axis('off')
  axs[1].set_title(f'Label')
  vis = det2.utils.visualizer.Visualizer(image*255, m, scale=0.4)
  axs[3].imshow(vis.draw_sem_seg(pred).get_image())
  axs[3].axis('off')
  axs[3].set_title(f'Deep Ensemble Prediction')
  vis = det2.utils.visualizer.Visualizer(image*255, m_bool, scale=0.4)
  axs[2].imshow(vis.draw_sem_seg(inlier).get_image())
  axs[2].axis('off')
  axs[2].set_title(f'OOD Detection')
  vis = det2.utils.visualizer.Visualizer(image*255, m, scale=0.4)
  axs[5].imshow(vis.draw_sem_seg(adapted_pred).get_image())
  axs[5].axis('off')
  axs[5].set_title(f'Adapted Prediction')
  vis = det2.utils.visualizer.Visualizer(image*255, m, scale=0.4)
  axs[4].imshow(vis.draw_sem_seg(cluster+12).get_image())
  axs[4].axis('off')
  axs[4].set_title(f'Deeplabv3 Clustering')
  plt.show()

In [None]:
subset='oaisys_trajectory_grass'
methods=['label', 'ensemble_pred', 'pred401-merged-ensemble_pred-seg107', 'seg440']
OAISYS_LABELS[15] = 'Sand'
m = get_measurements(subset=subset, pretrained_id='deeplab', methods=methods)
df = pd.DataFrame.from_dict({name: {
    OAISYS_LABELS[l]: m[name]['assigned_iou'][l]
    for l in range(37)
    if not np.isnan(m['label']['assigned_iou'][l])
} for name in m if not name == 'label'}).T
df = df.fillna(0)
df['mean'] = df.mean(axis=1)
df['v_score'] = [m[name]['v_score'] for name in df.index]
pd.set_option("display.float_format", lambda f: f"{f:0.2f}")
#print(df.to_latex(float_format=lambda x: f'{x * 100:.0f}'))
df.sort_values(by='mean', ascending=False)[:45]

### Predict and Visualize MarsScapes

In [None]:
label_path = '/home/asl/Downloads/MarsScapes/processed/train/439_2_2_13_1_semanticId.png'
mars_label = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_RGB2BGR)
print(mars_label[:, :, 0])
print(mars_label.shape)

In [None]:
skip_frames = 20
uncert_treshold = 0
subset = 'oaisys_trajectory3'
uncert = 'maxlogit-pp'
postfix = ''
stop_index = None
OAISYS_LABELS[15] = 'Sand'#'Grass-Field'
subsample = 1000
num_classes = 11
device = 'cpu'
split = 'validation'
pretrained_model = '/home/asl/Downloads/file_transfer/logs/282/deeplab_oaisys_1000_test_00004epochs.pth'

mars_labels = ['soil','bedrock','gravel','sand','big rock','steep slope','sky','unknown']


# DATA SETUP
data = tfds.load(f'{subset}', split=split, as_supervised=True)
dataset = TFDataIterableDataset(data.map(data_converter_rugd))
data_loader = torch.utils.data.DataLoader(dataset=dataset,
                                           batch_size=1,
                                           pin_memory=True,
                                           drop_last=True)

# MODEL SETUP
model = torchvision.models.segmentation.deeplabv3_resnet101(
    pretrained=False,
    pretrained_backbone=False,
    progress=True,
    num_classes=num_classes,
    aux_loss=None)
checkpoint, pretrained_id = get_checkpoint(pretrained_model)
# remove any aux classifier stuff
removekeys = [k for k in checkpoint.keys() if k.startswith('aux_classifier')]
for k in removekeys:
  del checkpoint[k]
load_checkpoint(model, checkpoint)
model.to(device)
model.eval()

# DIRECTORY SETUP
directory = os.path.join(EXP_OUT, 'oaisys_inference', f'{subset}', pretrained_id)
os.makedirs(directory, exist_ok=True)

cm = torchmetrics.ConfusionMatrix(num_classes=16, compute_on_step=False)

# INFERENCE
for idx, (image, label) in enumerate(tqdm(data_loader)):
  # run inference
  logits = model(image)['out']
  max_logit, pred = torch.max(logits, 1)
  # print(image.shape)
  # read image
  image_path = '/home/asl/Downloads/MarsScapes/processed/train/439_2_2_13_1.png'
  mars_img =  cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_RGB2BGR)/255.
  # print(img.shape)
  mars_img = torch.tensor(mars_img).permute([2, 0, 1])[None,:].type(torch.FloatTensor)
  mars_logits = model(mars_img)['out']
  mars_max_logit, mars_pred = torch.max(mars_logits, 1)
  break


m = det2.data.Metadata()
m.stuff_classes = OAISYS_LABELS + [f'c{i}' for i in range(80)]
m.stuff_colors = NYU40_COLORS + NYU40_COLORS + NYU40_COLORS
m_bool = det2.data.Metadata()
m_bool.stuff_classes = ['outlier', 'inlier']
m_bool.stuff_colors = NYU40_COLORS

m_mars = det2.data.Metadata()
m_mars.stuff_classes = mars_labels
m_mars.stuff_colors = NYU40_COLORS

image = tf.transpose(image[0], perm=[1, 2, 0]).numpy()
label = label[0].numpy()
pred = pred[0].numpy()
mars_img = tf.transpose(mars_img[0], perm=[1, 2, 0]).numpy()
label_path = '/home/asl/Downloads/MarsScapes/processed/train/439_2_2_13_1_semanticId.png'
mars_label = cv2.cvtColor(cv2.imread(label_path), cv2.COLOR_RGB2BGR)[:,:,0]

mars_pred = mars_pred[0].numpy()
_, axs = plt.subplots(1, 6, figsize=(30, 20))
axs[0].imshow(image)
axs[0].axis('off')
axs[0].set_title(f'image')
vis = det2.utils.visualizer.Visualizer(image*255, m, scale=0.4)
axs[1].imshow(vis.draw_sem_seg(label).get_image())
axs[1].axis('off')
axs[1].set_title(f'label')
vis = det2.utils.visualizer.Visualizer(image*255, m, scale=0.4)
axs[2].imshow(vis.draw_sem_seg(pred).get_image())
axs[2].axis('off')
axs[2].set_title(f'prediction ')
axs[3].imshow(mars_img)
axs[3].axis('off')
axs[3].set_title(f'mars image')
vis = det2.utils.visualizer.Visualizer(mars_img*255, m_mars, scale=0.4)
axs[4].imshow(vis.draw_sem_seg(mars_label).get_image())
axs[4].axis('off')
axs[4].set_title(f'mars label')
vis = det2.utils.visualizer.Visualizer(mars_img*255, m, scale=0.4)
axs[5].imshow(vis.draw_sem_seg(mars_pred).get_image())
axs[5].axis('off')
axs[5].set_title(f'mars prediction')
