In [1]:
# from GNN_for_Depth.data.dataset import DepthDataset
# from GNN_for_Depth.data.utils import custom_collate
# from GNN_for_Depth.model.GNN import DepthGNNModel
# from GNN_forDepth.utils.criterion import SiLogLoss
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt

import os
import torch
import torch.nn as nn
import torchvision
from tqdm import tqdm
import torch.optim as optim
# from torchvision.ops import RoIAlign


### Dataset

In [4]:
import os, random, glob

import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
from torch_geometric.data import Data
import h5py, torch, cv2
import numpy as np
from statistics import mode
import torch.nn.functional as F

from torch import nn, optim



class DepthDataset(Dataset):
    
    def __init__(self, data_path = "/home3/fsml62/LLM_and_SGG_for_MDE/dataset/nyu_depth_v2/official_splits", transform=None, ext="jpg", mode='train', threshold=0.06):
        
        self.data_path = data_path
        
        self.filenames = glob.glob(os.path.join(self.data_path, mode, '**', '*.{}'.format(ext)), recursive=True)
        self.pt_path = "/home3/fsml62/LLM_and_SGG_for_MDE/GNN_for_MDE/results/depth_embedding/nyu_depth_v2/official_splits"
        self.depth_map_path = "/home3/fsml62/LLM_and_SGG_for_MDE/GNN_for_MDE/results/depth_map/nyu_depth_v2/official_splits"
