In [1]:
from GNN_for_Depth.data.dataset import DepthDataset
from GNN_for_Depth.data.utils import custom_collate
from GNN_for_Depth.model.GNN import DepthGNNModel
# from GNN_forDepth.utils.criterion import SiLogLoss
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import os
import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
import torch.optim as optim
# from torchvision.ops import RoIAlign


In [11]:
# Clear cache to free up memory
torch.cuda.empty_cache()

torch.cuda.reset_max_memory_allocated()
torch.cuda.reset_max_memory_cached()

In [12]:
torch.cuda.reset_peak_memory_stats()

dataset.py

In [4]:
import os, random, glob

import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torch_geometric.data import Data
import h5py, torch, cv2
import numpy as np
from statistics import mode

from torch import nn, optim



class DepthDataset(Dataset):
    
    def __init__(self, data_path = "/home3/fsml62/LLM_and_SGG_for_MDE/dataset/nyu_depth_v2/official_splits", transform=None, ext="jpg", mode='train', threshold=0.06):
        
        self.data_path = data_path
        
        self.filenames = glob.glob(os.path.join(self.data_path, mode, '**', '*.{}'.format(ext)), recursive=True)
        self.pt_path = "/home3/fsml62/LLM_and_SGG_for_MDE/GNN_for_MDE/results/depth_embedding/nyu_depth_v2/official_splits"
        self.depth_map_path = "/home3/fsml62/LLM_and_SGG_for_MDE/GNN_for_MDE/results/depth_map/nyu_depth_v2/official_splits"
        self.sg_path = "/home3/fsml62/LLM_and_SGG_for_MDE/GNN_for_MDE/results/SGG/nyu_depth_v2/official_splits"

        self.transform = transform if transform else transforms.ToTensor()
        self.mode = mode
        self.threshold = threshold
        
        self.cache = {}
        
        



    def __getitem__(self, idx):
        
        
        if idx in self.cache:
            return self.cache[idx]
        

        # image path
        img_path = self.filenames[idx]
        # get the image
        img = Image.open(img_path)

        # get the relative path
        relative_path = os.path.relpath(img_path, self.data_path)
        # get depth embedding path
        depth_emb_path = os.path.join(self.pt_path, '{}.pt'.format(relative_path.split('.')[0]))
        # get depth map path
        depth_path = os.path.join(self.mode, self.depth_map_path, '{}.pt'.format(relative_path.split('.')[0]))

        #get the scene graph path
        scenegraph_path = os.path.join(self.sg_path, '{}.h5'.format(relative_path.split('.')[0]))



        ## Depth Embedding
        depth_emb = self.normalize(torch.load(depth_emb_path))

        ## depth map
        depth_map = self.normalize(torch.load(depth_path))

        ## get the actual depth
        actual_depth_path = img_path.replace("rgb", "sync_depth").replace('.jpg', '.png')
        actual_depth = Image.open(actual_depth_path)

        
        # the resize size
        target_size = (25, 25)


        with h5py.File(scenegraph_path, 'r') as h5_file:
            loaded_output_dict = {key: torch.tensor(np.array(h5_file[key])) for key in h5_file.keys()}


        probas = loaded_output_dict['rel_logits'].softmax(-1)[0, :, :-1]
        probas_sub = loaded_output_dict['sub_logits'].softmax(-1)[0, :, :-1]
        probas_obj = loaded_output_dict['obj_logits'].softmax(-1)[0, :, :-1]
        
        
        keep = torch.logical_and(probas.max(-1).values > self.threshold, 
                                torch.logical_and(probas_sub.max(-1).values > self.threshold, probas_obj.max(-1).values > self.threshold))
        
        
        
      
        
        sub_bboxes_scaled = self.rescale_bboxes(loaded_output_dict['sub_boxes'][0, keep], img.size)
        obj_bboxes_scaled = self.rescale_bboxes(loaded_output_dict['obj_boxes'][0, keep], img.size)
        relations = loaded_output_dict['rel_logits'][0, keep]

        valid_sub_bboxes = self.validate_bounding_boxes(sub_bboxes_scaled, img.size, target_size)
        valid_obj_bboxes = self.validate_bounding_boxes(obj_bboxes_scaled, img.size, target_size)

        # Combine validity of subject and object bounding boxes
        valid_pairs = torch.tensor([vs and vo for vs, vo in zip(valid_sub_bboxes, valid_obj_bboxes)], dtype=torch.bool)

        # Apply the updated keep mask
        sub_bboxes_scaled = sub_bboxes_scaled[valid_pairs]
        obj_bboxes_scaled = obj_bboxes_scaled[valid_pairs]
        relations = relations[valid_pairs]
        
        
        
        sub_idxs, nodes1 = self.assign_index(sub_bboxes_scaled, [], threshold=0.7)
        obj_idxs, nodes2 = self.assign_index(obj_bboxes_scaled, nodes1, threshold=0.7)
        
        all_idxs = sub_idxs + obj_idxs
        bbox_lists = torch.concat((sub_bboxes_scaled, obj_bboxes_scaled), dim=0)
        
        unique_idxs = set()
        filtered_idxs = []
        filtered_bboxes = []

        for idx, bbox in zip(all_idxs, bbox_lists):   
            if idx not in unique_idxs:
                unique_idxs.add(idx)
                filtered_idxs.append(idx)
                filtered_bboxes.append(bbox.tolist())  

        # Sort the filtered indices along with their corresponding bounding boxes
        sorted_wrapped = sorted(zip(filtered_idxs, filtered_bboxes), key=lambda x: x[0])
        sorted_idxs, sorted_bboxes = zip(*sorted_wrapped)

        # Convert them back to torch tensors
        sorted_idxs = torch.tensor(sorted_idxs, dtype=torch.long)
        sorted_bboxes = torch.tensor(sorted_bboxes, dtype=torch.int32)        
        

        
        
        # Apply transform to image and depth
        img = self.transform(img)
        img = self.normalize(img)
        actual_depth = self.transform(actual_depth).float()
        actual_depth = self.normalize(actual_depth)
        
        device = 'cpu'
        
        pooled_images = self.pool_visual_content_and_depth(
            sorted_bboxes=sorted_bboxes,
            embedding=img,
            target_size=target_size).to(device)
        
        pooled_depths = self.pool_visual_content_and_depth(
            sorted_bboxes=sorted_bboxes,
            embedding=depth_emb[0],
            target_size=target_size).to(device)
        
        pooled_act_depths = self.resize_depth_map(
            sorted_bboxes=sorted_bboxes,
            embedding=actual_depth.squeeze(0),
            target_size=target_size).to(device)
        
        
        
        # node embedding
        node_embeddings = torch.cat([pooled_images, pooled_depths], dim=1).to(device)
        
        edge_index = torch.tensor([sub_idxs, obj_idxs]).to(device)
        
