In [1]:
import argparse
import cv2
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
from PIL import Image
import pickle
from skimage.transform import resize
import sys
import torch
import torchvision
from torchvision.transforms import transforms as transforms
%matplotlib inline

In [3]:
# initialize the model
model = torchvision.models.detection.keypointrcnn_resnet50_fpn(
    pretrained=True, num_keypoints=17
)
# set the computation device
# device = torch.device(0)
device = torch.device('cpu')

# load the model on to the computation device and set to eval mode
model.to(device).eval()

# transform to convert the image to tensor
transform = transforms.Compose([
    transforms.ToTensor()
])

In [4]:
image_files = sorted(glob.glob('../data/raw/*.jpg'))

out_dir = '../data/figures/'
if not os.path.exists(out_dir):
    os.mkdir(out_dir)


all_outputs = []

for i, fname in enumerate(image_files):
    sys.stdout.write('\r\t%d\t' % i)
    
    try:
        image = Image.open(fname).convert('RGB')

        # NumPy copy of the image for OpenCV functions
        img = np.array(image, dtype=np.float32) / 255

        # transform the image
        image = transform(image)

        # add a batch dimension
        image = image.unsqueeze(0).to(device)

        with torch.no_grad():
            output = model(image)

        H, W, _ = img.shape

        boxes = output[0]['boxes'].cpu()
        scores = output[0]['scores'].detach().cpu().numpy()

        for j, b in enumerate(boxes):
            if scores[j] < 0.9:
                continue
            if output[0]['keypoints_scores'][0].min() < 0:
                continue
                
            x0, y0, x1, y1 = b
            w = x1 - x0
            h = y1 - y0

            i0 = max(int(x0 - w/4), 0)
            i1 = min(int(x1 + w/4), W)

            j0 = max(int(y0 - h/4), 0)
            j1 = min(int(y1 + h/4), H)

            im = img[j0:j1, i0:i1]

            outfile = out_dir + fname.split('/')[-1].rstrip('.jpg') + '_%04d' % j + '.pkl'

            # resize to square
            img = resize(img, (128, 128))

            # save image with metadata
            img_data = {
                'img': img,
                ''
            }
            
            with open(outfile, 'wb') as fout:
                pickle.dump(im, fout)

            all_outputs.append({
                'fname': fname,
                'fig_index': j,
                'score': scores[j],
                'key_scores': output[0]['keypoints_scores'][j].tolist(),
                'points': output[0]['keypoints'][j].tolist(),
                'box': output[0]['boxes'][j].tolist()
            })
            pd.DataFrame(all_outputs).to_pickle('all_figures.pkl')
            break
            
            
    except Exception as e:
        print(e)
        continue

	0	

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


	20	

KeyboardInterrupt: 