#         self.depth_map_path = "/home3/fsml62/LLM_and_SGG_for_MDE/GNN_for_MDE/results/adabins/depth_map/nyu_depth_v2/official_splits"
        self.sg_path = "/home3/fsml62/LLM_and_SGG_for_MDE/GNN_for_MDE/results/SGG/nyu_depth_v2/official_splits"

        self.transform = transform if transform else transforms.ToTensor()
        self.mode = mode
        self.threshold = threshold
        
        self.cache = {}
        
        



    def __getitem__(self, idx):
        
        
        if idx in self.cache:
            return self.cache[idx]
        

        # image path
        img_path = self.filenames[idx]
        # get the image
        img = Image.open(img_path)

        # get the relative path
        relative_path = os.path.relpath(img_path, self.data_path)
        # get depth embedding path
        depth_emb_path = os.path.join(self.pt_path, '{}.pt'.format(relative_path.split('.')[0]))
        # get depth map path
        depth_path = os.path.join(self.mode, self.depth_map_path, '{}.pt'.format(relative_path.split('.')[0]))

        #get the scene graph path
        scenegraph_path = os.path.join(self.sg_path, '{}.h5'.format(relative_path.split('.')[0]))



        ## Depth Embedding
        depth_emb = self.normalize(torch.load(depth_emb_path))

        ## depth map
        depth_map = self.normalize(torch.load(depth_path))

        ## get the actual depth
        actual_depth_path = img_path.replace("rgb", "sync_depth").replace('.jpg', '.png')
        actual_depth = Image.open(actual_depth_path)

        
        # the resize size
        target_size = (25, 25)


        with h5py.File(scenegraph_path, 'r') as h5_file:
            loaded_output_dict = {key: torch.tensor(np.array(h5_file[key])) for key in h5_file.keys()}


        probas = loaded_output_dict['rel_logits'].softmax(-1)[0, :, :-1]
        probas_sub = loaded_output_dict['sub_logits'].softmax(-1)[0, :, :-1]
        probas_obj = loaded_output_dict['obj_logits'].softmax(-1)[0, :, :-1]
        
        
        keep = torch.logical_and(probas.max(-1).values > self.threshold, 
                                torch.logical_and(probas_sub.max(-1).values > self.threshold, probas_obj.max(-1).values > self.threshold))
        
        
        
      
        
        sub_bboxes_scaled = self.rescale_bboxes(loaded_output_dict['sub_boxes'][0, keep], img.size)
        obj_bboxes_scaled = self.rescale_bboxes(loaded_output_dict['obj_boxes'][0, keep], img.size)
        relations = loaded_output_dict['rel_logits'][0, keep]

        valid_sub_bboxes = self.validate_bounding_boxes(sub_bboxes_scaled, img.size, target_size)
        valid_obj_bboxes = self.validate_bounding_boxes(obj_bboxes_scaled, img.size, target_size)

        # Combine validity of subject and object bounding boxes
        valid_pairs = torch.tensor([vs and vo for vs, vo in zip(valid_sub_bboxes, valid_obj_bboxes)], dtype=torch.bool)

        # Apply the updated keep mask
        sub_bboxes_scaled = sub_bboxes_scaled[valid_pairs]
        obj_bboxes_scaled = obj_bboxes_scaled[valid_pairs]
        relations = relations[valid_pairs]
        
        
        
        sub_idxs, nodes1 = self.assign_index(sub_bboxes_scaled, [], threshold=0.7)
        obj_idxs, nodes2 = self.assign_index(obj_bboxes_scaled, nodes1, threshold=0.7)
        
        all_idxs = sub_idxs + obj_idxs
        bbox_lists = torch.concat((sub_bboxes_scaled, obj_bboxes_scaled), dim=0)
        
        unique_idxs = set()
        filtered_idxs = []
        filtered_bboxes = []

        for idx, bbox in zip(all_idxs, bbox_lists):   
            if idx not in unique_idxs:
                unique_idxs.add(idx)
                filtered_idxs.append(idx)
                filtered_bboxes.append(bbox.tolist())  

        # Sort the filtered indices along with their corresponding bounding boxes
        sorted_wrapped = sorted(zip(filtered_idxs, filtered_bboxes), key=lambda x: x[0])
        sorted_idxs, sorted_bboxes = zip(*sorted_wrapped)

        # Convert them back to torch tensors
        sorted_idxs = torch.tensor(sorted_idxs, dtype=torch.long)
        sorted_bboxes = torch.tensor(sorted_bboxes, dtype=torch.int32)        
        

        
        
        # Apply transform to image and depth
        img = self.transform(img)
        img = self.normalize(img)
        actual_depth = self.transform(actual_depth).float()
        actual_depth = self.normalize(actual_depth)
        depth_emb = self.normalize(depth_emb)
        
        device = 'cpu'
        

        pooled_images = self.linear_visual_content_and_depth(
            sorted_bboxes=sorted_bboxes,
            embedding=img,
            target_size=target_size).to(device)
        
        pooled_depths = self.linear_visual_content_and_depth(
            sorted_bboxes=sorted_bboxes,
            embedding=depth_emb[0],
            target_size=target_size).to(device)
        
        pooled_act_depths = self.resize_depth_map(
            sorted_bboxes=sorted_bboxes,
            embedding=actual_depth.squeeze(0),
            target_size=target_size,
            method='nearest').to(device)
        
        
        
        node_embeddings = torch.cat([pooled_images, pooled_depths], dim=1).to(device)
        
        edge_index = torch.tensor([sub_idxs, obj_idxs]).to(device)
        
        edge_embeddings = relations.clone().detach().float().to(device)
        
        
        
        gnndata = Data(x=node_embeddings, edge_index=edge_index, edge_attr=edge_embeddings)

        ## return data
        data = {
            'image': img,
            'depth_emb': depth_emb,
            'depth_map': depth_map,
            'depth': actual_depth,
            'pooled_act_depths': pooled_act_depths,
            'bboxs': filtered_bboxes,
            'gnndata': gnndata
        }
        
        self.cache[idx] = data
        
        return data

    def __len__(self):
        return len(self.filenames)
    
    def box_cxcywh_to_xyxy(self, x):

        x_c, y_c, w, h = x.unbind(1)
        b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    
        return torch.stack(b, dim=1)
    
    def rescale_bboxes(self, out_bbox, size):

        img_w, img_h = size
        b = self.box_cxcywh_to_xyxy(out_bbox)
        b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
        
        b = torch.round(b).int()

        b[:, 0] = torch.clamp(b[:, 0], min=0, max=img_w)
        b[:, 1] = torch.clamp(b[:, 1], min=0, max=img_h)
        b[:, 2] = torch.clamp(b[:, 2], min=0, max=img_w)
        b[:, 3] = torch.clamp(b[:, 3], min=0, max=img_h)

        return b

    def validate_bounding_boxes(self, bboxes, img_size, target_size):
        """Return a list of booleans indicating whether each bounding box is valid."""
        valid_bboxes = []
        img_w, img_h = img_size
        
        t_w, t_h = target_size

        for bbox in bboxes:
            x1, y1, x2, y2 = bbox
            if x1 < 0: x1 = 0
            if y1 < 0: y1 = 0
            if x2 > img_w: x2 = img_w
            if y2 > img_h: y2 = img_h
            
            if (x2 - x1) >= t_w and (y2 - y1) >= t_h:
                valid_bboxes.append(True)
            else:
                valid_bboxes.append(False)

        return valid_bboxes


    
    def pool_visual_content_and_depth(self, sorted_bboxes, embedding, target_size=(50, 50)):

        pool = nn.AdaptiveAvgPool2d(target_size)

        pooled_embs = []

        for bbox in sorted_bboxes:
            cropped_emb = embedding[:, bbox[1]:bbox[3], bbox[0]:bbox[2]]

            if cropped_emb.dim() == 2:
                cropped_emb = cropped_emb.unsqueeze(0)  

            pooled_emb = pool(cropped_emb.unsqueeze(0)).squeeze(0)

            flattened_emb = pooled_emb.view(-1)
            
            pooled_embs.append(flattened_emb)


        pooled_embs = torch.stack(pooled_embs) if pooled_embs else torch.empty(0)

        return pooled_embs


    def linear_visual_content_and_depth(self, sorted_bboxes, embedding, target_size=(50, 50)):

        pooled_embs = []

        for bbox in sorted_bboxes:

            cropped_emb = embedding[:, bbox[1]:bbox[3], bbox[0]:bbox[2]]

            if cropped_emb.dim() == 2:
                cropped_emb = cropped_emb.unsqueeze(0)

            pooled_emb = F.interpolate(cropped_emb.unsqueeze(0), size=target_size, mode='linear').squeeze(0)

            flattened_emb = pooled_emb.view(-1)

            pooled_embs.append(flattened_emb)

        pooled_embs = torch.stack(pooled_embs) if pooled_embs else torch.empty(0)

        return pooled_embs
    
    
    def calculate_iou(self, box1, box2):

        x1, y1, x2, y2 = box1
        x3, y3, x4, y4 = box2

        x_inter1 = max(x1, x3)
        y_inter1 = max(y1, y3)
        x_inter2 = min(x2, x4)
        y_inter2 = min(y2, y4)

        width_inter = max(0, x_inter2 - x_inter1)
        height_inter = max(0, y_inter2 - y_inter1)

        area_inter = width_inter * height_inter

        width_box1 = abs(x2 - x1)
        height_box1 = abs(y2 - y1)
        area_box1 = width_box1 * height_box1

        width_box2 = abs(x4 - x3)
        height_box2 = abs(y4 - y3)
        area_box2 = width_box2 * height_box2

        area_union = area_box1 + area_box2 - area_inter

        if area_union == 0:
            return 0  # avoid division by zero
        iou = area_inter / area_union

        return iou
    
    
    def assign_index(self, bounding_boxes, nodes, threshold=0.5):
        indices = []
        existing_boxes = nodes

        for box in bounding_boxes:
            found_match = False
            for idx, existing_box in enumerate(existing_boxes):
                if self.calculate_iou(box, existing_box) > threshold:
                    indices.append(idx)
                    found_match = True
                    break

            if not found_match:
                existing_boxes.append(box)
                indices.append(len(existing_boxes) - 1)

        return indices, existing_boxes
    
    
    def normalize(self, tensor):
        tensor_min = tensor.min()
        tensor_max = tensor.max()
        normalized_tensor = (tensor - tensor_min) / ((tensor_max - tensor_min) + 1e-8)
        return normalized_tensor
    
    
    
    def resize_depth_map(self, sorted_bboxes, embedding, target_size=(25, 25), method='nearest'):

        pooled_embs = []

        for bbox in sorted_bboxes:
            cropped_emb = embedding[bbox[1]:bbox[3], bbox[0]:bbox[2]]

            if method == 'nearest':
                pooled_emb = self.downsize_depth_map_bilinear(cropped_emb, target_size)
            else:
                pooled_emb = self.downsize_depth_map_mode(cropped_emb, target_size)

            flattened_emb = pooled_emb.view(-1)
            
            pooled_embs.append(flattened_emb)


        pooled_embs = torch.stack(pooled_embs) if pooled_embs else torch.empty(0)

        return pooled_embs
    
    
    def downsize_depth_map_mode(self, depth_map, new_size):
        
        original_height, original_width = depth_map.shape[-2], depth_map.shape[-1]
        new_height, new_width = new_size

        window_height = original_height // new_height
        window_width = original_width // new_width

        downsized_map = torch.zeros((new_height, new_width), dtype=depth_map.dtype, device=depth_map.device)

        for i in range(new_height):
            for j in range(new_width):

                start_row = i * window_height
                end_row = start_row + window_height
                start_col = j * window_width
                end_col = start_col + window_width

                window = depth_map[start_row:end_row, start_col:end_col]

                flat_window = window.flatten().cpu().numpy()  

                downsized_map[i, j] = mode(flat_window)

        return downsized_map
    
    def downsize_depth_map_bilinear(self, depth_map, new_size):

        if depth_map.dim() == 2:
            depth_map = depth_map.unsqueeze(0).unsqueeze(0)  # Add batch and channel dimensions

        new_height, new_width = new_size

        resized_map = F.interpolate(depth_map, size=(new_height, new_width), mode='blinear')

        return resized_map.squeeze(0).squeeze(0)

