## Our dependencies

In [None]:
!pip install omegaconf
!pip install  pyntcloud
!pip install open3d
!pip install OpenCV
!pip install Plotly
!pip install psutil requests
!apt install aria2
!apt-get install -y orca
import tarfile
import os
import pyarrow.feather as feather
import numpy as np
import plotly.graph_objs as go
from ipywidgets import Button, VBox
from IPython.display import display, clear_output
from PIL import Image
import matplotlib.pyplot as pyplot
import tensorflow as tf
from tensorflow.keras.layers import Input, Dense, Reshape, Conv2D, BatchNormalization, Activation, UpSampling2D, Flatten
from tensorflow.keras.models import Model
import pyarrow.feather as feather
import open3d as o3d

## Downloading the data
##### Note: This will taske a while, please use aria2c as this speeds up the download read more: add_Me

In [None]:
# Create directories for train and test files if they don't exist
!mkdir -p /argotrain
#!mkdir -p /argotest

# Train 1
!aria2c -x 16 https://s3.amazonaws.com/argoverse/datasets/av2/tars/sensor/train-000.tar -d /argotrain -o train-000.tar

# Test 1
!aria2c -x 16 https://s3.amazonaws.com/argoverse/datasets/av2/tars/sensor/test-000.tar -d /argotest -o test-000.tar


## To visualize lidar data

import plotly.graph_objs as go
import numpy as np
import pyarrow.feather as feather
import numpy as np
import open3d as o3d


lidar_file = '/argotrain/sensor/train/0749e9e0-ca52-3546-b324-d704138b11b5/sensors/lidar/315972769159647000.feather'


lidar_df = feather.read_feather(lidar_file)


lidar_data = lidar_df.to_numpy()

# Prepare your point cloud data
x = lidar_data[:, 0]
y = lidar_data[:, 1]
z = lidar_data[:, 2]


intensity = lidar_data[:, 3] if lidar_data.shape[1] > 3 else np.ones_like(x)
camera = dict(
    eye=dict(x=0.980957, y=-4.300781, z=1.241211)  # Position the camera here (x, y, z are the distances)
)

# Create a 3D scatter plot
fig = go.Figure(data=[go.Scatter3d(
    x=x, y=y, z=z,
    mode='markers',
    marker=dict(
        size=1,
        color=intensity,
        colorscale='Viridis',
        opacity=0.8

    )
)])

# Set plot details
fig.update_layout(scene = dict(
                    xaxis_title='X',
                    yaxis_title='Y',
                    zaxis_title='Z'),
                    width=700,
                    margin=dict(r=0, l=0, b=0, t=0),
                    meta=camera)

fig.show()
fig.write_html("/contentlidar_plot.html")

### To visualize images

In [None]:
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import tarfile
import os
import pyarrow.feather as feather
import numpy as np
import plotly.graph_objs as go
from ipywidgets import Button, VBox
from IPython.display import display, clear_output
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
import re

# place holder.
image_file=  '/argotrain/sensor/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/sensors/cameras/ring_front_center/315967378149927216.jpg'

image = mpimg.imread(image_file)
plt.imshow(image)
plt.axis('off')  # Hide axes
plt.show()


