In [1]:
import torch, torchvision
import torchvision.transforms as T
import matplotlib.pyplot as plt
from PIL import Image
import glob
import io
import matplotlib.pyplot as plt

In [7]:
finetuned_model = torch.hub.load('facebookresearch/detr',
                       'detr_resnet50',
                       pretrained=False,
                       num_classes=1) 

checkpoint = torch.load('./detr/outputs/checkpoint.pth', map_location='cpu')
finetuned_model.load_state_dict(checkpoint['model'], strict=False)
finetuned_model.eval()

Using cache found in /root/.cache/torch/hub/facebookresearch_detr_master


DETR(
  (transformer): Transformer(
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=2048, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (linear2): Linear(in_features=2048, out_features=256, bias=True)
          (norm1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.1, inplace=False)
          (dropout2): Dropout(p=0.1, inplace=False)
        )
        (1): TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): Linear(in_features=256, out_features=256, bias=True)
          )
          (linear1): Linear(in_features=256, out_features=2048, bias=True)
          (dropout): Drop

In [8]:
labels = [
  'racket'
]

transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
COLORS = [[0.000, 0.447, 0.741], [0.850, 0.325, 0.098], [0.929, 0.694, 0.125],
          [0.494, 0.184, 0.556], [0.466, 0.674, 0.188], [0.301, 0.745, 0.933]]


In [9]:
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
    return b

def filter_bboxes_from_outputs(im,outputs, threshold):
    probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
    keep = probas.max(-1).values > threshold
    probas_to_keep = probas[keep]
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
    return probas_to_keep, bboxes_scaled

def plot_result(folder_path,filename,pil_img, prob=None, boxes=None, labels=None):
    plt.figure(figsize=(16, 10))
    plt.imshow(pil_img)
    ax = plt.gca()
    colors = COLORS * 100
    if prob is not None and boxes is not None:
        for p, (xmin, ymin, xmax, ymax),c in zip(prob, boxes.tolist(),colors):
            cl = p.argmax()
            ax.add_patch(plt.Rectangle((xmin, ymin), xmax-xmin, ymax-ymin,
                                       fill=False, color=c, linewidth=3))
            text = f'{labels[cl]}: {p[cl]:0.2f}'
            ax.text(xmin, ymin, text, fontsize=8,bbox=dict(facecolor='yellow', alpha=0.4))
    plt.axis('off')
    plt.savefig(folder_path+"/"+filename.split("/")[-1]+"_detection.jpg", bbox_inches='tight',pad_inches=0)

# 物体検出
def inference(folder_path,filename,input_image, model, labels, threshold):

  img = transform(input_image).unsqueeze(0)

  outputs =model(img)

  probas_to_keep, bboxes_scaled = filter_bboxes_from_outputs(input_image,outputs, threshold=threshold)

  plot_result(folder_path,filename,input_image, probas_to_keep, bboxes_scaled, labels)

In [None]:
folder_path="./data"
paths=sorted(glob.glob(folder_path+"/test/*.jpg"))

folder_path+="/result"

for path in paths:
    filename=path.split("jpg")[0][:-1]
    print(filename)
    input_image = Image.open(filename+".jpg").convert('RGB')
    inference(folder_path,filename,input_image,finetuned_model,labels, threshold=0.6)