util.py

In [5]:
import torch
from torch_geometric.data import Batch, Data


def custom_collate(batch):
    
    
    batch = [item for item in batch if item is not None]

    if len(batch) == 0:
        return None
    
    # Handle images, depth maps, and other tensors separately
    images = torch.stack([item['image'] for item in batch])
    depth_embs = torch.stack([item['depth_emb'] for item in batch])
    depth_maps = torch.stack([item['depth_map'] for item in batch])
    depths = torch.stack([item['depth'] for item in batch])
    
#     pooled_act_depths = [torch.tensor(item['pooled_act_depths'], dtype=torch.int) for item in batch]
    pooled_act_depths = [item['pooled_act_depths'].clone().detach().float() for item in batch]




#     bboxs = torch.stack([item['bboxs'] for item in batch])
    bboxs = [torch.tensor(item['bboxs'], dtype=torch.int) for item in batch]

    gnndata_list = [item['gnndata'] for item in batch]

    # Batch the graph data using PyTorch Geometric's Batch
    graph_batch = Batch.from_data_list(gnndata_list)

    return {
        'image': images,
        'depth_emb': depth_embs,
        'depth_map': depth_maps,
        'depth': depths,
        'pooled_act_depths': pooled_act_depths,
        'bboxs': bboxs,
        'gnndata': graph_batch,  # Batched graph data
    }

