In [None]:
# Initialize the dataset
from dex_grasp.dataset.grasp_dataset import GraspDataset
from torch.utils.data import DataLoader
from torch.utils.data import ConcatDataset

pkl_dirs = [
    "/data_ssd/irmak/deft-data-all/ego4d-r3m/labels_obj_bbox",
    "/data_ssd/irmak/deft-data-all/ego4d-sta/labels_obj_bbox",
    "/data_ssd/irmak/deft-data-all/ek100/labels_obj_bbox", 
    "/data_ssd/irmak/deft-data-all/hoi4d/labels"
]

dsets = [] 
for pkl_dir in pkl_dirs:
    dsets.append(GraspDataset(
        pkl_dir=pkl_dir,
        return_cropped_image=False,
    ))

dataset = ConcatDataset(dsets)

# Initialize the dataloader
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=32)


In [15]:

from tqdm import tqdm 
import torch
from dex_grasp.models.grasp_transformer import AffordanceModel, GraspTransformer
from torch import nn

lambda_m = 1e-5 # mu lambda
lambda_g = 5 # grasp lambda
lambda_r = 0.1 # rotation lambda

affordance_model = AffordanceModel(
    src_in_features=512,
    freeze_rep=True
).to('cuda')
grasp_transformer = GraspTransformer(text_dim=512, image_dim=512).to('cuda')
optimizer = torch.optim.AdamW(list(affordance_model.parameters()) + list(grasp_transformer.parameters()), lr=1e-4)

# pbar = tqdm(len(dataloader))
for batch in dataloader:
    img = batch[0].to('cuda')
    task_description = batch[6]
    gt_mu = batch[2].to('cuda')[:,:,:2]
    gt_grasp_rotation = batch[3].to('cuda')
    gt_grasp_pose = batch[4].to('cuda')

    optimizer.zero_grad()

    img_feat, text_feat = affordance_model.get_clip_features(img, task_description)
    mu, cvar = affordance_model.get_mu_cvar(img_feat)
    grasp_rotation, grasp_pose = grasp_transformer(text_feat, img_feat)
    
    contact_loss = nn.functional.mse_loss(mu, gt_mu)
    grasp_rotation_loss = nn.functional.mse_loss(grasp_rotation, gt_grasp_rotation)
    grasp_pose_loss = nn.functional.mse_loss(grasp_pose, gt_grasp_pose)

    loss = lambda_m * contact_loss + lambda_r * grasp_rotation_loss + lambda_g * grasp_pose_loss
    loss.backward()
    optimizer.step()
    
    # pbar.update(1)
    print(f"Loss: {loss.item():.2f} | Contact: {lambda_m * contact_loss:.2f} | Rotation: {lambda_r * grasp_rotation_loss:.2f} | Pose: {lambda_g * grasp_pose_loss:.2f}")
    
# pbar.close()

    

Loss: 4.19 | Contact: 0.92 | Rotation: 0.26 | Pose: 3.01
Loss: 1.79 | Contact: 0.78 | Rotation: 0.28 | Pose: 0.72
Loss: 1.93 | Contact: 0.85 | Rotation: 0.24 | Pose: 0.83
Loss: 2.05 | Contact: 0.85 | Rotation: 0.27 | Pose: 0.94
Loss: 2.05 | Contact: 0.89 | Rotation: 0.25 | Pose: 0.91
Loss: 1.74 | Contact: 0.82 | Rotation: 0.25 | Pose: 0.68
Loss: 1.74 | Contact: 0.88 | Rotation: 0.26 | Pose: 0.60
Loss: 1.63 | Contact: 0.84 | Rotation: 0.26 | Pose: 0.53
Loss: 1.61 | Contact: 0.83 | Rotation: 0.26 | Pose: 0.53
Loss: 1.63 | Contact: 0.81 | Rotation: 0.27 | Pose: 0.55
Loss: 1.65 | Contact: 0.83 | Rotation: 0.27 | Pose: 0.56
Loss: 1.68 | Contact: 0.83 | Rotation: 0.27 | Pose: 0.58
Loss: 1.66 | Contact: 0.78 | Rotation: 0.26 | Pose: 0.62
Loss: 1.53 | Contact: 0.84 | Rotation: 0.27 | Pose: 0.42
Loss: 1.67 | Contact: 0.88 | Rotation: 0.24 | Pose: 0.54
Loss: 1.64 | Contact: 0.88 | Rotation: 0.25 | Pose: 0.51
Loss: 1.66 | Contact: 0.93 | Rotation: 0.25 | Pose: 0.48
Loss: 1.57 | Contact: 0.84 | Ro

KeyboardInterrupt: 