# helper function if you want to just visualize images via file path
def plot_lidar_and_image(feather_file:str):
    try:
        # Extract timestamp from feather file name
        match = re.search(r"(\d+)\.feather", feather_file)
        if not match:
            print("Could not extract timestamp from feather file name.")
            return

        lidar_timestamp = int(match.group(1))

        lidar_df = feather.read_feather(feather_file)
        lidar_data = lidar_df.to_numpy()

        # ... (your existing lidar plotting code)

        # Find the nearest image
        image_dir = os.path.dirname(feather_file).replace("sensors/lidar", "sensors/cameras/ring_front_center")
        image_files = [f for f in os.listdir(image_dir) if f.endswith(".jpg")]


        min_diff = float('inf')
        closest_image = None

        for image_file in image_files:
            image_match = re.search(r"(\d+)\.jpg", image_file)
            if image_match:
                image_timestamp = int(image_match.group(1))
                diff = abs(image_timestamp - lidar_timestamp)
                if diff < min_diff:
                    min_diff = diff
                    closest_image = os.path.join(image_dir, image_file)

        if closest_image:
          print(closest_image)
          image = mpimg.imread(closest_image)
          plt.figure(figsize=(8, 6))  # Adjust figure size as needed
          plt.imshow(image)
          plt.axis('off')  # Hide axes
          plt.title(f"Nearest Image to {feather_file}")  # Display file name
          plt.show()

        else:
            print("No corresponding image file found.")

    except Exception as e:
        print(f"An error occurred: {e}")
# Example usage (replace with your actual file path):
plot_lidar_and_image('/argotrain/sensor/train/0749e9e0-ca52-3546-b324-d704138b11b5/sensors/lidar/315972769159647000.feather')

## Our Data loaders

In [None]:

import os
import torch
from torch.utils.data import Dataset, DataLoader
import pyarrow.feather as feather
import numpy as np
import re
from PIL import Image

class ArgoverseDataset_Images_To_Lidar(Dataset):
    def __init__(self, root_dir:str, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.lidar_files = []
        for subdir, _, files in os.walk(self.root_dir):
            for file in files:
                if file.endswith(".feather"):
                    self.lidar_files.append(os.path.join(subdir, file))

    def __len__(self):
        return len(self.lidar_files)

    def __getitem__(self, idx:int):
        lidar_file = self.lidar_files[idx]

        try:
            # Extract timestamp from feather file name
            match = re.search(r"(\d+)\.feather", lidar_file)
            if not match:
                print(f"Could not extract timestamp from {lidar_file}. Skipping.")
                return None  # Skip this item if timestamp extraction fails

            lidar_timestamp = int(match.group(1))

            lidar_df = feather.read_feather(lidar_file)
            lidar_data = lidar_df.to_numpy()

            # Normalize lidar data (example)
            lidar_data = lidar_data[:,0:3]
            lidar_data = lidar_data / np.max(np.abs(lidar_data)) # Normalize to [-1, 1]
            lidar_data = torch.tensor(lidar_data, dtype=torch.float32)
            # example of explanations
            #/argotrain/sensor/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/sensors/lidar/315967376859506000.feather
            #/argotrain/sensor/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/sensors/cameras/ring_front_center
            #/argotrain/sensor/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/sensors/cameras/ring_front_left
            #/argotrain/sensor/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/sensors/cameras/ring_front_right
            #/argotrain/sensor/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/sensors/cameras/ring_rear_left
            #/argotrain/sensor/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/sensors/cameras/ring_rear_right
            #/argotrain/sensor/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/sensors/cameras/ring_side_left
            #/argotrain/sensor/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/sensors/cameras/stereo_front_left
            #/argotrain/sensor/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/sensors/cameras/stereo_front_right
            #/argotrain/sensor/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/sensors/cameras/ring_front_center/315967376899927209.jpg

            # Load surrounding images
            images = []
            base_dir = os.path.dirname(lidar_file)
            camera_types = {
                "front": "ring_front_center",
                "left": "ring_front_left",
                "right": "ring_front_right",
                "rear_left":"ring_rear_left",
                "rear_right":"ring_rear_right",
                "side_left":"ring_side_left",
                "side_right":"ring_side_right",
                "stereo_front_left":"stereo_front_left",
                "stereo_front_right":"stereo_front_right"
                # Add more cameras as needed
            }

            for cam_name, cam_type in camera_types.items():
                image_dir = base_dir.replace("sensors/lidar", f"sensors/cameras/{cam_type}")
                min_diff = float('inf')
                closest_image = None

                for image_file in os.listdir(image_dir):
                    if image_file.endswith(".jpg"):
                        image_match = re.search(r"(\d+)\.jpg", image_file)
                        if image_match:
                            image_timestamp = int(image_match.group(1))
                            diff = abs(image_timestamp - lidar_timestamp)
                            if diff < min_diff:
                                min_diff = diff
                                closest_image = os.path.join(image_dir, image_file)

                if closest_image:
                    image = Image.open(closest_image).convert("RGB")
                    if self.transform:
                        image = self.transform(image)
                    image=image[0:1,:,:]
                    images.append(image)
                else:
                    print(f"Missing image for camera {cam_name} near {lidar_file}")

                    images.append(torch.zeros(3, 224, 224))  # Example placeholder, in case of error, just for safety

            print("lidar_data", lidar_data.shape)
            print("images", torch.stack(images).shape)

            # Return lidar data and images as tensors
            return lidar_data, torch.stack(images)

        except Exception as e:
            print(f"Error processing {lidar_file}: {e}")
            return None
# Assuming we've defined our transformations (e.g., using torchvision.transforms) for y labels and vice versa
from torchvision import transforms

transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images
    transforms.ToTensor(),           # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) #Normalize
])


