In [17]:
from extract_lidar_features import *
from visualization_lidar import *
from load_lidar_data import *

In [None]:
downsample_factor = 4
threshold = 10
transform_method = "normalization_without_mask"

# Load data to numpy array
lidar_data = load_lidar_data(num_images=6, input_h=1280, input_w=1920, downsample=downsample_factor)
# Convert the data to the Dataset class
lidar_data = ProjectedLidarDataset(lidar_data, downsample=downsample_factor, transform_method=transform_method, threshold=threshold)

# Calculate the input size
num_samples, height, width = lidar_data.shape

# Create dataloader
train_data_loader = create_dataloader(input_lidar=lidar_data, batch_size=4)

In [None]:
model, losses = cae_train(train_data_loader, epochs=200, lr=0.00001)

In [None]:
visualize_cae_result(model, lidar_data, recon_height=height, recon_width=width, num_images2show=6, device=device)


In [None]:
test_data = load_lidar_data(num_images=6, input_h=1280, input_w=1920, downsample=downsample_factor)
visualize_cae_result(model, lidar_data, recon_height=height, recon_width=width, num_images2show=6, device=device)

In [None]:
# save_path = '/home/meowater/Documents/ssd_drive/CAE_models/'
# os.makedirs(save_path, exist_ok=True)
#
# output_fn = os.path.join(save_path, 'lidar_data_cae_model.pkl')
# with open(output_fn, 'wb') as handle:
#     pickle.dump((model, losses), handle, protocol=pickle.HIGHEST_PROTOCOL)

model_path = '/home/meowater/Documents/ssd_drive/CAE_models/lidar_data_cae_model.pkl'

with open(model_path, 'rb') as f:
    model, _ = pickle.load(f)

In [None]:
def output_cae_encoded(input_lidar, cae_model):
    cae_model.eval()
    with torch.no_grad():
        x = torch.FloatTensor(input_lidar).to(device)
        reshaped_x = x.unsqueeze(0)
        reshaped_x = reshaped_x.unsqueeze(0)
        encoded= cae_model.encode(reshaped_x)

        encoded = encoded.squeeze(0).cpu().numpy()
    return encoded

In [None]:

# Go through all lidar data
lidar_path = '/home/meowater/Documents/ssd_drive/lidar_projected/'
save_path = '/home/meowater/Documents/ssd_drive/lidar_projected_cae/'

lidar_list = glob.glob(os.path.join(lidar_path, '*/*.pkl'), recursive=True)

for fn in tqdm(lidar_list):

    _, base_name = os.path.split(fn)
    name_prefix = base_name.split('.')[0]
    context_name = name_prefix.split('_camera_image_camera')[0]
    timestamp_name = name_prefix.split('_timestamp-')[-1]

    sub_path = os.path.join(save_path, context_name)
    os.makedirs(sub_path, exist_ok=True)
    with open(fn, 'rb') as handle:
        data = pickle.load(handle)

    projected_lidar = resize(data["lidar_projection"],
                                  output_shape=(height.cpu().numpy(), width.cpu().numpy()),
                                  anti_aliasing=True)

    output = output_cae_encoded(projected_lidar, model)



    output_data = {
        "context_name": context_name,
        "timestamp": timestamp_name,
        "lidar_extracted": output
    }
    new_fn = os.path.join(sub_path, name_prefix + '_cae_feature.pkl')

    with open(new_fn, 'wb') as handle:
        pickle.dump(output_data, handle, protocol=pickle.HIGHEST_PROTOCOL)




MaskRCNN(
  (transform): GeneralizedRCNNTransform(
      Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
      Resize(min_size=(800,), max_size=1333, mode='bilinear')
  )
  (backbone): BackboneWithFPN(
    (body): IntermediateLayerGetter(
      (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
      (bn1): FrozenBatchNorm2d(64, eps=0.0)
      (relu): ReLU(inplace=True)
      (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
      (layer1): Sequential(
        (0): Bottleneck(
          (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn1): FrozenBatchNorm2d(64, eps=0.0)
          (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
          (bn2): FrozenBatchNorm2d(64, eps=0.0)
          (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (bn3): FrozenBatchNorm2d(256, eps=0.0)
          (relu): ReLU(in