In [15]:
import os
import random
import shutil
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
import torch.nn.functional as F  # Import for F.mse_loss
from torchvision import models
import cv2
import yaml
from tqdm import tqdm

In [16]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class PoseNet6D_MLP_ConcatFusion(nn.Module):
    def __init__(self, pretrained=True, compress_rgb=True):
        super(PoseNet6D_MLP_ConcatFusion, self).__init__()

        # RGB  (ResNet50)
        resnet_rgb = models.resnet50(
            weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
        )
        self.rgb_backbone = nn.Sequential(*list(resnet_rgb.children())[:-1])  # (B, 2048, 1, 1)

        # convert RGB features to 512-dim
        self.compress_rgb = compress_rgb
        if compress_rgb:
            self.rgb_compress = nn.Linear(2048, 512)

        # Depth  (ResNet18, 1-channel input)
        resnet_depth = models.resnet18(weights=None)
        resnet_depth.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.depth_backbone = nn.Sequential(*list(resnet_depth.children())[:-1])  # (B, 512, 1, 1)

        # 2 layer MLP  for weight
        self.gate_mlp = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 2),
            nn.Sigmoid()
        )

        # Pose regression layers (input = 1024  concat)
        self.fc_rot = nn.Linear(1024, 4)
        self.fc_trans = nn.Linear(1024, 3)

    def forward(self, rgb_img, depth_img):
        # RGB feature extraction
        rgb_feat = self.rgb_backbone(rgb_img).squeeze(-1).squeeze(-1)  # (B, 2048)
        if self.compress_rgb:
            rgb_feat = self.rgb_compress(rgb_feat)  # (B, 512)

        # Depth feature extraction
        depth_feat = self.depth_backbone(depth_img).squeeze(-1).squeeze(-1)  # (B, 512)

        # cat
        concat_feat = torch.cat([rgb_feat, depth_feat], dim=1)  # (B, 1024)
        gates = self.gate_mlp(concat_feat)                      # (B, 2)
        rgb_gate = gates[:, 0].unsqueeze(1)
        depth_gate = gates[:, 1].unsqueeze(1)

        # Apply weight
        gated_rgb_feat = rgb_feat * rgb_gate
        gated_depth_feat = depth_feat * depth_gate

        #  concatenation (final feature = 1024)
        fused_feat = torch.cat([gated_rgb_feat, gated_depth_feat], dim=1)  # (B, 1024)

        # Pose regression
        rot = self.fc_rot(fused_feat)
        trans = self.fc_trans(fused_feat)
        rot = F.normalize(rot, dim=1)

        return rot, trans


In [17]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class PoseNet6D_MLP_ConcatFusion_512(nn.Module):
    def __init__(self, pretrained=True, compress_rgb=True):
        super(PoseNet6D_MLP_ConcatFusion_512, self).__init__()

        # RGB  (ResNet50)
        resnet_rgb = models.resnet50(
            weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
        )
        self.rgb_backbone = nn.Sequential(*list(resnet_rgb.children())[:-1])  # (B, 2048, 1, 1)

        # convert RGB features to 512-dim
        self.compress_rgb = compress_rgb
        if compress_rgb:
            self.rgb_compress = nn.Linear(2048, 512)

        # Depth  (ResNet18, 1-channel input)
        resnet_depth = models.resnet18(weights=None)
        resnet_depth.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.depth_backbone = nn.Sequential(*list(resnet_depth.children())[:-1])  # (B, 512, 1, 1)

        # 2 layer MLP for weight
        self.gate_mlp = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 2),
            nn.Sigmoid()
        )

        #  (1024 to 512)
        self.fusion_fc = nn.Linear(1024, 512)

        # Pose regression layers
        self.fc_rot = nn.Linear(512, 4)
        self.fc_trans = nn.Linear(512, 3)

    def forward(self, rgb_img, depth_img):
        # RGB feature extraction
        rgb_feat = self.rgb_backbone(rgb_img).squeeze(-1).squeeze(-1)  # (B, 2048)
        if self.compress_rgb:
            rgb_feat = self.rgb_compress(rgb_feat)  # (B, 512)

        # Depth feature extraction
        depth_feat = self.depth_backbone(depth_img).squeeze(-1).squeeze(-1)  # (B, 512)

        # cat
        concat_feat = torch.cat([rgb_feat, depth_feat], dim=1)  # (B, 1024)
        gates = self.gate_mlp(concat_feat)                      # (B, 2)
        rgb_gate = gates[:, 0].unsqueeze(1)
        depth_gate = gates[:, 1].unsqueeze(1)

        # Apply weight
        gated_rgb_feat = rgb_feat * rgb_gate
        gated_depth_feat = depth_feat * depth_gate

        #  concatenation
        fused_feat = torch.cat([gated_rgb_feat, gated_depth_feat], dim=1)  # (B, 1024)
        fused_feat = self.fusion_fc(fused_feat)  # (B, 512)

        # Pose regression
        rot = self.fc_rot(fused_feat)
        trans = self.fc_trans(fused_feat)
        rot = F.normalize(rot, dim=1)

        return rot, trans


