In [1]:
import torch
import numpy as np
from PIL import Image
import cv2, pandas as pd, youtube_dl, subprocess
from torch.utils.data import Dataset, DataLoader, ConcatDataset
from torchvision import transforms

In [5]:

class CountixDataset(Dataset):
    
    def __init__(self, df_path, framesPerVid):
        # sort the dataframe
        self.df = pd.read_csv(df_path)
        self.framesPerVid = framesPerVid

    def __getitem__(self, index):
        id = self.df.loc[index, 'video_id']
        start = self.df.loc[index, 'repetition_start']
        end = self.df.loc[index, 'repetition_end']
        count = self.df.loc[index, 'count']
        
        X, newStart, newEnd = self.generateRepVid(start,end,id,index)
        durationScaleFac = (newEnd-newStart)/self.framesPerVid
        repStartFrame = int((newStart- start)/durationScaleFac)
        repEndFrame = int((newStart- end)/durationScaleFac)
        
        periodicity = np.zeros((self.framesPerVid,1)) 
        for i in range(self.framesPerVid):
            if repStartFrame< i <repEndFrame:
                periodicity[i] = False


        periodLength = np.zeros((self.framesPerVid,1))
        for i in range(self.framesPerVid):
            if repStartFrame< i <repEndFrame:
                periodLength[i] = (repEndFrame-repEndFrame)/count

        return X, periodicity, periodLength

    def generateRepVid(self, start, end, id, index):

        #TODO: include extra frames from both sides
        extendBegin = True
        if self.df.loc[max(0,index-1), 'video_id']== self.df.loc[index, 'video_id']:
            if int(self.df.loc[max(0,index-1), 'repetition_end']) == int(self.df.loc[index, 'repetition_start']):
                extendBegin = False


        extendEnd = True
        if self.df.loc[min(len(self)-1,index+1), 'video_id']==self.df.loc[index, 'video_id']:
            if int(self.df.loc[index, 'repetition_end'])==int(self.df.loc[min(len(self)-1,index+1), 'repetition_start']):
                extendEnd = False

        newStart =  start
        if extendBegin:
            newStart = max(start - 1.5, self.df.loc[index, 'kinetics_start'])
            
        newEnd = end
        if extendEnd:
            newEnd = min(end  + 1.5, self.df.loc[index, 'kinetics_end'])

        url = 'https://www.youtube.com/watch?v='+id
        path_to_video = 'video_to_train'+id+'.mp4'
        fps = self.framesPerVid//(end-start) + 1
        
        if os.path.exists(path_to_video):
            os.remove(path_to_video)

        opts = {'format': 'worst',
                'quiet':True,
                }
        
        while(True):
            with youtube_dl.YoutubeDL(opts) as ydl:
                result=ydl.extract_info(url, download=False)
                video=result['entries'][0] if 'entries' in result else result
            url = video['url']
            subprocess.call('ffmpeg -i "%s" -ss %s -to %s -vf "setpts=%s*PTS" -r %s "%s"' % 
                                        (url, newStart, newEnd, 30/fps, fps, path_to_video), shell=True)

            if os.path.exists(path_to_video):
                break

        frames = []
        cap = cv2.VideoCapture(path_to_video)
        while cap.isOpened():
            ret, frame = cap.read()
            if ret is False:
                break
                
            img = Image.fromarray(frame)
            preprocess = transforms.Compose([
                transforms.Resize((112, 112), 2),
                transforms.ToTensor(),       
                transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                    std=[0.229, 0.224, 0.225])])
            frameTensor = preprocess(img).unsqueeze(0)
            frames.append(frameTensor)
        
        cap.release()

        if os.path.exists(path_to_video):
            os.remove(path_to_video)

        se = len(frames)//self.framesPerVid
        frames = [frame for i, frame in enumerate(frames) if i%se==0]
        frames = frames[:self.framesPerVid]
        frames = torch.cat(frames)
        return frames, newStart, newEnd

    def __len__(self):
        return len(self.df)

print("done")

done


In [6]:
ds = CountixDataset('countix/countix_train.csv', 64)
for i in range(4):
    X, y1, y2 = ds[i]
    print("frames:",X.shape,"dur:",y1.shape,"reps:",y2.shape)

frames: torch.Size([64, 3, 112, 112]) dur: (64, 1) reps: (64, 1)
frames: torch.Size([64, 3, 112, 112]) dur: (64, 1) reps: (64, 1)
frames: torch.Size([64, 3, 112, 112]) dur: (64, 1) reps: (64, 1)
frames: torch.Size([64, 3, 112, 112]) dur: (64, 1) reps: (64, 1)


In [7]:
trainLoader = DataLoader(ds, batch_size = 4,num_workers=4)
X, y1, y2 = next(iter(trainLoader))


In [8]:
print(X.shape)
print(y1.shape)
print(y2.shape)

torch.Size([4, 64, 3, 112, 112])
torch.Size([4, 64, 1])
torch.Size([4, 64, 1])