dataset = ArgoverseDataset('/argotrain/sensor/train', transform=transform)
dataloader = DataLoader(dataset, batch_size=4, shuffle=True, collate_fn=lambda x: [i for i in x if i is not None])

for batch in dataloader:
    if batch: # Check if batch is not empty
        # Access lidar_data and images from the inner tuples
        lidar_data, images = batch[0][0], batch[0][1]

        print("Lidar data shape:", lidar_data.shape)
        print("Images shape:", images.shape)
    
    break # Break after our first batch for demonstration

### For our lidar generator via images

In [None]:
import torch
import torch.nn as nn

class LidarGenerator(nn.Module):
    def __init__(self, num_images:int=9, image_channels:int=1):
        super(LidarGenerator, self).__init__()
        self.num_images = num_images
        self.image_channels = image_channels

        # Image Conv is our siamese network
        self.image_conv = nn.Sequential(
            nn.Conv2d(image_channels, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),

            nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.SELU(),

            nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.Flatten()
        )
        #self.attention= nn.Transformer(d_model=512, nhead=4,6,) FIXME
        # Lidar generation layers
        self.lidar_fc = nn.linear(256 * 28 * 28, 10500*3) #remember our X,Y,Z plus padding
    def forward(self, images):  # Add latent_vec to the input arguments

        # Process images
        embedings=[]
        for image in images:
          image_features = self.image_conv(image)
          embedings.append(embedings)

        # Generate lidar points
        lidar_points = self.lidar_fc(combined)
        return lidar_points

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LidarGenerator().to(device)


#generated_lidar = model(images.to(device))

#print("Generated lidar shape:", generated_lidar.shape) FIXME

criterion = nn.MSELoss()  # THIS LOSS IS SUBJECT TO CHANGE
optimizer = torch.optim.Adam(model.parameters(), lr=0.0002)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    for lidar_data, images in dataloader:
      if lidar_data is None or images is None:
        continue

      lidar_data, images = lidar_data.to(device), images.to(device)
      optimizer.zero_grad()
      generated_lidar = model(images)
      loss = criterion(generated_lidar, lidar_data)
      loss.backward()
      # Update the weights
      optimizer.step()

      print(f"Epoch [{epoch + 1}/{num_epochs}], Loss: {loss.item():.4f}")

#### Inspecting our training files

In [None]:
import os
print(len(os.listdir('/argotrain/sensor/train/00a6ffc1-6ce9-3bc3-a060-6006e9893a1a/sensors/lidar')))

print(len((os.listdir('/argotrain/sensor/train/'))))
list_of_examples=os.listdir('/argotrain/sensor/train/')
lidar_file_paths=[]
for path in list_of_examples:
  path='/argotrain/sensor/train/'+path+'/sensors/lidar'
  print('lidar examples count',len(os.listdir(path)))
  files = os.listdir(path)
  for file in files:
        lidar_file_paths.append(os.path.join(path, file))