In [18]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class PoseNet6D_MLP_Fusion(nn.Module):
    def __init__(self, pretrained=True, compress_rgb=True):
        super(PoseNet6D_MLP_Fusion, self).__init__()

        # RGB  (ResNet50)
        resnet_rgb = models.resnet50(
            weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
        )
        self.rgb_backbone = nn.Sequential(*list(resnet_rgb.children())[:-1])  # (B, 2048, 1, 1)

        # convert to 512-dim
        self.compress_rgb = compress_rgb
        if compress_rgb:
            self.rgb_compress = nn.Linear(2048, 512)

        # Depth  (ResNet18, 1-channel input)
        resnet_depth = models.resnet18(weights=None)
        resnet_depth.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.depth_backbone = nn.Sequential(*list(resnet_depth.children())[:-1])  # (B, 512, 1, 1)

        # 2layer MLP 
        self.gate_mlp = nn.Sequential(
            nn.Linear(1024, 256),
            nn.ReLU(),
            nn.Linear(256, 2),      #  one for RGB, one for Depth
            nn.Sigmoid()            # Output in [0,1]
        )

        # Pose regression layers 
        self.fc_rot = nn.Linear(512, 4)   # quaternion
        self.fc_trans = nn.Linear(512, 3) # translation

    def forward(self, rgb_img, depth_img):
        # RGB feature extraction
        rgb_feat = self.rgb_backbone(rgb_img).squeeze(-1).squeeze(-1)  # (B, 2048)
        if self.compress_rgb:
            rgb_feat = self.rgb_compress(rgb_feat)  # (B, 512)

        # Depth feature extraction
        depth_feat = self.depth_backbone(depth_img).squeeze(-1).squeeze(-1)  # (B, 512)

        # Concatenate features
        concat_feat = torch.cat([rgb_feat, depth_feat], dim=1)  # (B, 1024)
        gates = self.gate_mlp(concat_feat)                      # (B, 2), values in [0,1]

        # Split gating weights
        rgb_gate = gates[:, 0].unsqueeze(1)     # (B,1)
        depth_gate = gates[:, 1].unsqueeze(1)   # (B,1)

        # element-wise multiply
        gated_rgb_feat = rgb_feat * rgb_gate    # (B, 512)
        gated_depth_feat = depth_feat * depth_gate  # (B, 512)

        #  sum
        fused_feat = gated_rgb_feat + gated_depth_feat  # (B, 512)

        # Pose regression
        rot = self.fc_rot(fused_feat)       # (B, 4)
        trans = self.fc_trans(fused_feat)   # (B, 3)

        # Normalize quaternion to unit length
        rot = F.normalize(rot, dim=1)

        return rot, trans


In [19]:
import torch
from torch.utils.data import Dataset
import pandas as pd
import cv2
import numpy as np
from scipy.spatial.transform import Rotation as R
import torchvision.transforms as T
import os
import yaml
from PIL import Image


