In [1]:
import torch.hub
import torchvision
import numpy as np
from PIL import Image, ImageOps
import os

In [2]:
repo = 'epic-kitchens/action-models'

class_counts = (125, 352)
segment_count = 8
base_model = 'resnet50'

tsn = torch.hub.load(repo, 'TSN', class_counts, segment_count, 'RGB',
                     base_model=base_model, 
                     pretrained='epic-kitchens', force_reload=True)
trn = torch.hub.load(repo, 'TRN', class_counts, segment_count, 'RGB',
                     base_model=base_model, 
                     pretrained='epic-kitchens')
mtrn = torch.hub.load(repo, 'MTRN', class_counts, segment_count, 'RGB',
                     base_model=base_model, 
                      pretrained='epic-kitchens')
tsm = torch.hub.load(repo, 'TSM', class_counts, segment_count, 'RGB',
                     base_model=base_model, 
                     pretrained='epic-kitchens')

Downloading: "https://github.com/epic-kitchens/action-models/archive/master.zip" to /home/dimitri/.cache/torch/hub/master.zip
Using cache found in /home/dimitri/.cache/torch/hub/epic-kitchens_action-models_master
Using cache found in /home/dimitri/.cache/torch/hub/epic-kitchens_action-models_master


Multi-Scale Temporal Relation Network Module in use ['8-frame relation', '7-frame relation', '6-frame relation', '5-frame relation', '4-frame relation', '3-frame relation', '2-frame relation']


Using cache found in /home/dimitri/.cache/torch/hub/epic-kitchens_action-models_master


In [3]:
# Transforms
class GroupScale(object):
    """ Rescales the input PIL.Image to the given 'size'.
    'size' will be the size of the smaller edge.
    For example, if height > width, then image will be
    rescaled to (size * height / width, size)
    size: size of the smaller edge
    interpolation: Default: PIL.Image.BILINEAR
    """

    def __init__(self, size, interpolation=Image.BILINEAR):
        self.worker = torchvision.transforms.Scale(size, interpolation)

    def __call__(self, img_group):
        return [self.worker(img) for img in img_group]
       

class GroupCenterCrop(object):
    def __init__(self, size):
        self.worker = torchvision.transforms.CenterCrop(size)

    def __call__(self, img_group):
        return [self.worker(img) for img in img_group]
    

class Stack(object):    
    def __call__(self, img_group):
        return np.array([np.array(i) for i in img_group])
    

class ToTorchFormatTensor(object):    
    def __call__(self, pic):
        img = torch.from_numpy(pic).permute(0, 3, 1, 2).contiguous()
        return img.float().div(255) 
    

class GroupNormalize(object):
    def __init__(self, mean, std):
        self.mean = mean
        self.std = std

    def __call__(self, tensor):
        rep_mean = self.mean * (tensor.size()[0]//len(self.mean))
        rep_std = self.std * (tensor.size()[0]//len(self.std))

        # TODO: make efficient
        for t, m, s in zip(tensor, rep_mean, rep_std):
            t.sub_(m).div_(s)

        return tensor

In [4]:
net = tsn
scale_size = net.scale_size
crop_size = net.input_size

transform=torchvision.transforms.Compose([
                       GroupScale(int(scale_size)),
                       GroupCenterCrop(crop_size),
                       Stack(),
                       ToTorchFormatTensor(),
                       GroupNormalize(net.input_mean, net.input_std),
                   ])



In [5]:
dir_name = os.path.join("data", "P01_11")
segment_size = 30
segment_num = 8
step = 30
frames_n = len(os.listdir(dir_name))
for k in range(1, frames_n - segment_num * segment_size, step):
    image_paths = [os.path.join(dir_name, "frame_{}.jpg".format(str(k + i * segment_size).zfill(10))) for i in range(segment_num)]
    images = [Image.open(image_p).convert('RGB') for image_p in image_paths]
    inputs = transform(images)
    features = tsn.features(inputs)
    verb_logits, noun_logits = tsn.logits(features)
    print("{}.{}:".format((k // segment_size) // 60, (k // segment_size) % 60), torch.argmax(noun_logits).item(), torch.argmax(verb_logits).item())

0.0: 1 1
0.1: 1 0
0.2: 1 1
0.3: 1 1
0.4: 1 1
0.5: 1 1
0.6: 6 1
0.7: 1 1
0.8: 6 1
0.9: 1 1
0.10: 4 1
0.11: 4 1
0.12: 1 1
0.13: 1 0
0.14: 7 1
0.15: 1 1
0.16: 1 1
0.17: 1 0
0.18: 4 1
0.19: 4 1
0.20: 1 1
0.21: 7 0
0.22: 1 1
0.23: 1 1
0.24: 1 1
0.25: 1 1
0.26: 1 1
0.27: 4 1
0.28: 1 1
0.29: 6 0
0.30: 5 1
0.31: 6 1
0.32: 1 1
0.33: 1 1
0.34: 1 1
0.35: 1 1
0.36: 6 1


KeyboardInterrupt: 