In [None]:
import os
import sys
sys.path.append('..')
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import numpy as np
from tqdm.notebook import tqdm
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from glob import glob
from imageio import imsave
import cv2
import torch

from abcmodel.libs.cvat import JOINT_LABELS
from higher_hrnet.libs.utils import read_tarfile
from higher_hrnet.libs.predictor import Refine
from higher_hrnet.models.higher_hrnet import HigherHRNet
from higher_hrnet.libs.transforms import InferenceTransform
from higher_hrnet.libs.clustering import SpatialClustering
# #torch.set_grad_enabled(False)

## Load Higher HRNet(for pig instance segmentation/keypoint estimation)

In [None]:
#seg_key_model = TRTModule()
#seg_key_model.load_state_dict(torch.load('../../seg_key_model_trt_v0.0.2.pth'))
#seg_key_model = Refine(seg_key_model,
#                       joint_labels=JOINT_LABELS,
#                       average_tag=True)

seg_key_model = HigherHRNet(num_keypoints=13,
                            num_seg_classes=2,
                            dim_tag=5).eval()
seg_key_model.load_state_dict(torch.load(
    os.path.join('/workspace', 'pig', 'model', 'seg_key_model.pth')))
seg_key_model = seg_key_model.cuda().eval()
seg_key_model = Refine(seg_key_model, JOINT_LABELS, average_tag=True)

inference_transform = InferenceTransform(input_size=480)
clustering = SpatialClustering(threshold=0.05, min_pixels=20, margin=.5)

In [None]:
def predict(img):
    x, t_inv = inference_transform(img)
    x = x.unsqueeze(0).cuda()
    
    seg_pred, hm_preds, tag_preds = seg_key_model(x)
    hr_hm = hm_preds[1].cpu()
    seed = torch.sigmoid(tag_preds[0, -1]).cpu()
    instance_map = clustering(tag_preds)
    instance_map = instance_map.cpu().squeeze()
    
    ins_map = instance_map.numpy()
    ins_map = cv2.resize(ins_map, (640,480), interpolation = cv2.INTER_NEAREST)
    seg = seg_pred.softmax(dim=1)[0,1].cpu().numpy()
    hms = hm_preds[1][0].cpu().numpy()
    
    return ins_map, seg, hms

## Read Dataset

In [None]:
DIR = '/workspace/pig/data/'
data_dir = os.path.join(DIR, 'tar')
save_dir = os.path.join(DIR, 'images_20211125')

In [None]:
date_id = 20211125
sorted(os.listdir(data_dir + '/%s'%date_id))

In [None]:
room_id = '18-21-R8'
os.makedirs(save_dir + '/%s/%s'%(date_id, room_id), exist_ok=True)

file_list = sorted(glob(data_dir + '/%s/%s/*'%(date_id, room_id)))
print(len(file_list))

## Check data and inference result

find the pig with the ID in his body 

In [None]:
idx = 50
file_path = file_list[idx]

fp = file_path.split('/')[-1]
data_id = fp.split('Z')[0]
rgb, depth = read_tarfile(file_path)
depth = cv2.resize(depth, (640,480))

fig, ((ax1, ax2)) = plt.subplots(ncols=2, figsize=(12,12))

ax1.imshow(rgb)
ax1.set_title('RGB')

ax2.imshow(depth)
ax2.set_title('Depth')

In [None]:
ins_map, seg, hms = predict(rgb)

hms = cv2.resize(hms.transpose(1,2,0), (640,480)).transpose(2,0,1)
for i in np.unique(ins_map):
    if i != 0:
        mask = cv2.resize((ins_map==i).astype(np.uint8), (640,480))
        y, x = np.where(mask==1)
        xmin, ymin = np.min([x,y], axis=1)
        xmax, ymax = np.max([x,y], axis=1)
        
        file_id = data_id+'_'+str(xmin)+'_'+str(ymin)+'_'+str(xmax)+'_'+str(ymax) + '.png'
        
        print(file_id, (xmax-xmin)*(ymax-ymin))
        plt.imshow(rgb[ymin:ymax,xmin:xmax])
        plt.show()

In [None]:
plt.imshow(ins_map)

## Save image

In [None]:
for file_path in file_list[int(idx):]:
    fp = file_path.split('/')[-1]
    data_id = fp.split('Z')[0]
    print(fp)
    
    rgb, _ = read_tarfile(file_path)                
    ins_map, seg, hms = predict(rgb)

    hms = cv2.resize(hms.transpose(1,2,0), (640,480)).transpose(2,0,1)
    for i in np.unique(ins_map):
        if i != 0:
            mask = cv2.resize((ins_map==i).astype(np.uint8), (640,480))
            y, x = np.where(mask==1)
            xmin, ymin = np.min([x,y], axis=1)
            xmax, ymax = np.max([x,y], axis=1)

            if (xmax-xmin)*(ymax-ymin)>10000:
            
                file_id = data_id+'_'+str(xmin)+'_'+str(ymin)+'_'+str(xmax)+'_'+str(ymax) + '.png'

#                 imsave(save_dir + '/%s/%s/'%(date_id, room_id) + file_id, rgb[ymin:ymax,xmin:xmax])