In [1]:
from transformers import AutoModel, AutoProcessor
from transformers import DetrImageProcessor, DetrForObjectDetection
import numpy as np
import torch
from matplotlib import pyplot as plt
import torchvision
import os
from PIL import Image

device = "mps"

In [2]:
detr = DetrForObjectDetection.from_pretrained("diliash/detr-2024-04-05-13-03-08").to(device)
detr_processor = DetrImageProcessor.from_pretrained("diliash/detr-2024-04-05-13-03-08")

In [3]:
class CocoDetection(torchvision.datasets.CocoDetection):
    def __init__(self, coco_folder, processor, train=True):
        ann_file = os.path.join(coco_folder, "coco_annotation", "MotionNet_train.json" if train else "MotionNet_valid.json")
        super(CocoDetection, self).__init__(os.path.join(coco_folder, "train/origin" if train else "valid/origin"), ann_file)
        self.processor = processor

    def __getitem__(self, idx):
        img, target = super(CocoDetection, self).__getitem__(idx)
        image_id = self.ids[idx]
        target = {'image_id': image_id, 'annotations': target}
        encoding = self.processor(images=img, annotations=target, return_tensors="pt")
        pixel_values = encoding["pixel_values"].squeeze()
        target = encoding["labels"][0]
        return pixel_values, target


In [4]:
val_dataset = CocoDetection(coco_folder='./partnetsim-256-fixed-viewpoints/coco', processor=detr_processor, train=False)

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [5]:
special_val_dataset = CocoDetection(coco_folder='./partnetsim-1024-fixed-viewpoints/coco', processor=detr_processor, train=False)

loading annotations into memory...
Done (t=0.01s)
creating index...
index created!


In [6]:
from torch.utils.data import DataLoader

def collate_fn(batch):
  pixel_values = [item[0] for item in batch]
  encoding = detr_processor.pad(pixel_values, return_tensors="pt")
  labels = [item[1] for item in batch]
  batch = {}
  batch['pixel_values'] = encoding['pixel_values']
  batch['pixel_mask'] = encoding['pixel_mask']
  batch['labels'] = labels
  return batch

val_dataloader = DataLoader(val_dataset, collate_fn=collate_fn, batch_size=1)

In [7]:
import matplotlib.pyplot as plt

PARTNETSIM_COLOR_MAP = {
    1: (0.0, 107.0, 164.0),
    2: (255.0, 128.0, 14.0),
    3: (200.0, 82.0, 0.0),
    4: (171.0, 171.0, 171.0),
}

def plot_results(pil_img, scores, labels, boxes):
    plt.figure(figsize=(16,10))
    plt.imshow(pil_img)
    ax = plt.gca()
    for score, label, (xmin, ymin, xmax, ymax)  in zip(scores.tolist(), labels.tolist(), boxes.tolist()):
        c = [x / 255 for x in PARTNETSIM_COLOR_MAP[label]]
        ax.add_patch(plt.Rectangle((xmin, ymin), xmax - xmin, ymax - ymin,
                                   fill=False, color=c, linewidth=3))
        text = f'{detr.config.id2label[label]}: {score:0.2f}'
        ax.text(xmin, ymin, text, fontsize=15,
                bbox=dict(facecolor='yellow', alpha=0.5))
    plt.axis('off')
    plt.show()

In [8]:
aggregated_results = {}

In [10]:

for pixel_values, target in val_dataset:
  pixel_values = pixel_values.unsqueeze(0).to(device)
  with torch.no_grad():
    outputs = detr(pixel_values=pixel_values, pixel_mask=None)

  image_id = target['image_id'].item()
  
  image = val_dataset.coco.loadImgs(image_id)[0]
  image = Image.open(os.path.join('./partnetsim-256-fixed-viewpoints/coco/valid/origin', image['file_name']))

  width, height = image.size
  postprocessed_outputs = detr_processor.post_process_object_detection(outputs,
                                                                  target_sizes=[(height, width)],
                                                                  threshold=0.8)
  results = postprocessed_outputs[0]
  aggregated_results[image_id] = results
  #print(image_id)

  # Uncomment to see the predicted bboxes
  #plot_results(image, results['scores'], results['labels'], results['boxes'])

In [16]:
aggregated_results