#         edge_embeddings = torch.tensor(relations, dtype=torch.float).to(device)
        edge_embeddings = relations.clone().detach().float().to(device)
        
        
        
        gnndata = Data(x=node_embeddings, edge_index=edge_index, edge_attr=edge_embeddings)


        ## return data
        data = {
            'image': img,
            'depth_emb': depth_emb,
            'depth_map': depth_map,
            'depth': actual_depth,
            'pooled_act_depths': pooled_act_depths,
            'bboxs': filtered_bboxes,
            'gnndata': gnndata
        }
        
        self.cache[idx] = data
        
        return data

    def __len__(self):
        return len(self.filenames)
    
    def box_cxcywh_to_xyxy(self, x):

        x_c, y_c, w, h = x.unbind(1)
        b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    
        return torch.stack(b, dim=1)
    
    def rescale_bboxes(self, out_bbox, size):

        img_w, img_h = size
        b = self.box_cxcywh_to_xyxy(out_bbox)
        b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
        
        b = torch.round(b).int()

        b[:, 0] = torch.clamp(b[:, 0], min=0, max=img_w)
        b[:, 1] = torch.clamp(b[:, 1], min=0, max=img_h)
        b[:, 2] = torch.clamp(b[:, 2], min=0, max=img_w)
        b[:, 3] = torch.clamp(b[:, 3], min=0, max=img_h)

        return b

    def validate_bounding_boxes(self, bboxes, img_size, target_size):
        """Return a list of booleans indicating whether each bounding box is valid."""
        valid_bboxes = []
        img_w, img_h = img_size
        
        t_w, t_h = target_size

        for bbox in bboxes:
            x1, y1, x2, y2 = bbox
            if x1 < 0: x1 = 0
            if y1 < 0: y1 = 0
            if x2 > img_w: x2 = img_w
            if y2 > img_h: y2 = img_h
            
            if (x2 - x1) >= t_w and (y2 - y1) >= t_h:
                valid_bboxes.append(True)
            else:
                valid_bboxes.append(False)

        return valid_bboxes


    
    def pool_visual_content_and_depth(self, sorted_bboxes, embedding, target_size=(50, 50)):

        pool = nn.AdaptiveAvgPool2d(target_size)

        pooled_embs = []

        for bbox in sorted_bboxes:
            cropped_emb = embedding[:, bbox[1]:bbox[3], bbox[0]:bbox[2]]

            if cropped_emb.dim() == 2:
                cropped_emb = cropped_emb.unsqueeze(0)  

            pooled_emb = pool(cropped_emb.unsqueeze(0)).squeeze(0)

            flattened_emb = pooled_emb.view(-1)
            
            pooled_embs.append(flattened_emb)


        pooled_embs = torch.stack(pooled_embs) if pooled_embs else torch.empty(0)

        return pooled_embs
    
    
    def calculate_iou(self, box1, box2):

        x1, y1, x2, y2 = box1
        x3, y3, x4, y4 = box2

        # Calculate intersection coordinates
        x_inter1 = max(x1, x3)
        y_inter1 = max(y1, y3)
        x_inter2 = min(x2, x4)
        y_inter2 = min(y2, y4)

        # Calculate intersection dimensions
        width_inter = max(0, x_inter2 - x_inter1)
        height_inter = max(0, y_inter2 - y_inter1)

        # Calculate intersection area
        area_inter = width_inter * height_inter

        # Calculate areas of the input boxes
        width_box1 = abs(x2 - x1)
        height_box1 = abs(y2 - y1)
        area_box1 = width_box1 * height_box1

        width_box2 = abs(x4 - x3)
        height_box2 = abs(y4 - y3)
        area_box2 = width_box2 * height_box2

        # Calculate union area
        area_union = area_box1 + area_box2 - area_inter

        # Calculate IoU
        if area_union == 0:
            return 0  # avoid division by zero
        iou = area_inter / area_union

        return iou
    
    
    def assign_index(self, bounding_boxes, nodes, threshold=0.5):
        indices = []
        existing_boxes = nodes

        for box in bounding_boxes:
            found_match = False
            for idx, existing_box in enumerate(existing_boxes):
                if self.calculate_iou(box, existing_box) > threshold:
                    indices.append(idx)
                    found_match = True
                    break

            if not found_match:
                existing_boxes.append(box)
                indices.append(len(existing_boxes) - 1)

        return indices, existing_boxes
    
    
    def normalize(self, tensor):
        tensor_min = tensor.min()
        tensor_max = tensor.max()
        normalized_tensor = (tensor - tensor_min) / ((tensor_max - tensor_min) + 1e-8)
        return normalized_tensor
    
    
    
    def resize_depth_map(self, sorted_bboxes, embedding, target_size=(25, 25)):

        pooled_embs = []

        for bbox in sorted_bboxes:
            cropped_emb = embedding[bbox[1]:bbox[3], bbox[0]:bbox[2]]


            pooled_emb = self.downsize_depth_map_mode(cropped_emb, target_size)

            flattened_emb = pooled_emb.view(-1)
            
            pooled_embs.append(flattened_emb)


        pooled_embs = torch.stack(pooled_embs) if pooled_embs else torch.empty(0)

        return pooled_embs
    
    
    def downsize_depth_map_mode(self, depth_map, new_size):
        
        original_height, original_width = depth_map.shape[-2], depth_map.shape[-1]
        new_height, new_width = new_size

        window_height = original_height // new_height
        window_width = original_width // new_width

        downsized_map = torch.zeros((new_height, new_width), dtype=depth_map.dtype, device=depth_map.device)

        for i in range(new_height):
            for j in range(new_width):
                # Define the window boundaries
                start_row = i * window_height
                end_row = start_row + window_height
                start_col = j * window_width
                end_col = start_col + window_width

                # Extract the window
                window = depth_map[start_row:end_row, start_col:end_col]

                # Flatten the window
                flat_window = window.flatten().cpu().numpy()  # Convert to numpy for mode calculation

                # Calculate the mode and handle exceptions
                try:
                    downsized_map[i, j] = mode(flat_window)
                except:
                    downsized_map[i, j] = torch.tensor(np.median(flat_window), dtype=depth_map.dtype)

        return downsized_map
    
    