class PoseDataset(Dataset):
    def __init__(self,  rgb_dir, depth_dir,linemod_root, augment=False):

        self.rgb_dir = rgb_dir
        self.depth_dir = depth_dir
        self.linemod_root = linemod_root
        self.RGB_img_filenames = sorted([
            f for f in os.listdir(rgb_dir) if f.endswith(".png")
        ])
        self.depth_img_filenames = sorted([
            f for f in os.listdir(depth_dir) if f.endswith(".png")
        ])


        # Preload gt.yml data for all classes
        self.gt_data = {}
        for class_id in range(1, 16):
            class_str = f"{class_id:02d}"
            gt_path = os.path.join(linemod_root, class_str, "gt.yml")
            if os.path.exists(gt_path):
                with open(gt_path, 'r') as f:
                    self.gt_data[class_str] = yaml.safe_load(f)

        
        self.rgb_transform = T.Compose([
            T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
            T.RandomHorizontalFlip(),
            T.Resize((224, 224)),
            T.ToTensor(),
            T.Normalize(mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225])
        ])
        self.depth_transform = T.Compose([
            # T.ToTensor(),
            # T.Resize((224, 224)),
            T.Resize((224, 224)),
            T.ToTensor()  # Converts PIL float32 to FloatTensor [1, H, W]
        ])

    def __len__(self):
        return len(self.RGB_img_filenames)

    def normalize_depth(self,depth):
        depth = np.array(depth).astype(np.float32)
        return (depth - depth.min()) / (depth.max() - depth.min() + 1e-8)

    def __getitem__(self, idx):

        RGB_filename = self.RGB_img_filenames[idx]
        class_id_str, img_id_str = RGB_filename.split("_")
        img_id = int(os.path.splitext(img_id_str)[0])

        # Load RGB image
        rgb_path = os.path.join(self.rgb_dir, RGB_filename)
        rgb_img = Image.open(rgb_path).convert("RGB")
        rgb_tensor = self.rgb_transform(rgb_img)

        # Load Depth image
        # depth_path = os.path.join(self.depth_dir, RGB_filename)
        # depth_img = Image.open(depth_path).convert("I")   # Single-channel

        # # Normalize and convert to PIL for transforms
        # depth_np_norm = self.normalize_depth(depth_img)
        # depth_img_norm = Image.fromarray((depth_np_norm * 255).astype(np.uint8))

        # depth_tensor = self.depth_transform(depth_img_norm)
           # Load Depth image (PIL single channel)
        depth_path = os.path.join(self.depth_dir, RGB_filename)
        depth_img = Image.open(depth_path).convert("I")  # 32-bit integer depth

        # Convert to float numpy, normalize (e.g., scale mm->meters or divide by max)
        depth_np = np.array(depth_img).astype(np.float32)
        depth_np /= 1000.0  # if in mm, convert to meters, adjust as per your data

        # Optional: clip depth values to a max distance (e.g., 2 meters)
        depth_np = np.clip(depth_np, 0, 2.0)

        # Normalize depth to [0,1] by dividing by max depth value (2.0)
        depth_np /= 2.0

        # Convert normalized float depth to PIL image in 'F' mode (32-bit float)
        depth_img_float = Image.fromarray(depth_np).convert('F')

        # Apply depth transforms (Resize -> ToTensor)
        depth_tensor = self.depth_transform(depth_img_float)  # [1, H, W], float32 in [0,1]


        # Load pose from GT file
        pose_list = self.gt_data[class_id_str][img_id]
        pose = next(item for item in pose_list if item['obj_id'] == int(class_id_str))

        R_mat = np.array(pose['cam_R_m2c']).reshape(3, 3).astype(np.float32)
        quat = R.from_matrix(R_mat).as_quat().astype(np.float32)
        quat /= np.linalg.norm(quat)
        t_vec = np.array(pose['cam_t_m2c'], dtype=np.float32) / 1000.0  #  to meters

        return {
            'RGB_image': rgb_tensor,
            'depth_image': depth_tensor,
            'rotation': torch.tensor(quat, dtype=torch.float32),
            'rotation_matrix': torch.tensor(R_mat, dtype=torch.float32),
            'translation': torch.tensor(t_vec, dtype=torch.float32),
            'class_id': int(class_id_str),
            'filename': RGB_filename
        }




In [20]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models