lidar_file_paths=sorted(lidar_file_paths)
ring_front_center_file_paths=[]
for path in list_of_examples:
    path='/argotrain/sensor/train/'+path+'/sensors/cameras/ring_front_center'
    print('camera examples count',len(os.listdir(path)))
    files = os.listdir(path)
    for file in files:
        ring_front_center_file_paths.append(os.path.join(path, file))

ring_front_center_file_paths=sorted(ring_front_center_file_paths)



ring_rear_right_file_paths = []
for path in list_of_examples:
    path='/argotrain/sensor/train/'+path+'/sensors/cameras/ring_rear_right'
    print('camera examples count',len(os.listdir(path)))
    files = os.listdir(path)
    for file in files:
        ring_rear_right_file_paths.append(os.path.join(path, file))
ring_rear_right_file_paths=sorted(ring_rear_right_file_paths)


ring_side_right_file_paths = []
for path in list_of_examples:
    path='/argotrain/sensor/train/'+path+'/sensors/cameras/ring_side_right'
    print('camera examples count',len(os.listdir(path)))

    files = os.listdir(path)
    for file in files:
        ring_side_right_file_paths.append(os.path.join(path, file))

ring_side_right_file_paths=sorted(ring_side_right_file_paths)
ring_side_left_file_paths=[]
for path in list_of_examples:
    path='/argotrain/sensor/train/'+path+'/sensors/cameras/ring_side_left'
    print('camera examples count',len(os.listdir(path)))
    # Concatenate the directory path with each file name and add it to the list
    files = os.listdir(path)
    for file in files:
        ring_side_left_file_paths.append(os.path.join(path, file))


ring_side_left_file_paths=sorted(ring_side_left_file_paths)




## Other helper functions 

In [None]:
def get_max_avg_points(lidar_file_paths:str):
    # Function to determine the maximum number of points in the LiDAR datasets
    max_points = 0
    avg=0

    for path in lidar_file_paths:
        lidar_data = get_file_data(path, islidar=True)
        avg=avg+lidar_data.shape[0]
        if lidar_data.shape[0] > max_points:
            max_points = lidar_data.shape[0]
    return max_points,avg/len(lidar_file_paths)


def pad_lidar_data(lidar_data:Tensor, max_points:int):
    # Function to pad the LiDAR data to the maximum number of points
    padding_size = max_points - lidar_data.shape[0]
    if padding_size > 0:
        padding = np.zeros((padding_size, lidar_data.shape[1]))
        lidar_data = np.vstack([lidar_data, padding])
    return lidar_data



def truncate_lidar_data(lidar_data:Tensor, max_points:int):
    if lidar_data.shape[0] > max_points:
        lidar_data = lidar_data[:max_points, :]
    return lidar_data



def get_file_data(file_path:str, islidar=False):
    if islidar:
        return feather.read_feather(file_path).to_numpy()
    else:
        image = np.asarray(Image.open(file_path))
        #Since front camera is different
        if image.shape[0] == 2048 and image.shape[1] == 1550:
            image = np.transpose(image, (1, 0, 2))
        return image


def visualize_image(file_path:str):
    image = Image.open(file_path)
    pyplot.imshow(image)
    pyplot.show()


def visualize_lidar(file:str):

    file_path = file
    lidar_df = feather.read_feather(file_path)
    lidar_data = lidar_df.to_numpy()


    x = lidar_data[:, 0]
    y = lidar_data[:, 1]
    z = lidar_data[:, 2]
    intensity = lidar_data[:, 3] if lidar_data.shape[1] > 3 else np.ones_like(x)


    fig = go.Figure(data=[go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(
            size=1,
            color=intensity, 
            colorscale='Viridis',
            opacity=0.8
        )
    )])


    fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
                      width=700, height=500,
                      margin=dict(r=20, l=10, b=10, t=10))


    clear_output(wait=True)
    fig.show()



