In [1]:
# import torch
# import torchvision
# import numpy as np
# from torchvision.models import resnet18, ResNet18_Weights

# resnet_model = resnet18(weights=ResNet18_Weights.DEFAULT)
# The current behavior is equivalent to passing `weights=ResNet18_Weights.IMAGENET1K_V1`. 
# You can also use `weights=ResNet18_Weights.DEFAULT` to get the most up-to-date weights.
# print(resnet_model)
# # This method remove the layers upto to the convolution feature maps
# modules = list(resnet_model.children())[:-2]
# backbone = torch.nn.Sequential(*modules)
# print(backbone)

## Version 1 of the NaviNet


## Version 2 of NaviNet

## Version 3 of Navinet 

In [11]:
import torch
from torch.nn import Conv2d, Linear,Parameter
from torch.nn import Module
from torchvision.models import efficientnet_b1,EfficientNet_B1_Weights
import h5py
import numpy as np







class RGBNet(Module):
    def __init__(self,ablation_depth=2):
        super().__init__()
        resnet_model = efficientnet_b1(weights=EfficientNet_B1_Weights.IMAGENET1K_V2)
        modules = list(resnet_model.children())[:-ablation_depth]
        self.backbone = torch.nn.Sequential(*modules)

    def forward(self, x):
        x = self.backbone(x)
        return x

class DepthNet(Module):
    def __init__(self):
        super().__init__()
        self.filter = Conv2d(in_channels=1, out_channels=2, kernel_size=3, 
                             stride=1, padding=0, bias=False)

        Gx = torch.tensor([[2.0, 0.0, -2.0], [4.0, 0.0, -4.0], [2.0, 0.0, -2.0]])
        Gy = torch.tensor([[2.0, 4.0, 2.0], [0.0, 0.0, 0.0], [-2.0, -4.0, -2.0]])
        G = torch.cat([Gx.unsqueeze(0), Gy.unsqueeze(0)], 0)
        G = G.unsqueeze(1)
        self.filter.weight = Parameter(G, requires_grad=False)

    def forward(self, img):
        x = self.filter(img)
        x = torch.mul(x, x)
        x = torch.sum(x, dim=1, keepdim=True)
        x = torch.sqrt(x)
        return x

class NaviNet(Module):
    '''
    A deeplearning architecture for local navigation planning
    '''
    def __init__(self,
                 image_dims=(240,320),
                 goal_dims=7):
         super(NaviNet, self).__init__()
         self.depth_net = DepthNet()
         self.rgb_net = RGBNet(ablation_depth=2)
         self.fc_goal_pose = Linear(goal_dims, 128) 

    def forward(self, rgb_image, depth_image):
        rgb_features = self.rgb_net(rgb_image).squeeze()
        depth_features = self.depth_net(depth_image)
        print(f"rgb_features: {rgb_features.shape}")
        print(f"depth_features: {depth_features.shape}")
       
        
        # Concatenate features
        # concatenated_features = torch.cat((rgb_features.squeeze(), depth_features, goal_pose), dim=1)
        
        # return concatenated_features
        return rgb_features, depth_features

In [12]:
file_path="/home/foxy_user/foxy_ws/src/gail_navigation/GailNavigationNetwork/data/traj1.hdf5"
read_file= h5py.File(file_path, "r")
rgb_image=np.expand_dims((read_file['images']['rgb_data'][0]).reshape(3,240,320),axis=0)
rgb_image = torch.tensor(rgb_image,dtype=torch.float32) # Example shape, adjust as needed
depth_image=np.expand_dims((read_file['images']['depth_data'][0]).reshape(240,320),axis=0)
# depth_image=(read_file['images']['depth_data'][0])

depth_image = torch.tensor(depth_image,dtype=torch.float32)  # Example shape, adjust as needed
goal_pose=np.expand_dims(read_file['kris_dynamics']['odom_data']['target_vector'][0],axis=0)
goal_pose = torch.tensor(goal_pose,dtype=torch.float32)  # Example shape
actions=np.expand_dims(read_file['kris_dynamics']['odom_data']['odom_data_filtered'][0],axis=0)
actions= torch.tensor(actions,dtype=torch.float32)  # Example shape
print(f"rgb_image: {rgb_image.shape} and depth_image: {depth_image.shape} \
      goal_pose: {goal_pose.shape} and actions: {actions.shape} \n \n")
model= NaviNet()
rgb_features,depth_features = model(rgb_image, depth_image)




rgb_image: torch.Size([1, 3, 240, 320]) and depth_image: torch.Size([1, 240, 320])       goal_pose: torch.Size([1, 7]) and actions: torch.Size([1, 7]) 
 

rgb_features: torch.Size([1280, 8, 10])
depth_features: torch.Size([2, 1, 318])
