In [None]:
import shutil
from pathlib import Path
from os.path import join
from dataclasses import asdict
import pickle
import torch

from perception_bev_learning.dataset import get_bev_dataloader
from perception_bev_learning.cfg import ExperimentParams
from perception_bev_learning.lightning import LightningBEV
from perception_bev_learning.utils import denormalize_img

from dataclasses import asdict
import yaml
import os
import random
import time
from sklearn.decomposition import PCA
import numpy as np
from efficientnet_pytorch import EfficientNet
from torch import nn

In [None]:
cfg = ExperimentParams()
cfg.update()
module = LightningBEV(cfg)
loader_train, loader_val = get_bev_dataloader(cfg, return_test_dataloader=False)

for batch in loader_train:
    imgs, rots, trans, intrins, post_rots, post_trans, target, aux, *_, pcd_new = batch
    break

In [None]:
# Just use pretrained efficientnet-b0 to get features maps
# Inference manually
D = module._model.image_backbone.D
C = module._model.image_backbone.camC
# downsample = module._model.image_backbone.downsample
trunk = EfficientNet.from_pretrained("efficientnet-b0")
BS, NR_CAMS = tuple(imgs.shape)[:2]
# Flatten BS and NR_CAMs
x = imgs.reshape(-1, 3, 256, 384)

endpoints = dict()
# Stem
x = trunk._swish(trunk._bn0(trunk._conv_stem(x)))
prev_x = x

# Blocks
for idx, block in enumerate(trunk._blocks):
    drop_connect_rate = trunk._global_params.drop_connect_rate
    if drop_connect_rate:
        drop_connect_rate *= float(idx) / len(trunk._blocks)  # scale drop connect_rate
    x = block(x, drop_connect_rate=drop_connect_rate)
    if prev_x.size(2) > x.size(2):
        endpoints["reduction_{}".format(len(endpoints) + 1)] = prev_x
    prev_x = x
# Head
endpoints["reduction_{}".format(len(endpoints) + 1)] = x

feat = nn.functional.interpolate(endpoints["reduction_5"], scale_factor=2, mode="bilinear", align_corners=True)
img_features = endpoints["reduction_4"].detach()  # torch.cat( [feat, endpoints["reduction_4"]], dim=1).detach()

In [None]:
img_features = img_features.reshape(BS, NR_CAMS, -1, 16, 24).reshape(-1, BS, NR_CAMS, 16, 24)
np_img_features = img_features[:, :, :, :, :].cpu().numpy().reshape(img_features.shape[0], -1).T
n_components = 12
pca = PCA(n_components=n_components)
pca.fit(np_img_features)
pca_comp = pca.transform(np_img_features).reshape(2, 4, 16, 24, n_components)
pca_comp_img = pca_comp[0, 0]
pca_com_inter = nn.functional.interpolate(
    torch.from_numpy(pca_comp_img.T)[None], scale_factor=16, mode="bilinear", align_corners=True
).T[:, :, :, 0]
pca_com_three = pca_com_inter[:, :, 0:3].numpy()
pca_com_three -= pca_com_three.min()
pca_com_three /= pca_com_three.max()
pca_com_three
from PIL import Image

Image.fromarray(np.uint8(pca_com_three * 255))

In [None]:
denormalize_img(imgs[0, 0])

In [None]:
from perception_bev_learning.network import BevTravNet

path_checkpoint_folder = "/media/Data/Results/bev_learning/2023-08-28T09-16-59_aux_regress_elevation"
device = "cpu"

checkpoint = None
if checkpoint is None:
    checkpoints = [str(s) for s in Path(path_checkpoint_folder).rglob("*.ckpt")]
    checkpoints.sort()
    checkpoint = checkpoints[-1]
ckpt = torch.load(checkpoint)
with open(join(path_checkpoint_folder, "experiment_params.pkl"), "rb") as file:
    cfg = pickle.load(file)

_cfg = cfg
# Initalize Model
_model = BevTravNet(cfg.model, cfg.dataset_train)
_model.to(device)
_model.eval()
state_dict = {k.replace("_model.", ""): v for k, v in ckpt["state_dict"].items() if "_model." in k}
_model.load_state_dict(state_dict)

In [None]:
geom = _model.image_backbone.get_geometry(rots, trans, intrins, post_rots, post_trans)
geom.shape

x = imgs.reshape(-1, 3, 256, 384)
depth, x = _model.image_backbone.camencode.get_depth_feat(x)
depth.shape, geom.shape
aux[0]

In [None]:
from perception_bev_learning.visu import paper_colors_rgb_f
import open3d as o3d
import numpy as np

N = 4
points = geom
co = [torch.tensor(v, device=points.device) for v in paper_colors_rgb_f.values()]
pcd = o3d.geometry.PointCloud()
vis = points[0].reshape(-1, 3).clone()
col = points[0].reshape(-1, 3).clone()
delt = int(vis.shape[0] / N)
for i in range(N):
    col[int(i * delt) : int((i + 1) * delt)] = co[i]
pcd.points = o3d.utility.Vector3dVector(vis.cpu().numpy())
pcd.colors = o3d.utility.Vector3dVector(col.cpu().numpy())
mesh_frame = o3d.geometry.TriangleMesh.create_coordinate_frame(size=10.0, origin=[0, 0, 0])
o3d.visualization.draw_geometries([mesh_frame, pcd])