def visualize_lidar_data(lidar_data:Tensor):
    x = lidar_data[:, 0]
    y = lidar_data[:, 1]
    z = lidar_data[:, 2]
    intensity = lidar_data[:, 3] if lidar_data.shape[1] > 3 else np.ones_like(x)

    fig = go.Figure(data=[go.Scatter3d(
        x=x, y=y, z=z,
        mode='markers',
        marker=dict(
            size=1,
            color=intensity,  # Color by intensity or z
            colorscale='Viridis',
            opacity=0.8
        )
    )])


    fig.update_layout(scene=dict(xaxis_title='X', yaxis_title='Y', zaxis_title='Z'),
                      width=700, height=500,
                      margin=dict(r=20, l=10, b=10, t=10))


    clear_output(wait=True)
    fig.show()


def flatten_y_train(Y_train:bool):
    batch_size = Y_train.shape[0]

    Y_train_flattened = Y_train.reshape(batch_size, -1)
    return Y_train_flattened

def flatten_x_train(X_train:Tensor):
    batch_size = X_train.shape[0]
    X_train_flattened = X_train.reshape(batch_size, -1)
    return X_train_flattened
max_cloud_ppoints,avg_cloud_ppoints=get_max_avg_points(lidar_file_paths)

## Lidar to images data generator

