In [1]:
import os
import pickle
import uuid
import glob 
from tqdm import tqdm
import numpy as np
from PIL import Image
from maskrcnn_benchmark.config import cfg
from maskrcnn_benchmark.utils.modanetDrawer_test import ModaNetDrawerTest
from PIL import Image
from pycocotools.mask import encode, decode
import matplotlib.pyplot as plt

In [2]:
BASE_DIR = '/data/hautell-looks/'
GARMENT_BASE_DIR = '/data/hautell-looks-garments'

In [6]:
images = glob.glob('/data/hautell-looks/spring-rtw-2019/**/*.jpg')

In [7]:
cfg.merge_from_file('../configs/modanet/modanet-8gpu.yaml')
cfg.merge_from_list(['MODEL.WEIGHT', '/data/logs/modanet/modanet-8gpu/model_0035000.pth'])
drawer = ModaNetDrawerTest(cfg, min_image_size=800, confidence_threshold=0.7)

In [8]:
for image_path in tqdm(images) :
    collection = image_path.split('/')[3]
    designer = image_path.split('/')[4]
    look_num = os.path.basename(image_path).split('.')[0]
    
    LOOKDIR = os.path.join(GARMENT_BASE_DIR, collection, designer, look_num)
    if not os.path.exists(LOOKDIR) : 
        os.makedirs(LOOKDIR)
    
    image = np.array(Image.open(image_path).convert('RGB'))
    result, top_predictions = drawer.run_on_opencv_image(image)

    bbox = top_predictions.bbox.numpy()
    labels = top_predictions.get_field("labels").numpy().tolist()
    masks = top_predictions.get_field('mask').numpy()
    scores = top_predictions.get_field("scores").numpy()
    
    predictions = {}
    for pred_ix in range(0, len(top_predictions)) :
        y1, x1, y2, x2 = map(int, bbox[pred_ix, :])
        garment = result[x1:x2, y1:y2, :]
        mask = masks[pred_ix, 0, :]
        rle = encode(np.asfortranarray(mask))
        rle['counts'] = rle['counts']
        garment_id = str(uuid.uuid4())
        predictions[garment_id] = {'mask_encoded': rle, 
                                   'label': labels[pred_ix], 
                                   'bbox': bbox[pred_ix, :].tolist(),
                                    'score': scores[pred_ix]}
        garment_save_path = os.path.join(LOOKDIR, str(uuid.uuid4())+'.jpg')

        Image.fromarray(garment).save(garment_save_path)

    with open(os.path.join(LOOKDIR, 'predictions.pkl'), 'wb') as f :
        pickle.dump(predictions, f)

100%|██████████| 8900/8900 [2:31:28<00:00,  1.09it/s]  