class PoseNet6D(nn.Module):
    def __init__(self, pretrained=True, compress_rgb=True):
        super(PoseNet6D, self).__init__()

        # ==== RGB branch: ResNet50 ====
        resnet_rgb = models.resnet50(
            weights=models.ResNet50_Weights.IMAGENET1K_V2 if pretrained else None
        )
        self.rgb_backbone = nn.Sequential(*list(resnet_rgb.children())[:-1])  # Output: (B, 2048, 1, 1)

        #  compression layer for RGB features
        self.compress_rgb = compress_rgb
        if compress_rgb:
            self.rgb_compress = nn.Linear(2048, 512)  # Match depth feature dim

        # ==== Depth branch: ResNet18 modified for 1-channel input ====
        resnet_depth = models.resnet18(weights=None)  # No pretrained weights for depth
        resnet_depth.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
        self.depth_backbone = nn.Sequential(*list(resnet_depth.children())[:-1])  # Output: (B, 512, 1, 1)

        # ==== Pose regression ====
        self.fc_rot = nn.Linear(1024, 4)   # 512 + 512 -> Quaternion
        self.fc_trans = nn.Linear(1024, 3) # 512 + 512 -> Translation

    def forward(self, rgb_img, depth_img):
        # RGB features
        rgb_feat = self.rgb_backbone(rgb_img).squeeze(-1).squeeze(-1)  # (B, 2048)
        if self.compress_rgb:
            rgb_feat = self.rgb_compress(rgb_feat)  # (B, 512)

        # Depth features
        depth_feat = self.depth_backbone(depth_img).squeeze(-1).squeeze(-1)  # (B, 512)

        # Concatenate features
        feat = torch.cat([rgb_feat, depth_feat], dim=1)  # (B, 1024)

        # Predict pose
        rot = self.fc_rot(feat)         # (B, 4)
        trans = self.fc_trans(feat)     # (B, 3)
        rot = F.normalize(rot, dim=1)   # Normalize quaternion
        return rot, trans


In [21]:
import torch
import torch.nn.functional as F

def mse_pose_loss(pred_q, pred_t, gt_q, gt_t):
    return torch.mean((pred_q - gt_q)**2) + torch.mean((pred_t - gt_t)**2)

def angle_pose_loss(pred_q, pred_t, gt_q, gt_t):
    pred_q = F.normalize(pred_q, dim=1)
    gt_q = F.normalize(gt_q, dim=1)
    cos_sim = torch.sum(pred_q * gt_q, dim=1).clamp(-1+1e-7, 1-1e-7)
    angle_loss = torch.mean(1 - cos_sim.abs())
    trans_loss = torch.mean((pred_t - gt_t)**2)
    return angle_loss + trans_loss

def smooth_l1_pose_loss(pred_q, pred_t, gt_q, gt_t):
    return F.smooth_l1_loss(pred_q, gt_q) + F.smooth_l1_loss(pred_t, gt_t)

def pose_loss(pred_q, pred_t, gt_q, gt_t):
    rot_loss = 1 - torch.sum(pred_q * gt_q, dim=1)**2
    trans_loss = torch.mean((pred_t - gt_t)**2, dim=1)
    return rot_loss.mean() + trans_loss.mean()


In [22]:
def train_model(model, dataloader, optimizer, device, scaler=None):
    model.train()
    total_loss = 0.0

    for batch in dataloader:
        rgb = batch['RGB_image'].to(device)
        depth = batch['depth_image'].to(device)
        gt_q = batch['rotation'].to(device)
        gt_t = batch['translation'].to(device)

        optimizer.zero_grad()

        if scaler:  # Mixed precision
            with torch.cuda.amp.autocast():
                pred_q, pred_t = model(rgb, depth)
                # print(f"pred_q={pred_q} and pred_t={pred_t}")
                # print(f"gt_q={gt_q} and gt_t={gt_t}")
                # loss = pose_loss(pred_q, pred_t, gt_q, gt_t)
                # loss = smooth_l1_pose_loss(pred_q, pred_t, gt_q, gt_t)
                loss = angle_pose_loss(pred_q, pred_t, gt_q, gt_t)
                # loss = mse_pose_loss(pred_q, pred_t, gt_q, gt_t)
                # print(f"losss={loss}")
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            pred_q, pred_t = model(rgb, depth)
            # print(f"pred_q={pred_q} and pred_t={pred_t}")
            # print(f"gt_q={gt_q} and gt_t={gt_t}")
            # loss = pose_loss(pred_q, pred_t, gt_q, gt_t)
            # loss = smooth_l1_pose_loss(pred_q, pred_t, gt_q, gt_t)
            loss = angle_pose_loss(pred_q, pred_t, gt_q, gt_t)
            # loss = mse_pose_loss(pred_q, pred_t, gt_q, gt_t)
            # print(f"losss={loss}")
               
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)  # optional
            optimizer.step()
        
        total_loss += loss.item()

    return total_loss / len(dataloader)