util.py

In [5]:
import torch
from torch_geometric.data import Batch, Data


def custom_collate(batch):
    
    
    batch = [item for item in batch if item is not None]

    if len(batch) == 0:
        return None
    
    # Handle images, depth maps, and other tensors separately
    images = torch.stack([item['image'] for item in batch])
    depth_embs = torch.stack([item['depth_emb'] for item in batch])
    depth_maps = torch.stack([item['depth_map'] for item in batch])
    depths = torch.stack([item['depth'] for item in batch])
    
#     pooled_act_depths = [torch.tensor(item['pooled_act_depths'], dtype=torch.int) for item in batch]
    pooled_act_depths = [item['pooled_act_depths'].clone().detach().float() for item in batch]




#     bboxs = torch.stack([item['bboxs'] for item in batch])
    bboxs = [torch.tensor(item['bboxs'], dtype=torch.int) for item in batch]

    gnndata_list = [item['gnndata'] for item in batch]

    # Batch the graph data using PyTorch Geometric's Batch
    graph_batch = Batch.from_data_list(gnndata_list)

    return {
        'image': images,
        'depth_emb': depth_embs,
        'depth_map': depth_maps,
        'depth': depths,
        'pooled_act_depths': pooled_act_depths,
        'bboxs': bboxs,
        'gnndata': graph_batch,  # Batched graph data
    }