# Defining GNN

In [6]:
from torch_geometric.nn import MessagePassing
import torch.nn as nn

class DepthGNNModel(MessagePassing):
    def __init__(self, node_features_size, edge_features_size, hidden_channels, output_size):
        super(DepthGNNModel, self).__init__(aggr='add')

        self.message_mlp = nn.Sequential(
            nn.Linear(2 * node_features_size + edge_features_size, 1024),  # Reduced size
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(1024, hidden_channels)
        )

        self.node_mlp = nn.Sequential(
            nn.Linear(node_features_size + hidden_channels, 2048),  # Reduced size
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(2048, 1024),
            nn.LeakyReLU(inplace=True),
            nn.Dropout(p=0.3),
            nn.Linear(1024, output_size)  # Output a flattened 25x25 depth map
        )

    def forward(self, x, edge_index, edge_attr):
        return self.propagate(edge_index, x=x, edge_attr=edge_attr)

    def message(self, x_i, x_j, edge_attr):
        message_input = torch.cat([x_i, edge_attr, x_j], dim=-1)
        return self.message_mlp(message_input)

    def update(self, aggr_out, x):
        updated_node_features = torch.cat([x, aggr_out], dim=-1)
        return self.node_mlp(updated_node_features)


## Training the whole dataset

### checkpoints

In [7]:
checkpoint_path = './model_weights/current_linear_checkpoint.pth'

def save_checkpoint(epoch, model, optimizer, path):
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
    }, path)

def load_checkpoint(model, optimizer, path):
    checkpoint = torch.load(path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    return checkpoint['epoch']

### Early Stop

In [8]:
import torch
import numpy as np

class EarlyStopping:
    def __init__(self, patience=5, min_delta=0):
        self.patience = patience
        self.min_delta = min_delta
        self.counter = 0
        self.best_loss = np.inf
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss - self.min_delta:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True


### Loading data

In [10]:
train_dataset = DepthDataset()
train_dataloader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=0, collate_fn=custom_collate)


