In [2]:
import torch
from torchvision.ops import batched_nms
from torch.utils.data import DataLoader
from yolox.test_weights import download_weights, load_pretrained_weights
from yolox.model import create_yolox_s
from data_utils.ppe_dataset import PPE_DATA

In [3]:
num_classes = 80
weight_path = download_weights("yolox/yolox_s.pth")
model = create_yolox_s(num_classes)
model = load_pretrained_weights(model, weight_path, num_classes)

Loading weights from yolox/yolox_s.pth
Loaded weights:
  Missing keys: 0
  Unexpected keys: 0


In [4]:
model.eval().cuda()
dataset = PPE_DATA()

def yolo_collate(batch):
    imgs   = torch.stack([b[0] for b in batch])   #  (B,3,640,640)
    labels = [b[1] for b in batch]                #  list of length B
    return imgs, labels

loader = DataLoader(dataset, batch_size=16, shuffle = False, collate_fn=yolo_collate)

In [9]:
img, label = dataset[4]
output = model(img.unsqueeze(0).cuda())  # Add batch dimension
print(output.shape)

torch.Size([1, 8400, 85])


In [15]:
img, label = next(iter(loader))
output = model(img.cuda())  # Add batch dimension
print(output.shape)

torch.Size([16, 8400, 85])


In [None]:
for batch in loader:
    img, label = batch
    img = img.cuda()
    model = model.cuda()
    
    with torch.no_grad():
        # Forward pass through the model
        # noise = torch.randn_like(img).cuda()
        output = model(img)
        obj = output[..., 4:5]
        print(obj.max())

In [25]:
#TODO fix for batches
def post_process_img(output, confidence_threshold = 0.25, iou_threshold = 0.5):
    x1 = output[..., 0:1] - output[..., 2:3] / 2
    y1 = output[..., 1:2] - output[..., 3:4] / 2
    x2 = output[..., 0:1] + output[..., 2:3] / 2
    y2 = output[..., 1:2] + output[..., 3:4] / 2

    # boxes: (batch, num_anchors, 4)
    boxes = torch.cat([x1, y1, x2, y2], dim=-1)

    # (batch, num_anchors, 1)
    obj = output[..., 4:5]
    class_probs = output[..., 5:]

    scores = obj * class_probs
    best_scores, best_class = scores.max(dim=-1)

    mask = best_scores > confidence_threshold
    best_scores = best_scores[mask] 
    best_class = best_class[mask] 
    boxes = boxes[mask]
    keep = batched_nms(boxes, best_class, best_scores, iou_threshold = iou_threshold)
    print(boxes.shape, keep.shape)
    final_boxes = boxes[keep]
    final_classes = best_class[keep]
    final_scores = best_scores[keep]
    # final classes and final scores have shape (num_kept,), so unsqueeze to add the dim 1 again
    predictions = torch.cat((final_classes.unsqueeze(1), 
                             final_boxes, 
                             final_scores.unsqueeze(1)), dim=1)
    return predictions
preds = post_process_img(output[0])
print(preds.shape)

torch.Size([21, 4]) torch.Size([9])
torch.Size([9, 6])


In [None]:
# labels = torch.cat((final_classes.unsqueeze(1), final_boxes), dim=1)
# print(labels.shape)
# if labels.shape[0] == 0:
#     print("No detections found.")
# else:
#     PPE_DATA.show_img(img, labels[:20, :], rect_coords_centered=False)

torch.Size([3, 5])
tensor(0., device='cuda:0', grad_fn=<UnbindBackward0>)
tensor(0., device='cuda:0', grad_fn=<UnbindBackward0>)
tensor(0., device='cuda:0', grad_fn=<UnbindBackward0>)


In [27]:
PPE_DATA.show_img(img[0], preds[:, :5], rect_coords_centered=False, output_file="output2.png")

tensor(56., device='cuda:0', grad_fn=<UnbindBackward0>)
tensor(0., device='cuda:0', grad_fn=<UnbindBackward0>)
tensor(0., device='cuda:0', grad_fn=<UnbindBackward0>)
tensor(0., device='cuda:0', grad_fn=<UnbindBackward0>)
tensor(0., device='cuda:0', grad_fn=<UnbindBackward0>)
tensor(0., device='cuda:0', grad_fn=<UnbindBackward0>)
tensor(0., device='cuda:0', grad_fn=<UnbindBackward0>)
tensor(0., device='cuda:0', grad_fn=<UnbindBackward0>)
tensor(0., device='cuda:0', grad_fn=<UnbindBackward0>)