In [6]:
train_dataset = DepthDataset()
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=custom_collate)


batch = next(iter(train_dataloader))

In [7]:
# import networkx as nx
# import matplotlib.pyplot as plt
# from torch_geometric.utils import to_networkx
# data_list = batch['gnndata'].to_data_list()
# # Convert and visualize each graph
# for i, data in enumerate(data_list):
#     G = to_networkx(data, to_undirected=True)  # Convert to networkx graph
#     plt.figure(figsize=(8, 8))
#     plt.title(f"Graph {i+1}")
#     nx.draw(G, with_labels=True, node_size=700, node_color='lightblue')
#     plt.show()


In [10]:
data_list = batch['gnndata'].to_data_list()

In [11]:
data_list[0].x.shape

torch.Size([10, 41875])

In [9]:
batch['pooled_act_depths'][0].shape

torch.Size([10, 625])

# Defining GNN

I am going to do node and edge.

- node: **Zero Padding First, Then Pooling" approach for both visual content (e.g., features extracted from bounding boxes) and depth embeddings.**
- relationship: relation

In [2]:
from torch_geometric.nn import MessagePassing
import torch.nn as nn

class DepthGNNModel(MessagePassing):
    def __init__(self, node_features_size, edge_features_size, hidden_channels, output_size):
        super(DepthGNNModel, self).__init__(aggr='add')

        self.message_mlp = nn.Sequential(
            nn.Linear(2 * node_features_size + edge_features_size, 1024),  # Reduced size
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(1024, hidden_channels)
        )

        self.node_mlp = nn.Sequential(
            nn.Linear(node_features_size + hidden_channels, 2048),  # Reduced size
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(2048, 1024),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(1024, output_size)  # Output a flattened 25x25 depth map
        )

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        message_input = torch.cat([x_i, edge_attr, x_j], dim=-1)
        return self.message_mlp(message_input)

    def update(self, aggr_out, x):
        updated_node_features = torch.cat([x, aggr_out], dim=-1)
        return self.node_mlp(updated_node_features)


In [7]:
# # Instantiate the model and move it to GPU
# device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# gnn_model = DepthGNNModel(node_features_size=82075, edge_features_size=52, hidden_channels=128, output_size=1225).to(device)

# # Define the loss function and optimizer
# criterion = nn.MSELoss()
# optimizer = optim.Adam(gnn_model.parameters(), lr=0.001)


In [8]:
# train_dataset = DepthDataset()
# train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=custom_collate)


In [15]:
# # Get a batch of data
batch = next(iter(train_dataloader))


In [16]:
batch['gnndata'].x

tensor([[0.3104, 0.3547, 0.3478,  ..., 0.4681, 0.4504, 0.4573],
        [0.3544, 0.3539, 0.3431,  ..., 0.4329, 0.4199, 0.4193],
        [0.0538, 0.0651, 0.0690,  ..., 0.4432, 0.4260, 0.4278],
        ...,
        [0.4521, 0.4317, 0.4115,  ..., 0.4586, 0.4422, 0.4296],
        [0.3595, 0.3383, 0.3423,  ..., 0.4378, 0.4319, 0.4291],
        [0.1691, 0.1658, 0.1478,  ..., 0.4523, 0.4557, 0.4754]])

In [17]:
# def normalize(tensor):
#     tensor_min = tensor.min()
#     tensor_max = tensor.max()
#     normalized_tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
#     return normalized_tensor



# gnndata = batch['gnndata'].to('cuda')
# pooled_act_depths = batch['pooled_act_depths']

# # Training step for a single batch
# gnn_model.train()
# optimizer.zero_grad()

# # Forward pass
# output = gnn_model(gnndata.x, gnndata.edge_index, gnndata.edge_attr)