In [23]:
def validate_model(model, dataloader, device):
    model.eval()
    total_loss = 0.0

    with torch.no_grad():
        for batch in dataloader:
            rgb = batch['RGB_image'].to(device)
            depth = batch['depth_image'].to(device)
            gt_q = batch['rotation'].to(device)
            gt_t = batch['translation'].to(device)

            pred_q, pred_t = model(rgb, depth)
            # loss = pose_loss(pred_q, pred_t, gt_q, gt_t)
            # loss = smooth_l1_pose_loss(pred_q, pred_t, gt_q, gt_t)
            loss = angle_pose_loss(pred_q, pred_t, gt_q, gt_t)
            # loss = mse_pose_loss(pred_q, pred_t, gt_q, gt_t)
            total_loss += loss.item()

    return total_loss / len(dataloader)


In [24]:
# ###test different optimiser using validation

# import os
# import random
# import shutil
# from sklearn.model_selection import train_test_split
# from torch.utils.data import DataLoader
# import torch
# import torch.nn as nn
# from torch.optim.lr_scheduler import ReduceLROnPlateau
# # ----- Paths -----
# RGB_cropped_dir = "/kaggle/input/rgboutput/RGB_crop/train/train_cropped_objects"
# depth_cropped_dir = "/kaggle/input/depthoutput/depth_crop/train/train_cropped_objects"
# linemod_root = "/kaggle/input/linemod/Linemod_preprocessed/data"
# working_dir = "/kaggle/working"
# RGB_train_dir = os.path.join(working_dir, "RGB/train")
# RGB_val_dir = os.path.join(working_dir, "RGB/val")
# depth_train_dir = os.path.join(working_dir, "depth/train")
# depth_val_dir = os.path.join(working_dir, "depth/val")
# split_ratio = 0.8

# os.makedirs(RGB_train_dir, exist_ok=True)
# os.makedirs(RGB_val_dir, exist_ok=True)
# os.makedirs(depth_train_dir, exist_ok=True)
# os.makedirs(depth_val_dir, exist_ok=True)

# image_files = [f for f in os.listdir(RGB_cropped_dir) if f.endswith(".png")]
# train_files, val_files = train_test_split(image_files, train_size=split_ratio, random_state=42)

# for file in train_files:
#     shutil.copy(os.path.join(RGB_cropped_dir, file), os.path.join(RGB_train_dir, file))
#     shutil.copy(os.path.join(depth_cropped_dir, file), os.path.join(depth_train_dir, file))

# for file in val_files:
#     shutil.copy(os.path.join(RGB_cropped_dir, file), os.path.join(RGB_val_dir, file))
#     shutil.copy(os.path.join(depth_cropped_dir, file), os.path.join(depth_val_dir, file))

# # ----- Load datasets -----
# train_dataset = PoseDataset(rgb_dir=RGB_train_dir, depth_dir=depth_train_dir, linemod_root=linemod_root)
# val_dataset = PoseDataset(rgb_dir=RGB_val_dir, depth_dir=depth_val_dir, linemod_root=linemod_root)
# train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
# val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)

# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# # ----- Model -----
# model = PoseNet6D_MLP_Fusion(pretrained=True).to(device)   # PoseNet6D_MLP_Fusion  PoseNet6D_MLP_ConcatFusion PoseNet6D_MLP_ConcatFusion_512 