### initialise model

In [11]:
# Initialize model, optimizer, and other components
node_features_size =  41875 # 82075 # 
edge_features_size = 52
hidden_channels = 728
output_size = 625 #1225 # 

gnn_model = DepthGNNModel(node_features_size, edge_features_size, hidden_channels, output_size).cuda()
# torch.nn.utils.clip_grad_norm_(gnn_model.parameters(), max_norm=1.0)

### Testing one batch

In [28]:
# import torch.multiprocessing as mp

# if __name__ == '__main__':

#     mp.set_start_method('spawn', force=True)
#     single_batch = next(iter(train_dataloader))

In [29]:
# criterion = nn.MSELoss()
# optimizer = optim.Adam(gnn_model.parameters(), lr=1e-4)
# scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
# # early_stopping = EarlyStopping(patience=5, min_delta=0.001)

In [30]:
# gnndata = single_batch['gnndata'].to('cuda')
# pooled_act_depths = [depth_map.cuda() for depth_map in single_batch['pooled_act_depths']]

# output = gnn_model(gnndata.x, gnndata.edge_index, gnndata.edge_attr)


In [31]:

# loss = 0
# for i, node_output in enumerate(output):
#     node_output_reshaped = node_output.view(-1, output_size)
#     ground_truth = pooled_act_depths[0][i].view(-1, output_size)

#     normalized_output = (node_output_reshaped - node_output_reshaped.min()) / (node_output_reshaped.max() - node_output_reshaped.min() + 1e-8)
#     normalized_ground_truth = (ground_truth - ground_truth.min()) / (ground_truth.max() - ground_truth.min() + 1e-8)

#     loss += criterion(normalized_output, normalized_ground_truth)

# loss = loss / len(output)

# if torch.isnan(loss):
#     print("Loss became NaN. Stopping training.")


# # Backward
# optimizer.zero_grad()
# loss.backward()
# optimizer.step()

# # epoch_loss += loss.item()

# # Print the loss for the epoch
# # print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.4f}")


In [32]:
# loss

tensor(0.1536, device='cuda:0', grad_fn=<DivBackward0>)

### Training

In [12]:


# # Load checkpoint if exists
# if os.path.exists(checkpoint_path):
#     start_epoch = load_checkpoint(gnn_model, optimizer, checkpoint_path)
#     print(f"Model loaded from checkpoint, starting from epoch {start_epoch + 1}")
# else:
#     start_epoch = 0

gnn_model.train()