In [None]:
def generate_training_data():
    X_train = []
    Y_train = []
    X_test = []
    Y_test = []
    target_count_for_test=0

    # Iterate over target indices (twice the number of lidar file paths)
    for target_index in range(int((len(lidar_file_paths) * 2) * 0.009)): #reason why I have this is that we can use sampling techniques to better short cut the heave number of data usage
        target_index=target_index*100
        # Fetch the camera file paths for the target index
        ring_rear_right_path = ring_rear_right_file_paths[target_index]
        ring_side_right_path = ring_side_right_file_paths[target_index]
        ring_side_left_path = ring_side_left_file_paths[target_index]
        ring_front_center_path = ring_front_center_file_paths[target_index]

        # Fetch the lidar file (one lidar file for every two camera files)
        lidar_index = int(target_index / 2)
        lidar_df = get_file_data(lidar_file_paths[lidar_index], islidar=True)
        if(max_cloud_ppoints>lidar_df.shape[0]):
          lidar_df=pad_lidar_data(lidar_df,max_cloud_ppoints)
        if(max_cloud_ppoints<lidar_df.shape[0]):
          lidar_df=truncate_lidar_data(lidar_df,max_cloud_ppoints)
        lidar_df = lidar_df[:, :3]  # Only take the first 3 columns (x, y, z)
        # Fetching camera data for each target view
        ring_rear_right_df = get_file_data(ring_rear_right_path)
        ring_side_right_df = get_file_data(ring_side_right_path)
        ring_side_left_df = get_file_data(ring_side_left_path)
        ring_front_center_df = get_file_data(ring_front_center_path)
        # Defining flag arrays for the views
        flag_rear_right = np.array([0, 0, 0, 1])
        flag_side_right = np.array([0, 0, 1, 0])
        flag_side_left = np.array([0, 1, 0, 0])
        flag_front_center = np.array([1, 0, 0, 0])


        # Create input data by appending LiDAR data with the corresponding flag, we are using multi-task learning hence why this is done, there is a way of this redundency but havent gotten my way around it yet
        input_rear_right = np.hstack([lidar_df, np.tile(flag_rear_right, (lidar_df.shape[0], 1))])
        input_side_right = np.hstack([lidar_df, np.tile(flag_side_right, (lidar_df.shape[0], 1))])
        input_side_left = np.hstack([lidar_df, np.tile(flag_side_left, (lidar_df.shape[0], 1))])
        input_front_center = np.hstack([lidar_df, np.tile(flag_front_center, (lidar_df.shape[0], 1))])
        #print('index ', target_index)
        #print('input_rear_right',input_rear_right.shape)
        #print('input_side_right',input_side_right.shape)

        #print('input_side_left',input_side_left.shape)
        #print('input_front_center',input_front_center.shape)

        # Checking and making sure that the camera data has the same shape before appending
        if ring_rear_right_df.shape == ring_side_right_df.shape == ring_side_left_df.shape == ring_front_center_df.shape:
            if target_count_for_test<=int((len(lidar_file_paths) * 2) * 0.009 *0.8*4 ):
              X_train.append(input_rear_right)
              Y_train.append(ring_rear_right_df)
              target_count_for_test=target_count_for_test+1
            else:
              X_test.append(input_rear_right)
              Y_test.append(ring_rear_right_df)
              target_count_for_test=target_count_for_test+1
            if(target_count_for_test<int((len(lidar_file_paths) * 2) * 0.009 *0.8 *4)):
                X_train.append(input_side_right)
                Y_train.append(ring_side_right_df)
                target_count_for_test=target_count_for_test+1
            else:
                X_test.append(input_side_right)
                Y_test.append(ring_side_right_df)
                target_count_for_test=target_count_for_test+1
            if(target_count_for_test<int((len(lidar_file_paths) * 2) * 0.009 *0.8*4 )):
                X_train.append(input_side_left)
                Y_train.append(ring_side_left_df)
                target_count_for_test=target_count_for_test+1
            else:
                X_test.append(input_side_left)
                Y_test.append(ring_side_left_df)
                target_count_for_test=target_count_for_test+1

            if(target_count_for_test<int((len(lidar_file_paths) * 2) * 0.009 *0.8*4)):
                X_train.append(input_front_center)
                Y_train.append(ring_front_center_df)
                target_count_for_test=target_count_for_test+1
            else:
                X_test.append(input_front_center)
                Y_test.append(ring_front_center_df)
                target_count_for_test=target_count_for_test+1


        else:
            print(f"Shape mismatch at index {target_index}. Skipping this set.")


    try:
        X_train = np.asarray(X_train)
        Y_train = np.asarray(Y_train)
        X_test = np.asarray(X_test)
        Y_test = np.asarray(Y_test)
        print(f"Generated and saved X_train and Y_train with shapes {X_train.shape} and {Y_train.shape}")
        print(f"Generated and saved X_test and Y_test with shapes {X_test.shape} and {Y_test.shape}")
    except ValueError as e:
        print(f"Error during conversion to numpy arrays: {e}")

    return X_train, Y_train, X_test, Y_test


## Model spec

In [None]:
from tensorflow.keras.layers import Input, Dense, BatchNormalization, Reshape, Conv2DTranspose, Activation
from tensorflow.keras.models import Model
import matplotlib.pyplot as plt
from tensorflow.keras import backend as K
def custom_flattened_mse_loss(y_true:Tensor, y_pred:Tensor):
        #tf.print("applying loss")

        # Flatten the ground truth (y_true) and predicted (y_pred) tensors
        tf.print(y_true.shape)
        tf.print(y_pred.shape)
        y_true_flat = K.flatten(y_true)
        tf.print(y_true_flat.shape)
        #y_pred_flat = K.flatten(y_pred)


        mse_loss = K.mean(K.square(y_true_flat - y_pred))


        return mse_loss