# # ----- Optimizer -----
# optimizer_type = "SGD"  # Choice: "SGD", "Adam", "AdamW", "RMSprop"
# learning_rate = 1e-4

# if optimizer_type == "SGD":
#     optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9)
# elif optimizer_type == "Adam":
#     optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
# elif optimizer_type == "AdamW":
#     optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)
# elif optimizer_type == "RMSprop":
#     optimizer = torch.optim.RMSprop(model.parameters(), lr=learning_rate)
# else:
#     raise ValueError(f"Unsupported optimizer: {optimizer_type}")

# # ----- Save Paths -----
# save_path = os.path.join(working_dir, "extension_model")
# os.makedirs(save_path, exist_ok=True)
# best_model_path = os.path.join(save_path, "best_model.pth")
# checkpoint_path = os.path.join(save_path, "checkpoint.pth")

# # ----- Resume checkpoint -----
# start_epoch = 0
# best_val_loss = float('inf')
# if os.path.exists(checkpoint_path):
#     checkpoint = torch.load(checkpoint_path)
#     model.load_state_dict(checkpoint['model_state_dict'])
#     optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#     best_val_loss = checkpoint['best_val_loss']
#     start_epoch = checkpoint['epoch'] + 1
#     print(f"Resumed training from epoch {start_epoch}")
# else:
#     print("No checkpoint found, starting from epoch 0")

# # ----- Training Loop -----
# patience = 20
# no_improve_counter = 0
# epoch_num = 130
# scaler = torch.cuda.amp.GradScaler()

# for epoch in range(start_epoch, epoch_num):
#     print(f"\nEpoch {epoch + 1}/{epoch_num}")

#     train_loss = train_model(model, train_loader, optimizer, device, scaler=scaler)
#     val_loss = validate_model(model, val_loader, device)

#     print(f"Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")

#     if val_loss < best_val_loss:
#         best_val_loss = val_loss
#         torch.save(model.state_dict(), best_model_path)
#         print(f"✅ Saved new best model at epoch {epoch + 1} with val loss {val_loss:.4f}")
#         no_improve_counter = 0
#     # else:
#     #     no_improve_counter += 1
#     #     if no_improve_counter >= patience:
#     #         print("⏹️ Early stopping triggered")
#     #         break

#     checkpoint = {
#         'epoch': epoch,
#         'model_state_dict': model.state_dict(),
#         'optimizer_state_dict': optimizer.state_dict(),
#         'best_val_loss': best_val_loss
#     }
#     torch.save(checkpoint, checkpoint_path)
#     print(f"Checkpoint saved at epoch {epoch + 1}")

In [25]:

###test different learning rate using validation
import os
import random
import shutil
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
import torch
import torch.nn as nn
from torch.optim.lr_scheduler import ReduceLROnPlateau



# Paths 
RGB_cropped_dir = "/kaggle/input/rgboutput/RGB_crop/train/train_cropped_objects"
depth_cropped_dir = "/kaggle/input/depthoutput/depth_crop/train/train_cropped_objects"
linemod_root = "/kaggle/input/linemod/Linemod_preprocessed/data"


working_dir = "/kaggle/working"
RGB_train_dir = os.path.join(working_dir, "RGB/train")
RGB_val_dir = os.path.join(working_dir, "RGB/val")
depth_train_dir = os.path.join(working_dir, "depth/train")
depth_val_dir = os.path.join(working_dir, "depth/val")
split_ratio = 0.8

# Create folders
os.makedirs(RGB_train_dir, exist_ok=True)
os.makedirs(RGB_val_dir, exist_ok=True)
os.makedirs(depth_train_dir, exist_ok=True)
os.makedirs(depth_val_dir, exist_ok=True)

# Get image list and split
image_files = [f for f in os.listdir(RGB_cropped_dir) if f.endswith(".png")]
train_files, val_files = train_test_split(image_files, train_size=split_ratio, random_state=42)

# Copy data to train/val folders
for file in train_files:
    shutil.copy(os.path.join(RGB_cropped_dir, file), os.path.join(RGB_train_dir, file))
    shutil.copy(os.path.join(depth_cropped_dir, file), os.path.join(depth_train_dir, file))