# # Compute the loss for each node with the corresponding ground truth in pooled_act_depths
# loss = 0
# for i, node_output in enumerate(output):
# #     node_output_reshaped = node_output.view(-1, 625)
#     ground_truth = pooled_act_depths[0][i].cuda()
    
#     normalized_output = normalize(node_output)
#     normalized_ground_truth = normalize(ground_truth)
    
    
#     loss += criterion(normalized_output, normalized_ground_truth)

# # Average the loss
# loss = loss / len(output)

# # Backward pass
# loss.backward()
# optimizer.step()

# print(f'Loss: {loss.item()}')

In [18]:
# # Move to cuda
# gnndata = batch['gnndata'].to('cuda')
# pooled_act_depths = [depth_map.cuda() for depth_map in batch['pooled_act_depths']]

In [19]:
# output

In [6]:


# output = gnn_model(gnndata.x, gnndata.edge_index, gnndata.edge_attr)

# loss = 0
# for i, node_output in enumerate(output):
# #     node_output_reshaped = node_output.view(-1, output_size)
# #     ground_truth = pooled_act_depths[0][i].view(-1, output_size)

#     normalized_output = (node_output - node_output.min()) / (node_output.max() - node_output.min())
#     print(normalized_output)
#     normalized_ground_truth = (pooled_act_depths[0][i] - pooled_act_depths[0][i].min()) / (pooled_act_depths[0][i].max() - pooled_act_depths[0][i].min())

#     loss += criterion(normalized_output, normalized_ground_truth)
#     print(loss)

# loss = loss / len(output)

# # Backward
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()

# epoch_loss += loss.item()

# # Free up memory
# del gnndata, pooled_act_depths
# torch.cuda.empty_cache()

# # Print the average loss for the epoch
# avg_loss = epoch_loss / len(train_dataloader)
# print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")


## Training the whole dataset

### define the loss
criterion.py

In [11]:
# import torch
# import torch.nn as nn

# class SiLogLoss(nn.Module):
#     def __init__(self, lambd=0.5):
#         super().__init__()
#         self.lambd = lambd

#     def forward(self, pred, target):
#         valid_mask = (target > 0).detach()
#         diff_log = torch.log(target[valid_mask]) - torch.log(pred[valid_mask])
#         loss = torch.sqrt(torch.pow(diff_log, 2).mean() -
#                           self.lambd * torch.pow(diff_log.mean(), 2))

#         return loss

### checkpoints

In [3]:
checkpoint_path = './model_weights/current_checkpoint.pth'

def save_checkpoint(epoch, model, optimizer, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)

def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

### Early Stop

In [4]:
import torch
import numpy as np

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = np.inf
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True


### Normalise

In [5]:
def normalize(tensor):
    tensor_min = tensor.min()
    tensor_max = tensor.max()
    normalized_tensor = (tensor - tensor_min) / (tensor_max - tensor_min)
    return normalized_tensor

### Loading data

In [6]:
train_dataset = DepthDataset()
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=custom_collate)


### initialise model

In [7]:
# Initialize model, optimizer, and other components
node_features_size =  41875 # 82075 # 
edge_features_size = 52
hidden_channels = 728
output_size = 625 #1225 # 

gnn_model = DepthGNNModel(node_features_size, edge_features_size, hidden_channels, output_size).cuda()
# torch.nn.utils.clip_grad_norm_(gnn_model.parameters(), max_norm=1.0)

In [26]:
# def initialize_weights(m):
#     if isinstance(m, nn.Linear):
#         nn.init.kaiming_normal_(m.weight, nonlinearity='leaky_relu')
#         if m.bias is not None:
#             nn.init.constant_(m.bias, 0)

# gnn_model.apply(initialize_weights)


### Testing

In [19]:
import torch.multiprocessing as mp

if __name__ == '__main__':

    mp.set_start_method('spawn', force=True)
    single_batch = next(iter(train_dataloader))

