In [None]:
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
import glob
import cv2
import json
import torch

class GripDataset(Dataset):
    def __init__(self, dir_path):
        self.images = sorted(glob.glob(f'{dir_path}/*.jpg'))
        self.transforms = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])
    
    def __len__(self): return len(self.images)
    
    def __getitem__(self, idx):
        img_path = self.images[idx]
        img = cv2.imread(img_path)
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img = self.transforms(img)
        
        json_path = img_path.replace('.jpg', '.json')
        json_path = json_path.replace('image_', 'joints_')
        with open(json_path, 'r') as f:
            data = json.load(f)
        joints = torch.tensor(data['joints'], dtype=torch.float32)
        deltas = torch.tensor(data['deltas'], dtype=torch.float32)  # Target adjustments
        
        return img, torch.cat([img.flatten(), joints]), deltas  # Flattened for simple FC

dataset = GripDataset('dataset_grip')
dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Model: Simple ResNet for regression
model = models.resnet18(pretrained=True)
model.fc = nn.Sequential(
    nn.Linear(512 + 6, 128),  # +6 for joints
    nn.ReLU(),
    nn.Linear(128, 6)  # 6D deltas
)
model = model.cuda()
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Train loop
for epoch in range(50):
    for imgs, states, targets in dataloader:
        preds = model(states.cuda())
        loss = criterion(preds, targets.cuda())
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    if epoch % 10 == 0:
        print(f'Epoch {epoch}, Loss: {loss.item():.4f}')

torch.save(model.state_dict(), 'grip_model.pth')

# Export to ONNX for TensorRT (use trtexec or jetson-inference tools)
torch.onnx.export(model, torch.randn(1, 512+6).cuda(), 'grip_model.onnx')
# Then: trtexec --onnx=grip_model.onnx --saveEngine=grip_model.trt
