In [1]:
import sys
sys.path.insert(1, '/temporal-shift-module/online_demo')

#from mobilenet_v2_tsm_test import MobileNetV2
from arch_mobilenetv2 import MobileNetV2

from PIL import Image
import urllib.request
import os
import torch
import torchvision
import numpy as np
import cv2
import time


In [33]:


class GroupScale(object):
    """ Rescales the input PIL.Image to the given 'size'.
    'size' will be the size of the smaller edge.
    For example, if height > width, then image will be
    rescaled to (size * height / width, size)
    size: size of the smaller edge
    interpolation: Default: PIL.Image.BILINEAR
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        self.worker = torchvision.transforms.Scale(size, interpolation)

    def __call__(self, img_group):
        return [self.worker(img) for img in img_group]



class GroupCenterCrop(object):
    def __init__(self, size):
        self.worker = torchvision.transforms.CenterCrop(size)

    def __call__(self, img_group):
        return [self.worker(img) for img in img_group]



class Stack(object):

    def __init__(self, roll=False):
        self.roll = roll

    def __call__(self, img_group):
        if img_group[0].mode == 'L':
            return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2)
        
        elif img_group[0].mode == 'RGB':
            if self.roll:
                return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2)
            else:
                return np.concatenate(img_group, axis=2)



class ToTorchFormatTensor(object):
    """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255]
    to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """

    def __init__(self, div=True):
        self.div = div

    def __call__(self, pic):
        if isinstance(pic, np.ndarray):
            # handle numpy array
            img = torch.from_numpy(pic).permute(2, 0, 1).contiguous()
        else:
            # handle PIL Image
            img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes()))
            img = img.view(pic.size[1], pic.size[0], len(pic.mode))
            # put it from HWC to CHW format
            # yikes, this transpose takes 80% of the loading time/CPU
            img = img.transpose(0, 1).transpose(0, 2).contiguous()
        return img.float().div(255) if self.div else img.float()



class GroupNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        rep_mean = self.mean * (tensor.size()[0] // len(self.mean))
        rep_std = self.std * (tensor.size()[0] // len(self.std))

        # TODO: make efficient
        for t, m, s in zip(tensor, rep_mean, rep_std):
            t.sub_(m).div_(s)

        return tensor


def process_output(idx_, history, num_classes):
    # idx_: the output of current frame
    # history: a list containing the history of predictions
    if not REFINE_OUTPUT:
        return idx_, history

    max_hist_len = int((20/27)*num_classes) # max history buffer

    # mask out illegal action
    
    if num_classes == 27:
        if idx_ in [7, 8, 21, 22, 1, 3]:
            idx_ = history[-1]

        if idx_ == 0:
            idx_ = 2

    # use only single no action class
    elif num_classes == 3: 
        if idx_ in [2]:
            idx_ = history[-1]
        
        if idx_ == 0:
            idx_ = 0
    
    # history smoothing

    if idx_ != history[-1] and len(history) != 1:
        if not (history[-1] == history[-2]): #  and history[-2] == history[-3]):
            idx_ = history[-1]
    

    history.append(idx_)
    history = history[-max_hist_len:]

    return history[-1], history


def get_categories(num_classes):

    if num_classes == 27:
        catigories = [
        "Doing other things",  # 0
        "Drumming Fingers",  # 1
        "No gesture",  # 2
        "Pulling Hand In",  # 3
        "Pulling Two Fingers In",  # 4
        "Pushing Hand Away",  # 5
        "Pushing Two Fingers Away",  # 6
        "Rolling Hand Backward",  # 7
        "Rolling Hand Forward",  # 8
        "Shaking Hand",  # 9
        "Sliding Two Fingers Down",  # 10
        "Sliding Two Fingers Left",  # 11
        "Sliding Two Fingers Right",  # 12
        "Sliding Two Fingers Up",  # 13
        "Stop Sign",  # 14
        "Swiping Down",  # 15
        "Swiping Left",  # 16
        "Swiping Right",  # 17
        "Swiping Up",  # 18
        "Thumb Down",  # 19
        "Thumb Up",  # 20
        "Turning Hand Clockwise",  # 21
        "Turning Hand Counterclockwise",  # 22
        "Zooming In With Full Hand",  # 23
        "Zooming In With Two Fingers",  # 24
        "Zooming Out With Full Hand",  # 25
        "Zooming Out With Two Fingers"  # 26
    ]

    elif num_classes == 9: 

        catigories = ["Fall", "SalsaSpin", "Taichi", "WallPushups", "WritingOnBoard", "Archery", "Hulahoop", "Nunchucks", "WalkingWithDog"]
    
    elif num_classes == 10:

        catigories = ["Fall", "SalsaSpin", "Taichi", "WallPushups", "WritingOnBoard", "Archery", "Hulahoop", "Nunchucks", "WalkingWithDog", "test"]

    elif num_classes == 3 :

        catigories = ['Fall', "Not Fall", "Test"]

    elif num_classes == 2:

        catigories = ["Fall", "Not Fall"]

    return catigories




In [35]:
def main(num_classes):


    if num_classes not in [2, 3, 9, 10, 27]:
        return "Can only handle 2, 10, and 27 classes"

    else:
        catigories = get_categories(num_classes)

    cropping = torchvision.transforms.Compose([
        GroupScale(256),
        GroupCenterCrop(224),
    ])


    transform = torchvision.transforms.Compose([
        cropping,
        Stack(roll=False),
        ToTorchFormatTensor(div=True),
        GroupNormalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    ])

    from torch import nn
    torch_module = MobileNetV2(n_class=num_classes)
    #print(torch_module.state_dict().keys())
    
    
    if num_classes == 27:
        if not os.path.exists("mobilenetv2_jester_online.pth.tar"):  # checkpoint not downloaded
            print('Downloading PyTorch checkpoint...')
            url = 'https://hanlab.mit.edu/projects/tsm/models/mobilenetv2_jester_online.pth.tar'
            urllib.request.urlretrieve(url, './mobilenetv2_jester_online.pth.tar')
    

        torch_module.load_state_dict(torch.load("mobilenetv2_jester_online.pth.tar"))


    else:
        
        if num_classes == 9 or num_classes == 10:
            #KH model_new = torch.load("../../pretrained/9cat/ckpt.best.pth.tar")
            model_new = torch.load("/data/w251fall/checkpoints/9_Categories/TSM_w251fall_RGB_mobilenetv2_shift8_blockres_avg_segment8_e50/ckpt.best.pth.tar")
    
        elif num_classes == 2 or num_classes == 3:
            #KH model_new = torch.load("../../pretrained/2cat/ckpt.best.pth.tar")
            #model_new = torch.load("/data/w251fall/checkpoints/2_Categories/1_TSM_w251fall_RGB_mobilenetv2_shift8_blockres_avg_segment8_e25/ckpt.best.pth.tar")
            model_new = torch.load("/data/w251fall/checkpoints/2_Categories/TSM_w251fall_RGB_mobilenetv2_shift8_blockres_avg_segment8_e5/ckpt.best.pth.tar")
            
            
        # Fixing new model parameter mis-match
        state_dict = model_new['state_dict']
        #print(state_dict.keys())
    
        from collections import OrderedDict
        new_state_dict = OrderedDict()

        for k, v in state_dict.items():
            #name = k[7:] # remove `module.`

            if "module.base_model." in k:
                name = k.replace("module.base_model.", "")

                if ".net" in name:
                    name = name.replace(".net", "")


            elif "module." in k:
                name = k.replace("module.new_fc.", "classifier.")
        

            new_state_dict[name] = v

        # load params
        torch_module.load_state_dict(new_state_dict)


    torch_module.eval()

    #KH cap = cv2.VideoCapture(1)

    # set a lower resolution for speed up
    #KH cap.set(cv2.CAP_PROP_FRAME_WIDTH, 320)
    #KH cap.set(cv2.CAP_PROP_FRAME_HEIGHT, 240)
    
    # Load jpgs file names fully qualified
    
    
    # Falls
    #video_images_dir = '/data/w251fall/jpg/Fall/fall-25-front-urfall.val'
    #video_images_dir = '/data/w251fall/jpg/Fall/steph_2682_(1).train'
    video_images_dir = '/data/w251fall/jpg/Fall/ten_0002_(10).train'
    
    # Not Falls
    #video_images_dir = '/data/w251fall/jpg/NotFall/v_HulaHoop_g01_c01.val' # Nunchucks ✓
    #video_images_dir = '/data/w251fall/jpg/NotFall/v_HulaHoop_g02_c01.val' # Nunchucks ✓
    #video_images_dir = '/data/w251fall/jpg/NotFall/v_TaiChi_g23_c01.train' # WallPushUps ✓
    #video_images_dir = '/data/w251fall/jpg/NotFall/v_Archery_g25_c07.train' # Hulahoop ✓
    #video_images_dir = '/data/w251fall/jpg/NotFall/v_Nunchucks_g01_c01.val' # WritingOnBoard X HulaHoop X
    #video_images_dir = '/data/w251fall/jpg/NotFall/v_Nunchucks_g02_c06.val' # WritingOnBoard X
    #video_images_dir = '/data/w251fall/jpg/NotFall/v_WritingOnBoard_g05_c02.val' # Archery ✓
    #video_images_dir = '/data/w251fall/jpg/NotFall/v_WritingOnBoard_g07_c06.val' # Archery ✓
    #video_images_dir = '/data/w251fall/jpg/NotFall/v_Archery_g02_c07.val' # HulaHoop ✓ WritingOnBoard X
    
    
    jpg_filenames = [video_images_dir + '/' + s for s in os.listdir(video_images_dir)]

    
    #full_screen = False
    #KH cv2.namedWindow(WINDOW_NAME, cv2.WINDOW_NORMAL)
    #KH cv2.resizeWindow(WINDOW_NAME, 640, 480)
    #KH cv2.moveWindow(WINDOW_NAME, 0, 0)
    #KH cv2.setWindowTitle(WINDOW_NAME, WINDOW_NAME)


    shift_buffer = [torch.zeros([1, 3, 56, 56]),
                    torch.zeros([1, 4, 28, 28]),
                    torch.zeros([1, 4, 28, 28]),
                    torch.zeros([1, 8, 14, 14]),
                    torch.zeros([1, 8, 14, 14]),
                    torch.zeros([1, 8, 14, 14]),
                    torch.zeros([1, 12, 14, 14]),
                    torch.zeros([1, 12, 14, 14]),
                    torch.zeros([1, 20, 7, 7]),
                    torch.zeros([1, 20, 7, 7])]


    t = None    
    index = 0
    idx = 0
    history = [2]
    history_logit = []
    history_timing = []
    i_frame = -1

    #KH while True:
    for jpg_filename in jpg_filenames:
        
        img = cv2.imread(jpg_filename)
        
        i_frame += 1
        #KH _, img = cap.read()  # (480, 640, 3) 0 ~ 255

        if i_frame % 1 == 0:
            t1 = time.time()
            img_tran = transform([Image.fromarray(img).convert('RGB')])
            input_var = torch.autograd.Variable(img_tran.view(1, 3, img_tran.size(1), img_tran.size(2)))

            #prediction = torch_module(input_var, *shift_buffer) #demo mobilenet
            prediction = torch_module(input_var) #arch mobilenet


            feat, shift_buffer = prediction[0], prediction[1:]


            if SOFTMAX_THRES > 0:

                feat_np = feat.detach().numpy().reshape(-1)
                feat_np -= feat_np.max()

                softmax = np.exp(feat_np) / np.sum(np.exp(feat_np))

                #KH print(max(softmax))
        
                if max(softmax) > SOFTMAX_THRES:
                    idx_ = np.argmax(feat.detach().numpy(), axis=1)[0]
        
                else:
                    idx_ = idx
    
            else:
                #KH print(feat.detach().numpy())
                #idx_ = np.argmax(feat.detach().numpy(), axis=1)[0] For demo mobilenet
                idx_ = np.argmax(feat.detach().numpy()) # For archnet mobilenet


            if HISTORY_LOGIT:
                history_logit.append(feat.detach().numpy())
                history_logit = history_logit[-int(12/27*num_classes):]
                avg_logit = sum(history_logit)
                #idx_ = np.argmax(avg_logit, axis=1)[0] For demo mobilenet
                idx_ = np.argmax(avg_logit)  #For archnet mobilenet

            idx, history = process_output(idx_, history, num_classes)
            

            t2 = time.time()
            print(f"Final {index} Attempt {catigories[idx]}")

            
            current_time = t2 - t1

        
        img = cv2.resize(img, (640, 480))
        img = img[:, ::-1]
        height, width, _ = img.shape
        label = np.zeros([height // 10, width, 3]).astype('uint8') + 255

        #KH cv2.putText(label, 'Prediction: ' + catigories[idx], (0, int(height / 16)), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)
        #KH cv2.putText(label, '{:.1f} Vid/s'.format(1 / current_time), (width - 170, int(height / 16)), cv2.FONT_HERSHEY_SIMPLEX, 0.7, (0, 0, 0), 2)

        #KH img = np.concatenate((img, label), axis=0)
        #KH cv2.imshow(WINDOW_NAME, img)

        #KH key = cv2.waitKey(1)

        #KH if key & 0xFF == ord('q') or key == 27:  # exit
        #KH     break
        
        #KH elif key == ord('F') or key == ord('f'):  # full screen
        #KH     print('Changing full screen option!')
            
        #KH     full_screen = not full_screen
            
        #KH     if full_screen:
        #KH         print('Setting FS!!!')
        #KH         cv2.setWindowProperty(WINDOW_NAME, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_FULLSCREEN)
        #KH     
        #KH     else:
        #KH         cv2.setWindowProperty(WINDOW_NAME, cv2.WND_PROP_FULLSCREEN, cv2.WINDOW_NORMAL)


        if t is None:
            t = time.time()
        
        else:
            nt = time.time()
            index += 1
            t = nt


    #KH cap.release()
    #KH cv2.destroyAllWindows()



if __name__ == "__main__":
    print("Starting... \n")

    SOFTMAX_THRES = 0
    HISTORY_LOGIT = True
    REFINE_OUTPUT = True
    WINDOW_NAME = "GESTURE CAPTURE"

    #Modify number of classes here
    main(3)

    print("Done")

Starting... 

Final 0 Attempt Not Fall
Final 0 Attempt Not Fall
Final 1 Attempt Not Fall
Final 2 Attempt Not Fall
Final 3 Attempt Not Fall
Final 4 Attempt Not Fall
Final 5 Attempt Not Fall
Final 6 Attempt Not Fall
Final 7 Attempt Not Fall
Final 8 Attempt Not Fall
Final 9 Attempt Not Fall
Final 10 Attempt Not Fall
Final 11 Attempt Not Fall
Final 12 Attempt Not Fall
Final 13 Attempt Not Fall
Final 14 Attempt Not Fall
Final 15 Attempt Not Fall
Final 16 Attempt Not Fall
Final 17 Attempt Not Fall
Final 18 Attempt Not Fall
Final 19 Attempt Not Fall
Final 20 Attempt Not Fall
Final 21 Attempt Not Fall
Final 22 Attempt Not Fall
Final 23 Attempt Not Fall
Final 24 Attempt Not Fall
Final 25 Attempt Not Fall
Final 26 Attempt Not Fall
Final 27 Attempt Not Fall
Final 28 Attempt Not Fall
Final 29 Attempt Not Fall
Final 30 Attempt Not Fall
Final 31 Attempt Not Fall
Final 32 Attempt Not Fall
Final 33 Attempt Not Fall
Final 34 Attempt Not Fall
Final 35 Attempt Not Fall
Final 36 Attempt Not Fall
Final 37 