criterion = nn.MSELoss()
optimizer = optim.Adam(gnn_model.parameters(), lr=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
early_stopping = EarlyStopping(patience=5, min_delta=0.001)

num_epochs = 50

for epoch in range(num_epochs):

    epoch_loss = 0.0
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        
        # Move to cuda
        gnndata = batch['gnndata'].to('cuda')
        pooled_act_depths = [depth_map.cuda() for depth_map in batch['pooled_act_depths']]

        output = gnn_model(gnndata.x, gnndata.edge_index, gnndata.edge_attr)

        loss = 0
        for i, node_output in enumerate(output):
            node_output_reshaped = node_output.view(-1, output_size)
            ground_truth = pooled_act_depths[0][i].view(-1, output_size)

            normalized_output = (node_output_reshaped - node_output_reshaped.min()) / (node_output_reshaped.max() - node_output_reshaped.min() + 1e-8)
            normalized_ground_truth = (ground_truth - ground_truth.min()) / (ground_truth.max() - ground_truth.min() + 1e-8)

            loss += criterion(normalized_output, normalized_ground_truth)

        loss = loss / len(output)
#         print(loss)
        
#         print(loss)
        if torch.isnan(loss):
            print("Loss became NaN. Stopping training.")
            break

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        # Free up memory
        del gnndata, pooled_act_depths
        torch.cuda.empty_cache()

    # Print the average loss for the epoch
    avg_loss = epoch_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

    # early stopping
    early_stopping(avg_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

    # Step the scheduler
    scheduler.step(avg_loss)

    # Save checkpoint every 8 epochs
    if (epoch + 1) % 8 == 0:
        checkpoint_name = f'./model_weights/checkpoint_linear_epoch_{epoch+1}.pth'
        save_checkpoint(epoch, gnn_model, optimizer, checkpoint_name)
        print(f"Checkpoint saved: {checkpoint_name}")

    save_checkpoint(epoch, gnn_model, optimizer, checkpoint_path)

Epoch 1/50: 100%|██████████| 795/795 [11:20<00:00,  1.17it/s]


Epoch [1/50], Average Loss: 0.1385


Epoch 2/50: 100%|██████████| 795/795 [04:03<00:00,  3.26it/s]


Epoch [2/50], Average Loss: 0.1376


Epoch 3/50: 100%|██████████| 795/795 [04:02<00:00,  3.27it/s]


Epoch [3/50], Average Loss: 0.1368


Epoch 4/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]


Epoch [4/50], Average Loss: 0.1349


Epoch 5/50: 100%|██████████| 795/795 [03:58<00:00,  3.33it/s]


Epoch [5/50], Average Loss: 0.1318


Epoch 6/50: 100%|██████████| 795/795 [03:53<00:00,  3.40it/s]


Epoch [6/50], Average Loss: 0.1285


Epoch 7/50: 100%|██████████| 795/795 [03:55<00:00,  3.38it/s]


Epoch [7/50], Average Loss: 0.1258


Epoch 8/50: 100%|██████████| 795/795 [03:54<00:00,  3.39it/s]


Epoch [8/50], Average Loss: 0.1243
Checkpoint saved: ./model_weights/checkpoint_linear_epoch_8.pth


Epoch 9/50: 100%|██████████| 795/795 [03:54<00:00,  3.39it/s]


Epoch [9/50], Average Loss: 0.1234


Epoch 10/50: 100%|██████████| 795/795 [03:56<00:00,  3.37it/s]


Epoch [10/50], Average Loss: 0.1235


Epoch 11/50: 100%|██████████| 795/795 [03:58<00:00,  3.33it/s]


Epoch [11/50], Average Loss: 0.1220


Epoch 12/50: 100%|██████████| 795/795 [03:56<00:00,  3.36it/s]


Epoch [12/50], Average Loss: 0.1200


Epoch 13/50: 100%|██████████| 795/795 [03:56<00:00,  3.37it/s]


Epoch [13/50], Average Loss: 0.1182


Epoch 14/50: 100%|██████████| 795/795 [03:56<00:00,  3.35it/s]


Epoch [14/50], Average Loss: 0.1165


Epoch 15/50: 100%|██████████| 795/795 [03:57<00:00,  3.35it/s]


Epoch [15/50], Average Loss: 0.1159


Epoch 16/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]


Epoch [16/50], Average Loss: 0.1149
Checkpoint saved: ./model_weights/checkpoint_linear_epoch_16.pth


Epoch 17/50: 100%|██████████| 795/795 [03:58<00:00,  3.34it/s]


Epoch [17/50], Average Loss: 0.1139


Epoch 18/50: 100%|██████████| 795/795 [03:58<00:00,  3.33it/s]


Epoch [18/50], Average Loss: 0.1137


Epoch 19/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]


Epoch [19/50], Average Loss: 0.1128


Epoch 20/50: 100%|██████████| 795/795 [03:58<00:00,  3.33it/s]


Epoch [20/50], Average Loss: 0.1128


Epoch 21/50: 100%|██████████| 795/795 [04:00<00:00,  3.31it/s]


Epoch [21/50], Average Loss: 0.1108


Epoch 22/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]


Epoch [22/50], Average Loss: 0.1100


Epoch 23/50: 100%|██████████| 795/795 [04:02<00:00,  3.28it/s]


Epoch [23/50], Average Loss: 0.1092


Epoch 24/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]


Epoch [24/50], Average Loss: 0.1081
Checkpoint saved: ./model_weights/checkpoint_linear_epoch_24.pth


Epoch 25/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]


Epoch [25/50], Average Loss: 0.1076


Epoch 26/50: 100%|██████████| 795/795 [04:00<00:00,  3.31it/s]


Epoch [26/50], Average Loss: 0.1077


Epoch 27/50: 100%|██████████| 795/795 [04:01<00:00,  3.29it/s]


Epoch [27/50], Average Loss: 0.1067


Epoch 28/50: 100%|██████████| 795/795 [03:59<00:00,  3.31it/s]


Epoch [28/50], Average Loss: 0.1062


Epoch 29/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]


Epoch [29/50], Average Loss: 0.1057