class GenerateResNetFCC:
    #Note: This is a quick dry run on the task of changing Lidar to image.
    def __init__(self, input_shape, output_shape):
        self.input_shape = input_shape  # Shape of the LiDAR input (e.g., [N, 7])
        self.output_shape = output_shape  # Target image shape (e.g., [1550, 2048, 3])

        # Create the ResNet-Fully
        self.model = self.build_model()

    def build_model(self):
        #tf.print("bulding model")
        # The input layer for LiDAR data (x, y, z + flag)
        input_layer = Input(shape=self.input_shape)

        
        x = Dense(64, activation='relu')(input_layer)
        x = BatchNormalization()(x)
        x = Dense(128, activation='relu')(x)
        x = BatchNormalization()(x)
        # Projection into a larger feature space
        x = Dense(self.output_shape, activation='relu')(x)


        model = Model(inputs=input_layer, outputs=x)

        model.compile(optimizer='adam', loss='mse', metrics=['accuracy'])
        return model


    def train(self, X_train, Y_train, epochs=50, batch_size=32, validation_split=0.2):
        #tf.print('starting training')
        # Train the model using the input LiDAR data and target camera images
        history = self.model.fit(X_train, Y_train, epochs=epochs, batch_size=batch_size, validation_split=validation_split)
        return history
    def fit(self, X_train, Y_train, epochs=50, batch_size=32, validation_split=0.2)  :
        history = self.model.fit(X_train, Y_train, epochs=epochs, batch_size=batch_size, validation_split=validation_split)
        return history

    def generate_image(self, lidar_input):
        # Use the model to generate an image given LiDAR input
        generated_image = self.model.predict(lidar_input)
        return generated_image

    def save_model(self, path='generate_resnet_fcc_model.h5'):
        # Save the trained model
        self.model.save(path)

    def load_model(self, path='generate_resnet_fcc_model.h5'):
        # Load the model from a saved file
        self.model = tf.keras.models.load_model(path)

def plot_training_history(history):

    plt.figure(figsize=(10, 5))

    plt.subplot(1, 2, 1)
    plt.plot(history.history['loss'], label='Train Loss')
    plt.plot(history.history['val_loss'], label='Validation Loss')
    plt.title('Model Loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(loc='upper right')

    if 'accuracy' in history.history:
        plt.subplot(1, 2, 2)
        plt.plot(history.history['accuracy'], label='Train Accuracy')
        plt.plot(history.history['val_accuracy'], label='Validation Accuracy')
        plt.title('Model Accuracy')
        plt.ylabel('Accuracy')
        plt.xlabel('Epoch')
        plt.legend(loc='upper left')

    plt.tight_layout()
    plt.show()



## To train

In [None]:
X_train, Y_train,X_test,Y_test=generate_training_data()
X_train_flattened=flatten_x_train(X_train)
Y_train_flattened=flatten_y_train(Y_train)
Y_test_flattened=flatten_y_train(Y_test)
X_test_flattened=flatten_x_train(X_test)

In [None]:
with tf.device('/gpu:0'):
  history = model.model.fit(X_train_flattened, Y_train_flattened, epochs=1000, batch_size=32, validation_data=(X_test_flattened,Y_test_flattened))

plot_training_history(history)
predicted_output = model.model.predict(X_test_flattened[0:32])
print(predicted_output.shape)

predicted_image = predicted_output.reshape(32, 1550, 2048, 3)
predicted_image = np.clip(predicted_image, 0, 255).astype(np.uint8)

print(predicted_image[0][0][0])

for i in range(32):
    plt.imshow(predicted_image[i])
    plt.title("Generated Image from Lidar")
    plt.axis('off')  # To hide axes
    plt.show()
    plt.imshow(Y_test[i])
    plt.title("actual Image from camera")
    plt.axis('off')  # To hide axes
    plt.show()

## To visualize model generated image given lidar cloud points

import matplotlib.pyplot as plt
import numpy as np
predicted_output = model.model.predict(X_train_flattened[0:32])
print(predicted_output.shape)
predicted_image = predicted_output.reshape(32, 1550, 2048, 3)

predicted_image = np.clip(predicted_image, 0, 255).astype(np.uint8)
print(predicted_image[0][0][0])
for i in range(32):
    plt.imshow(predicted_image[i])
    plt.title("Generated Image from Lidar")
    plt.axis('off')  # To hide axes
    plt.show()