In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import os
import pandas as pd

from PIL import Image

In [2]:
torch.cuda.is_available()

True

In [3]:
# device agnostic code
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [None]:
bar_transform = transforms.Compose([
    transforms.Resize((20, 10)),
    transforms.ToTensor()
])

class TabMagicDataset(Dataset):
    def __init__(self,
                 data_path: str,
                 annotation_path: str,
                 transform: transforms.Compose = None):
        self.transform = transform
        # Load and preprocess the data
        self.image_paths = [os.path.join(data_path, img) for img in os.listdir(data_path) if img.endswith(('.png', '.jpg', '.jpeg'))]
        self.annotations_df = pd.read_csv(annotation_path)

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        # Load image
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert('RGB')
        
        # Get corresponding annotation
        img_name = os.path.basename(img_path)
        annotations = self.annotations_df[self.annotations_df['image_name'] == img_name]
        
        # Extract labels from annotation
        labels = []
        boxes = []
        for _, annotation in annotations.iterrows():
            boxes.append([
                annotation['bbox_x'],
                annotation['bbox_y'],
                annotation['bbox_x'] + annotation['bbox_width'],
                annotation['bbox_y'] + annotation['bbox_height']
            ])
            label = self.get_label_id(annotation['label_name'])
            labels.append(label)

        # Convert to tensor
        boxes = torch.as_tensor(boxes, dtype=torch.float32)
        labels = torch.as_tensor(labels, dtype=torch.int64)

        target = {
            'boxes': boxes,
            'labels': labels
        }

        if self.transform:
            # Get original dimensions
            original_width, original_height = image.size
            
            # Apply transformations
            image = self.transform(image)
            
            # Get new dimensions
            new_width, new_height = image.shape[2], image.shape[1]
            
            # Calculate scaling factors
            scale_x = new_width / original_width
            scale_y = new_height / original_height
            
            # Scale bounding boxes
            boxes = target['boxes']
            boxes[:, [0, 2]] *= scale_x
            boxes[:, [1, 3]] *= scale_y
            target['boxes'] = boxes
        
        return image, target

    @staticmethod
    def get_label_id(label_name):
        if label_name == 'string':
            return 1  # All string labels are now class 26
        elif label_name == 'number':
            return 2
        else:
            raise ValueError(f"Unknown label: {label_name}")