Epoch 30/50: 100%|██████████| 795/795 [03:57<00:00,  3.35it/s]


Epoch [30/50], Average Loss: 0.1060


Epoch 31/50: 100%|██████████| 795/795 [04:03<00:00,  3.27it/s]


Epoch [31/50], Average Loss: 0.1048


Epoch 32/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]


Epoch [32/50], Average Loss: 0.1051
Checkpoint saved: ./model_weights/checkpoint_linear_epoch_32.pth


Epoch 33/50: 100%|██████████| 795/795 [03:58<00:00,  3.33it/s]


Epoch [33/50], Average Loss: 0.1047


Epoch 34/50: 100%|██████████| 795/795 [04:00<00:00,  3.31it/s]


Epoch [34/50], Average Loss: 0.1039


Epoch 35/50: 100%|██████████| 795/795 [04:00<00:00,  3.30it/s]


Epoch [35/50], Average Loss: 0.1036


Epoch 36/50: 100%|██████████| 795/795 [03:59<00:00,  3.33it/s]


Epoch [36/50], Average Loss: 0.1038


Epoch 37/50: 100%|██████████| 795/795 [04:00<00:00,  3.30it/s]


Epoch [37/50], Average Loss: 0.1027


Epoch 38/50: 100%|██████████| 795/795 [04:00<00:00,  3.30it/s]


Epoch [38/50], Average Loss: 0.1031


Epoch 39/50: 100%|██████████| 795/795 [04:02<00:00,  3.28it/s]


Epoch [39/50], Average Loss: 0.1024


Epoch 40/50: 100%|██████████| 795/795 [04:01<00:00,  3.30it/s]


Epoch [40/50], Average Loss: 0.1026
Checkpoint saved: ./model_weights/checkpoint_linear_epoch_40.pth


Epoch 41/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]


Epoch [41/50], Average Loss: 0.1026


Epoch 42/50: 100%|██████████| 795/795 [04:01<00:00,  3.30it/s]


Epoch [42/50], Average Loss: 0.1015


Epoch 43/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]


Epoch [43/50], Average Loss: 0.1024


Epoch 44/50: 100%|██████████| 795/795 [04:02<00:00,  3.27it/s]


Epoch [44/50], Average Loss: 0.1012


Epoch 45/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]


Epoch [45/50], Average Loss: 0.1010


Epoch 46/50: 100%|██████████| 795/795 [04:01<00:00,  3.29it/s]


Epoch [46/50], Average Loss: 0.1004


Epoch 47/50: 100%|██████████| 795/795 [04:00<00:00,  3.31it/s]


Epoch [47/50], Average Loss: 0.1005


Epoch 48/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]


Epoch [48/50], Average Loss: 0.1010
Checkpoint saved: ./model_weights/checkpoint_linear_epoch_48.pth


Epoch 49/50: 100%|██████████| 795/795 [03:59<00:00,  3.32it/s]

Epoch [49/50], Average Loss: 0.1009
Early stopping triggered





### Nearest

In [12]:
 checkpoint_path = './model_weights/current_nearest_checkpoint.pth'

# # Load checkpoint if exists
# if os.path.exists(checkpoint_path):
#     start_epoch = load_checkpoint(gnn_model, optimizer, checkpoint_path)
#     print(f"Model loaded from checkpoint, starting from epoch {start_epoch + 1}")
# else:
#     start_epoch = 0

gnn_model.train()

criterion = nn.MSELoss()
optimizer = optim.Adam(gnn_model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3)
early_stopping = EarlyStopping(patience=5, min_delta=0.001)

num_epochs = 15