{1: {'scores': tensor([], device='mps:0'),
  'labels': tensor([], device='mps:0', dtype=torch.int64),
  'boxes': tensor([], device='mps:0', size=(0, 4))},
 2: {'scores': tensor([], device='mps:0'),
  'labels': tensor([], device='mps:0', dtype=torch.int64),
  'boxes': tensor([], device='mps:0', size=(0, 4))},
 3: {'scores': tensor([], device='mps:0'),
  'labels': tensor([], device='mps:0', dtype=torch.int64),
  'boxes': tensor([], device='mps:0', size=(0, 4))},
 4: {'scores': tensor([], device='mps:0'),
  'labels': tensor([], device='mps:0', dtype=torch.int64),
  'boxes': tensor([], device='mps:0', size=(0, 4))},
 5: {'scores': tensor([0.9933, 0.9502, 0.9799], device='mps:0'),
  'labels': tensor([1, 1, 1], device='mps:0'),
  'boxes': tensor([[ 65.6719,  45.2676, 215.8931, 163.9143],
          [ 94.6078, 135.5776, 189.9000, 201.0171],
          [ 85.1700, 105.2373, 206.4181, 180.6086]], device='mps:0')},
 6: {'scores': tensor([0.9947, 0.9962, 0.9940], device='mps:0'),
  'labels': tensor(

Note, if you are using CUDA device for inference, you can remove .to("cpu") in the following cells. I did it since some operations are still unsupported by mps.

In [15]:
from transformers import AutoModel, AutoProcessor
sam_processor = AutoProcessor.from_pretrained("diliash/sam-2024-04-08-01-25-31")
sam = AutoModel.from_pretrained("diliash/sam-2024-04-08-01-25-31").to("cpu")
sam_processor.image_processor.do_normalize = False

In [24]:
# With 0-70 bbox augmentation
from transformers import AutoModel, AutoProcessor
sam_processor = AutoProcessor.from_pretrained("diliash/sam-2024-04-08-10-30-40")
sam = AutoModel.from_pretrained("diliash/sam-2024-04-08-10-30-40").to("cpu")
sam_processor.image_processor.do_normalize = False

In [None]:
# With 0-20 bbox augmentation
from transformers import AutoModel, AutoProcessor
sam_processor = AutoProcessor.from_pretrained("diliash/sam-2024-04-08-06-34-33")
sam = AutoModel.from_pretrained("diliash/sam-2024-04-08-06-34-33").to("cpu")
sam_processor.image_processor.do_normalize = False

In [25]:
class AggregatedResultsDataset(torch.utils.data.Dataset):
    def __init__(self, aggregated_results, coco_folder, image_mean=[0, 0, 0], image_std=[0, 0, 0]):
        self.aggregated_results = aggregated_results
        self.coco_folder = coco_folder
        self.pixel_mean = torch.tensor(image_mean)
        self.pixel_std = torch.tensor(image_std)

    def __len__(self):
        return len(self.aggregated_results)

    def __getitem__(self, idx):
        image_id = list(self.aggregated_results.keys())[idx]
        results = self.aggregated_results[image_id]

        image_info = val_dataset.coco.loadImgs(image_id)[0]
        anns = image_info
        image = Image.open(os.path.join(self.coco_folder, image_info['file_name']))

        image = image.convert("RGB")

        # Scale the bounding boxes from 256x256 to 1024x1024
        scale_factor = 1024 / 256
        bboxes = results['boxes'].cpu().numpy().tolist()
        scores = results['scores'].cpu().numpy().tolist()
        labels = results['labels'].cpu().numpy().tolist()
        scaled_bboxes = [[x * scale_factor for x in bbox] for bbox in bboxes]

        return image, scaled_bboxes, anns, scores, labels


In [26]:
aggregated_results_dataset = AggregatedResultsDataset(aggregated_results, coco_folder='./partnetsim-1024-fixed-viewpoints/coco/valid/origin')


In [27]:
def aggregated_results_collate_fn(batch):
    imgs, bboxes, anns, scores, labels = zip(*batch)

    if any(len(bbox_list) > 0 for bbox_list in bboxes):
        # Pad bboxes to have the same length within a batch
        max_bboxes = max(len(bbox_list) for bbox_list in bboxes)
        padded_bboxes = [bbox_list + [[0, 0, 0, 0]] * (max_bboxes - len(bbox_list)) for bbox_list in bboxes]
        input = sam_processor(images=list(imgs), input_boxes=padded_bboxes, return_tensors="pt").to(torch.float32).to("cpu")
        input["anns"] = anns
        input["scores"] = scores
        input["labels"] = labels
    else:
        # If there are no bounding boxes, return None
        input = None

    return input

aggregated_results_dataloader = DataLoader(aggregated_results_dataset, collate_fn=aggregated_results_collate_fn, batch_size=1)

In [28]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

In [None]:
# Visualize with pred bboxes

for idx, batch in enumerate(aggregated_results_dataloader):
    if batch is None:
        # If there are no bounding boxes, show the original image without masks
        img = aggregated_results_dataset[idx][0]
        plt.figure(figsize=(10, 10))
        plt.imshow(img)
        plt.axis('off')
        plt.show()
    else:
        outputs = sam(**batch, multimask_output=False)
        raw_imgs = batch["pixel_values"].cpu().numpy()

        masks = sam_processor.post_process_masks(outputs.pred_masks, batch["original_sizes"], batch["reshaped_input_sizes"])

        for i in range(np.shape(masks)[1]):
            plt.figure(figsize=(10, 10))
            img = np.transpose(raw_imgs[0], (1, 2, 0))
            plt.imshow(img)
            show_mask(masks[0][i][0], plt.gca())
            plt.axis('off')
            plt.show()

In [33]:
aggreagted_masks = {}

In [22]:
for idx, batch in enumerate(aggregated_results_dataloader):
    if batch is not None:
        outputs = sam(**batch, multimask_output=False)
        raw_imgs = batch["pixel_values"].cpu().numpy()
        
        masks = sam_processor.post_process_masks(outputs.pred_masks, batch["original_sizes"], batch["reshaped_input_sizes"])
        aggreagted_masks[idx] = {"masks": masks, "image_info": batch["anns"], "scores": batch["scores"], "lables": batch["labels"]}


In [23]:
torch.save(aggreagted_masks, 'aggregated_masks_detr-0.8-sam+70.pt')