for file in val_files:
    shutil.copy(os.path.join(RGB_cropped_dir, file), os.path.join(RGB_val_dir, file))
    shutil.copy(os.path.join(depth_cropped_dir, file), os.path.join(depth_val_dir, file))

# Load datasets
train_dataset = PoseDataset(rgb_dir=RGB_train_dir, depth_dir=depth_train_dir, linemod_root=linemod_root)
val_dataset = PoseDataset(rgb_dir=RGB_val_dir, depth_dir=depth_val_dir, linemod_root=linemod_root)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=True)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model
model = PoseNet6D_MLP_ConcatFusion(pretrained=True).to(device)   #PoseNet6D_MLP_Fusion PoseNet6D_MLP_ConcatFusion   
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-2)
scheduler = ReduceLROnPlateau(optimizer, mode='min', patience=5, factor=0.5, verbose=True)

# Save directory
save_path = os.path.join(working_dir, "extension_model")
os.makedirs(save_path, exist_ok=True)
best_model_path = os.path.join(save_path, "best_model.pth")
checkpoint_path = os.path.join(save_path, "checkpoint.pth")

# Load checkpoint if exists
start_epoch = 0
best_val_loss = float('inf')
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    best_val_loss = checkpoint['best_val_loss']
    start_epoch = checkpoint['epoch'] + 1
    print(f"Resumed training from epoch {start_epoch}")
else:
    print("No checkpoint found, starting from epoch 0")

# Training loop
patience = 20
no_improve_counter = 0
epoch_num = 100

for epoch in range(start_epoch, epoch_num):
    print(f"\nEpoch {epoch + 1}/{epoch_num}")

    train_loss = train_model(model, train_loader, optimizer, device)
    val_loss = validate_model(model, val_loader, device)

    print(f"Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
    scheduler.step(val_loss)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), best_model_path)
        print(f"✅ Saved new best model at epoch {epoch + 1} with val loss {val_loss:.4f}")

    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'best_val_loss': best_val_loss
    }
    torch.save(checkpoint, checkpoint_path)
    print(f"Checkpoint saved at epoch {epoch + 1}")


No checkpoint found, starting from epoch 0

Epoch 1/100
Train Loss: 0.3716, Validation Loss: 0.3217
✅ Saved new best model at epoch 1 with val loss 0.3217
Checkpoint saved at epoch 1

Epoch 2/100
Train Loss: 0.2667, Validation Loss: 0.2679
✅ Saved new best model at epoch 2 with val loss 0.2679
Checkpoint saved at epoch 2

Epoch 3/100
Train Loss: 0.2212, Validation Loss: 0.2319
✅ Saved new best model at epoch 3 with val loss 0.2319
Checkpoint saved at epoch 3

Epoch 4/100
Train Loss: 0.1943, Validation Loss: 0.2119
✅ Saved new best model at epoch 4 with val loss 0.2119
Checkpoint saved at epoch 4

Epoch 5/100
Train Loss: 0.1871, Validation Loss: 0.2068
✅ Saved new best model at epoch 5 with val loss 0.2068
Checkpoint saved at epoch 5

Epoch 6/100
Train Loss: 0.1594, Validation Loss: 0.1818
✅ Saved new best model at epoch 6 with val loss 0.1818
Checkpoint saved at epoch 6

Epoch 7/100
Train Loss: 0.1541, Validation Loss: 0.1537
✅ Saved new best model at epoch 7 with val loss 0.1537
Check

In [26]:
# import os
# import shutil

# # Delete a specific file
# file_path = '/kaggle/working/extension_model/best_model.pth'
# if os.path.exists(file_path):
#     os.remove(file_path)
# file_path = '/kaggle/working/extension_model/checkpoint.pth'
# if os.path.exists(file_path):
#     os.remove(file_path)


In [27]:

# # Delete all files in /kaggle/working
# for f in os.listdir('/kaggle/working'):
#     path = os.path.join('/kaggle/working', f)
#     if os.path.isfile(path):
#         os.remove(path)
#     else:
#         shutil.rmtree(path)  # If it's a directory