for epoch in range(num_epochs):

    epoch_loss = 0.0
    for batch in tqdm(train_dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        
        # Move to cuda
        gnndata = batch['gnndata'].to('cuda')
        pooled_act_depths = [depth_map.cuda() for depth_map in batch['pooled_act_depths']]

        output = gnn_model(gnndata.x, gnndata.edge_index, gnndata.edge_attr)

        loss = 0
        for i, node_output in enumerate(output):
            node_output_reshaped = node_output.view(-1, output_size)
            ground_truth = pooled_act_depths[0][i].view(-1, output_size)

            normalized_output = (node_output_reshaped - node_output_reshaped.min()) / (node_output_reshaped.max() - node_output_reshaped.min() + 1e-8)
            normalized_ground_truth = (ground_truth - ground_truth.min()) / (ground_truth.max() - ground_truth.min() + 1e-8)

            loss += criterion(normalized_output, normalized_ground_truth)

        loss = loss / len(output)
#         print(loss)
        
#         print(loss)
        if torch.isnan(loss):
            print("Loss became NaN. Stopping training.")
            break

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        epoch_loss += loss.item()

        # Free up memory
        del gnndata, pooled_act_depths
        torch.cuda.empty_cache()

    # Print the average loss for the epoch
    avg_loss = epoch_loss / len(train_dataloader)
    print(f"Epoch [{epoch+1}/{num_epochs}], Average Loss: {avg_loss:.4f}")

    # early stopping
    early_stopping(avg_loss)
    if early_stopping.early_stop:
        print("Early stopping triggered")
        break

    # Step the scheduler
    scheduler.step(avg_loss)

    # Save checkpoint every 8 epochs
    if (epoch + 1) % 4 == 0:
        checkpoint_name = f'./model_weights/checkpoint_nearest_epoch_{epoch+1}.pth'
        save_checkpoint(epoch, gnn_model, optimizer, checkpoint_name)
        print(f"Checkpoint saved: {checkpoint_name}")

    save_checkpoint(epoch, gnn_model, optimizer, checkpoint_path)

Epoch 1/15: 100%|██████████| 795/795 [11:48<00:00,  1.12it/s]


Epoch [1/15], Average Loss: 0.1451


Epoch 2/15: 100%|██████████| 795/795 [06:01<00:00,  2.20it/s]


Epoch [2/15], Average Loss: 0.1440


Epoch 3/15: 100%|██████████| 795/795 [06:03<00:00,  2.19it/s]


Epoch [3/15], Average Loss: 0.1437


Epoch 4/15: 100%|██████████| 795/795 [05:58<00:00,  2.22it/s]


Epoch [4/15], Average Loss: 0.1435
Checkpoint saved: ./model_weights/checkpoint_nearest_epoch_4.pth


Epoch 5/15: 100%|██████████| 795/795 [05:50<00:00,  2.27it/s]


Epoch [5/15], Average Loss: 0.1434


Epoch 6/15: 100%|██████████| 795/795 [05:59<00:00,  2.21it/s]


Epoch [6/15], Average Loss: 0.1430


Epoch 7/15: 100%|██████████| 795/795 [05:54<00:00,  2.24it/s]


Epoch [7/15], Average Loss: 0.1429


Epoch 8/15: 100%|██████████| 795/795 [06:00<00:00,  2.21it/s]


Epoch [8/15], Average Loss: 0.1411
Checkpoint saved: ./model_weights/checkpoint_nearest_epoch_8.pth


Epoch 9/15: 100%|██████████| 795/795 [06:37<00:00,  2.00it/s]


Epoch [9/15], Average Loss: 0.1399


Epoch 10/15: 100%|██████████| 795/795 [06:30<00:00,  2.04it/s]


Epoch [10/15], Average Loss: 0.1388


Epoch 11/15: 100%|██████████| 795/795 [06:29<00:00,  2.04it/s]


Epoch [11/15], Average Loss: 0.1388


Epoch 12/15: 100%|██████████| 795/795 [06:15<00:00,  2.12it/s]


Epoch [12/15], Average Loss: 0.1366
Checkpoint saved: ./model_weights/checkpoint_nearest_epoch_12.pth


Epoch 13/15: 100%|██████████| 795/795 [06:28<00:00,  2.05it/s]


Epoch [13/15], Average Loss: 0.1346


Epoch 14/15: 100%|██████████| 795/795 [06:52<00:00,  1.93it/s]


Epoch [14/15], Average Loss: 0.1323


Epoch 15/15: 100%|██████████| 795/795 [06:04<00:00,  2.18it/s]


Epoch [15/15], Average Loss: 0.1316


In [37]:
loss

tensor(0.1220, device='cuda:0', grad_fn=<DivBackward0>)

In [21]:
!nvidia-smi

Wed Aug 14 18:41:04 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100 80GB PCIe          On  |   00000000:49:00.0 Off |                   On |
| N/A   36C    P0            107W /  300W |    6603MiB /  81920MiB |     N/A      Default |
|                                         |                        |              Enabled |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [15]:
save_checkpoint(10, gnn_model, optimizer, checkpoint_path)