In [20]:
criterion = nn.MSELoss()
optimizer = optim.Adam(gnn_model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
# early_stopping = EarlyStopping(patience=5, min_delta=0.001)

In [21]:
gnndata = single_batch['gnndata'].to('cuda')
pooled_act_depths = [depth_map.cuda() for depth_map in single_batch['pooled_act_depths']]

output = gnn_model(gnndata.x, gnndata.edge_index, gnndata.edge_attr)


NameError: name 'torch' is not defined

In [26]:

loss = 0
for i, node_output in enumerate(output):
    node_output_reshaped = node_output.view(-1, output_size)
    ground_truth = pooled_act_depths[0][i].view(-1, output_size)

    normalized_output = (node_output_reshaped - node_output_reshaped.min()) / (node_output_reshaped.max() - node_output_reshaped.min() + 1e-8)
    normalized_ground_truth = (ground_truth - ground_truth.min()) / (ground_truth.max() - ground_truth.min() + 1e-8)

    loss += criterion(normalized_output, normalized_ground_truth)

loss = loss / len(output)

if torch.isnan(loss):
    print("Loss became NaN. Stopping training.")


# Backward
optimizer.zero_grad()
loss.backward()
optimizer.step()

# epoch_loss += loss.item()

# Print the loss for the epoch
# print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")


IndexError: index 19 is out of bounds for dimension 0 with size 19

### Training

In [8]:


# # Load checkpoint if exists
# if os.path.exists(checkpoint_path):
#     start_epoch = load_checkpoint(gnn_model, optimizer, checkpoint_path)
#     print(f"Model loaded from checkpoint, starting from epoch {start_epoch + 1}")
# else:
#     start_epoch = 0

gnn_model.train()

criterion = nn.MSELoss()
optimizer = optim.Adam(gnn_model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
early_stopping = EarlyStopping(patience=5, min_delta=0.001)

num_epochs = 50

for epoch in range(num_epochs):

    epoch_loss = 0.0
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        
        # Move to cuda
        gnndata = batch['gnndata'].to('cuda')
        pooled_act_depths = [depth_map.cuda() for depth_map in batch['pooled_act_depths']]

        output = gnn_model(gnndata.x, gnndata.edge_index, gnndata.edge_attr)

        loss = 0
        for i, node_output in enumerate(output):
            node_output_reshaped = node_output.view(-1, output_size)
            ground_truth = pooled_act_depths[0][i].view(-1, output_size)

            normalized_output = (node_output_reshaped - node_output_reshaped.min()) / (node_output_reshaped.max() - node_output_reshaped.min() + 1e-8)
            normalized_ground_truth = (ground_truth - ground_truth.min()) / (ground_truth.max() - ground_truth.min() + 1e-8)

            loss += criterion(normalized_output, normalized_ground_truth)

        loss = loss / len(output)
#         print(loss)
        
#         print(loss)
        if torch.isnan(loss):
            print("Loss became NaN. Stopping training.")
            break

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        # Free up memory
        del gnndata, pooled_act_depths
        torch.cuda.empty_cache()

    # Print the average loss for the epoch
    avg_loss = epoch_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

    # early stopping
    early_stopping(avg_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

    # Step the scheduler
    scheduler.step(avg_loss)

    # Save checkpoint every 8 epochs
    if (epoch + 1) % 8 == 0:
        checkpoint_name = f'./model_weights/checkpoint_epoch_{epoch+1}.pth'
        save_checkpoint(epoch, gnn_model, optimizer, checkpoint_name)
        print(f"Checkpoint saved: {checkpoint_name}")

    save_checkpoint(epoch, gnn_model, optimizer, checkpoint_path)

Epoch 1/50: 100%|██████████| 795/795 [24:28<00:00,  1.85s/it]


Epoch [1/50], Average Loss: 0.1441


Epoch 2/50: 100%|██████████| 795/795 [24:29<00:00,  1.85s/it]


Epoch [2/50], Average Loss: 0.1430


Epoch 3/50: 100%|██████████| 795/795 [24:30<00:00,  1.85s/it]


Epoch [3/50], Average Loss: 0.1423


Epoch 4/50: 100%|██████████| 795/795 [24:24<00:00,  1.84s/it]


Epoch [4/50], Average Loss: 0.1405


Epoch 5/50: 100%|██████████| 795/795 [24:36<00:00,  1.86s/it]


Epoch [5/50], Average Loss: 0.1375


Epoch 6/50: 100%|██████████| 795/795 [24:34<00:00,  1.85s/it]


Epoch [6/50], Average Loss: 0.1352


Epoch 7/50: 100%|██████████| 795/795 [24:37<00:00,  1.86s/it]


Epoch [7/50], Average Loss: 0.1349


Epoch 8/50: 100%|██████████| 795/795 [24:37<00:00,  1.86s/it]


Epoch [8/50], Average Loss: 0.1338
Checkpoint saved: ./model_weights/checkpoint_epoch_8.pth


Epoch 9/50: 100%|██████████| 795/795 [24:36<00:00,  1.86s/it]


Epoch [9/50], Average Loss: 0.1322


Epoch 10/50: 100%|██████████| 795/795 [24:35<00:00,  1.86s/it]


Epoch [10/50], Average Loss: 0.1314


Epoch 11/50: 100%|██████████| 795/795 [24:38<00:00,  1.86s/it]


Epoch [11/50], Average Loss: 0.1304


Epoch 12/50: 100%|██████████| 795/795 [24:38<00:00,  1.86s/it]


Epoch [12/50], Average Loss: 0.1284


Epoch 13/50:  99%|█████████▉| 788/795 [24:19<00:11,  1.64s/it]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Epoch 14/50: 100%|██████████| 795/795 [24:32<00:00,  1.85s/it]


Epoch [14/50], Average Loss: 0.1246


Epoch 15/50: 100%|██████████| 795/795 [24:35<00:00,  1.86s/it]


Epoch [15/50], Average Loss: 0.1234


Epoch 16/50: 100%|██████████| 795/795 [24:28<00:00,  1.85s/it]


Epoch [16/50], Average Loss: 0.1237
Checkpoint saved: ./model_weights/checkpoint_epoch_16.pth


Epoch 17/50: 100%|██████████| 795/795 [24:26<00:00,  1.84s/it]


Epoch [17/50], Average Loss: 0.1220


Epoch 18/50: 100%|██████████| 795/795 [24:26<00:00,  1.85s/it]


Epoch [18/50], Average Loss: 0.1215


Epoch 19/50: 100%|██████████| 795/795 [24:26<00:00,  1.84s/it]


Epoch [19/50], Average Loss: 0.1203


Epoch 20/50: 100%|██████████| 795/795 [24:25<00:00,  1.84s/it]


Epoch [20/50], Average Loss: 0.1196


Epoch 21/50: 100%|██████████| 795/795 [24:40<00:00,  1.86s/it]


Epoch [21/50], Average Loss: 0.1192


Epoch 22/50:  52%|█████▏    | 416/795 [12:44<09:17,  1.47s/it]IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)

Epoch 31/50: 100%|██████████| 795/795 [24:22<00:00,  1.84s/it]


Epoch [31/50], Average Loss: 0.1140


Epoch 32/50: 100%|██████████| 795/795 [24:20<00:00,  1.84s/it]


Epoch [32/50], Average Loss: 0.1136
Checkpoint saved: ./model_weights/checkpoint_epoch_32.pth


Epoch 33/50: 100%|██████████| 795/795 [24:20<00:00,  1.84s/it]


Epoch [33/50], Average Loss: 0.1132


Epoch 34/50: 100%|██████████| 795/795 [24:22<00:00,  1.84s/it]


Epoch [34/50], Average Loss: 0.1127


Epoch 35/50: 100%|██████████| 795/795 [24:21<00:00,  1.84s/it]

Epoch [35/50], Average Loss: 0.1131
Early stopping triggered





In [21]:
!nvidia-smi

Wed Aug 14 18:41:04 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:49:00.0 Off |                   On |
| N/A   36C    P0            107W /  300W |    6603MiB /  81920MiB |     N/A      Default |
|                                         |                        |              Enabled |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [15]:
save_checkpoint(10, gnn_model, optimizer, checkpoint_path)

## Evaluation

In [9]:
test_dataset = DepthDataset(mode='test')
test_dataloader = DataLoader(test_dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=custom_collate)


In [10]:
criterion = nn.MSELoss()

# Load the model checkpoint if available
# checkpoint_path = 'checkpoint.pth'
# if os.path.exists(checkpoint_path):
#     checkpoint = torch.load(checkpoint_path)
#     gnn_model.load_state_dict(checkpoint['model_state_dict'])
#     print("Model loaded from checkpoint")

In [11]:
# Load checkpoint if exists
if os.path.exists(checkpoint_path):
    start_epoch = load_checkpoint(gnn_model, optimizer, checkpoint_path)
    print(f"Model loaded from checkpoint, starting evaluation from epoch {start_epoch + 1}")
else:
    raise FileNotFoundError("Checkpoint not found. Ensure the correct checkpoint path is provided.")

gnn_model.eval()  # Switch to evaluation mode

# criterion = nn.MSELoss()

with torch.no_grad():  # Disable gradient calculation
    total_loss = 0.0
    for batch in tqdm(test_dataloader, desc="Evaluating"):

        # Move to cuda
        gnndata = batch['gnndata'].to('cuda')
        pooled_act_depths = [depth_map.cuda() for depth_map in batch['pooled_act_depths']]

        output = gnn_model(gnndata.x, gnndata.edge_index, gnndata.edge_attr)

        loss = 0
        for i, node_output in enumerate(output):
            node_output_reshaped = node_output.view(-1, output_size)
            ground_truth = pooled_act_depths[0][i].view(-1, output_size)

            normalized_output = (node_output_reshaped - node_output_reshaped.min()) / (node_output_reshaped.max() - node_output_reshaped.min() + 1e-8)
            normalized_ground_truth = (ground_truth - ground_truth.min()) / (ground_truth.max() - ground_truth.min() + 1e-8)

            loss += criterion(normalized_output, normalized_ground_truth)

        loss = loss / len(output)
        
        if torch.isnan(loss):
            print("Loss became NaN during evaluation. Stopping evaluation.")
            break

        total_loss += loss.item()

        # Free up memory
        del gnndata, pooled_act_depths
        torch.cuda.empty_cache()

    # Print the average loss for the evaluation
    avg_loss = total_loss / len(test_dataloader)
    print(f"Average Loss during Evaluation: {avg_loss:.4f}")


Model loaded from checkpoint, starting evaluation from epoch 34


Evaluating: 100%|██████████| 654/654 [27:52<00:00,  2.56s/it]


NameError: name 'eval_dataloader' is not defined

In [12]:
avg_loss = total_loss / len(test_dataloader)
print(f"Average Loss during Evaluation: {avg_loss:.4f}")

Average Loss during Evaluation: 0.1133


In [14]:
# Evaluate the model
evaluate(gnn_model, test_dataloader, criterion)

Evaluating: 100%|██████████| 654/654 [02:51<00:00,  3.82it/s]

Average Loss: 0.1169, Average MAE: 0.3851





(0.11687608749893372, 0.38510026233641015)

In [15]:
evaluate(gnn_model, test_dataloader, criterion)

Evaluating: 100%|██████████| 654/654 [05:22<00:00,  2.03it/s]

Average Loss: 0.0964, Average MAE: 0.3200





(0.096388840765923, 0.319969099732714)

In [16]:
evaluate(gnn_model, train_dataloader, criterion)

Evaluating: 100%|██████████| 795/795 [03:13<00:00,  4.12it/s]

Average Loss: 0.0678, Average MAE: 0.2585





(0.06777471745000133, 0.2584519677754468)

## Testing the output

In [None]:
batch = next(iter(train_dataloader))

In [None]:
with torch.no_grad():

    # Move data to GPU
    sub_imgs = torch.stack([item.cuda() for item in batch['sub_imgs']])
    obj_imgs = torch.stack([item.cuda() for item in batch['obj_imgs']])
    sub_depth_emb = torch.stack([item.cuda() for item in batch['sub_depth_emb']])
    obj_depth_emb = torch.stack([item.cuda() for item in batch['obj_depth_emb']])
    sub_act_depths = torch.stack([item.cuda() for item in batch['sub_act_depths']])
    obj_act_depths = torch.stack([item.cuda() for item in batch['obj_act_depths']])
    edges = torch.stack([item.cuda() for item in batch['relation']])

    node1_features = torch.cat([sub_imgs, sub_depth_emb], dim=-1)
    node2_features = torch.cat([obj_imgs, obj_depth_emb], dim=-1)

    # Forward pass through the GNN model
    depth_map1, depth_map2, updated_edges = model(node1_features, node2_features, edges)

    # Reshape the depth maps to match the ground truth dimensions if necessary
    depth_map1 = depth_map1.view(sub_act_depths.shape)
    depth_map2 = depth_map2.view(obj_act_depths.shape)

    # Normalize depth maps
    normalized_depth_1 = (depth_map1 - depth_map1.min()) / (depth_map1.max() - depth_map1.min())
    normalized_depth_2 = (depth_map2 - depth_map2.min()) / (depth_map2.max() - depth_map2.min())
    normalized_sub_act_depths = (sub_act_depths - sub_act_depths.min()) / (sub_act_depths.max() - sub_act_depths.min())
    normalized_obj_act_depths = (obj_act_depths - obj_act_depths.min()) / (obj_act_depths.max() - obj_act_depths.min())

    # Calculate loss
    loss1 = criterion(normalized_depth_1, normalized_sub_act_depths)
    loss2 = criterion(normalized_depth_2, normalized_obj_act_depths)
    loss = loss1 + loss2

    # Calculate MAE
    mae1 = torch.mean(torch.abs(normalized_depth_1 - normalized_sub_act_depths))
    mae2 = torch.mean(torch.abs(normalized_depth_2 - normalized_obj_act_depths))
    mae = mae1 + mae2

    # Accumulate loss and MAE
    total_loss += loss.item()
    total_mae += mae.item()
