## Import statements and constants

In [58]:
import json
import os
from collections import Counter

vid_keep_thresh = 0.6
class_keep_thresh = 50
person_detection_filepath = '/n/fs/visualai-scr/felixy/Classes/COS529/COS529_Project/YOLO/person_fraction.txt'

txt_file_dir = '/n/fs/visualai-scr/vramaswamy/COS529_project/data/UCF-101/ucfTrainTestlist'
txt_filepaths = [os.path.join(txt_file_dir, x) for x in ['testlist01.txt', 'trainlist01.txt']]

json_filepaths = ['/n/fs/visualai-scr/Data/UCF101Images/ucfTrainTestlist/ucf101_01.json'.format(x) for x in ['', 'Greyed']]

## Filter out videos and classes

In [18]:
def get_class_from_name(vid_name):
    return vid_name.split('_')[1]

def get_person_fractions(filepath):
    person_fraction_dict = {}
    with open(person_detection_filepath) as f:
        for line in f:
            vid_name, fraction = line.strip().split()
            person_fraction_dict[vid_name] = float(fraction)
    return person_fraction_dict

def get_name_blacklist(person_fraction_dict):
    print('Removing names with less than {}% frames kept'.format(vid_keep_thresh*100))
    print('-'*40)
    blacklist = []
    count_total = Counter()
    count_remove = Counter()
    count_keep = Counter()
    for vid_name, fraction in person_fraction_dict.items():
        vid_class = get_class_from_name(vid_name)
        count_total[vid_class] += 1
        if fraction < vid_keep_thresh:
            blacklist.append(vid_name)
            count_remove[vid_class] += 1
        else:
            count_keep[vid_class] += 1
    print('Total number of videos removed: {}\n'.format(len(blacklist)))
    return blacklist, (count_total, count_remove, count_keep)

def get_class_blacklist(count_remove):
    print('Removing classes with less than {} videos'.format(class_keep_thresh))
    print('-'*40)
    class_blacklist = []
    vid_remove_count = 0
    for class_name, c in counts[2].items():
        if c < class_keep_thresh:
            class_blacklist.append(class_name)
            vid_remove_count += c
    print('Total number of classes removed: {}'.format(len(class_blacklist)))
    print('Total number of videos removed: {}'.format(vid_remove_count))
    return class_blacklist

In [19]:
person_fraction_dict = get_person_fractions(person_detection_filepath)
name_blacklist, counts = get_name_blacklist(person_fraction_dict)
class_blacklist = get_class_blacklist(counts[2])

Removing names with less than 60.0% frames kept
----------------------------------------
Total number of videos removed: 3372

Removing classes with less than 50 videos
----------------------------------------
Total number of classes removed: 16
Total number of videos removed: 460


In [56]:
13320 - 3372 - 460

9488

## Filter from files

In [75]:
def filter_from_txt(filepath, name_blacklist, class_blacklist):
    filepath_split = os.path.split(filepath)
    out_filepath = os.path.join(filepath_split[0], 'filtered_{}'.format(filepath_split[-1]))
    out_file = open(out_filepath, 'w')
    
    print('Filtering videos for file {}'.format(filepath_split[-1]))
    print('-'*40)
    num_before = 0
    num_after = 0
    with open(filepath, 'r') as in_file:
        for line in in_file:
            num_before += 1
            path = line.strip().split(' ')[0]
            class_name, vid_name = os.path.split(path)
            vid_name = vid_name[:-4]
            if not (class_name in class_blacklist or vid_name in name_blacklist):
                out_file.write(line)
                num_after += 1
    print('Number of videos before filtering: {}'.format(num_before))
    print('Number of videos after filtering: {}\n'.format(num_after))
    out_file.close()
    
def filter_from_json(filepath, name_blacklist, class_blacklist):
    filepath_split = os.path.split(filepath)
    out_filepath = os.path.join(filepath_split[0], 'filtered_{}'.format(filepath_split[-1]))
    print('Filtering videos for file {}'.format(filepath_split[-1]))
    print('-'*40)
    num_train = [0, 0] # [Before, After]
    num_test = [0, 0] # [Before, After]
    data = json.load(open(filepath))
    database = data['database']
    new_database = {}
    for vid_name in database.keys():
        class_name = get_class_from_name(vid_name)
        split = database[vid_name]['subset']
        if split == 'training': num_train[0] += 1
        elif split == 'validation': num_test[0] += 1
        if not (class_name in class_blacklist or vid_name in name_blacklist):
            new_database[vid_name] = database[vid_name]
            if split == 'training': num_train[1] += 1
            elif split == 'validation': num_test[1] += 1
    data['database'] = new_database
    print('Number of training videos before/after filtering: {}/{}'.format(num_train[0], num_train[1]))
    print('Number of testing videos before/after filtering: {}/{}'.format(num_test[0], num_test[1]))
    print('Total number of new samples: {}\n'.format(len(new_database)))
    with open(out_filepath, 'w') as out_f:
        json.dump(data, out_f)

In [54]:
for txt_filepath in txt_filepaths:
    filter_from_txt(txt_filepath, name_blacklist, class_blacklist)

Filtering videos for file testlist01.txt
----------------------------------------
Number of videos before filtering: 3783
Number of videos after filtering: 2717

Filtering videos for file trainlist01.txt
----------------------------------------
Number of videos before filtering: 9537
Number of videos after filtering: 6771



In [76]:
for json_filepath in json_filepaths:
    filter_from_json(json_filepath, name_blacklist, class_blacklist)

Filtering videos for file ucf101_01.json
----------------------------------------
Number of training videos before/after filtering: 9537/6771
Number of testing videos before/after filtering: 3783/2717
Total number of new samples: 9488

Filtering videos for file ucf101_01.json
----------------------------------------
Number of training videos before/after filtering: 9537/6771
Number of testing videos before/after filtering: 3783/2717
Total number of new samples: 9488



In [77]:
data = json.load(open('/n/fs/visualai-scr/Data/UCF101Images/ucfTrainTestlist/ucf101_01.json'))

In [79]:
print(len(data['database']))

13320


In [80]:
data2 = json.load(open('/n/fs/visualai-scr/Data/UCF101Images/ucfTrainTestlist/filtered_ucf101_01.json'))

In [81]:
data2

{'labels': ['ApplyEyeMakeup',
  'ApplyLipstick',
  'Archery',
  'BabyCrawling',
  'BalanceBeam',
  'BandMarching',
  'BaseballPitch',
  'Basketball',
  'BasketballDunk',
  'BenchPress',
  'Biking',
  'Billiards',
  'BlowDryHair',
  'BlowingCandles',
  'BodyWeightSquats',
  'Bowling',
  'BoxingPunchingBag',
  'BoxingSpeedBag',
  'BreastStroke',
  'BrushingTeeth',
  'CleanAndJerk',
  'CliffDiving',
  'CricketBowling',
  'CricketShot',
  'CuttingInKitchen',
  'Diving',
  'Drumming',
  'Fencing',
  'FieldHockeyPenalty',
  'FloorGymnastics',
  'FrisbeeCatch',
  'FrontCrawl',
  'GolfSwing',
  'Haircut',
  'Hammering',
  'HammerThrow',
  'HandstandPushups',
  'HandstandWalking',
  'HeadMassage',
  'HighJump',
  'HorseRace',
  'HorseRiding',
  'HulaHoop',
  'IceDancing',
  'JavelinThrow',
  'JugglingBalls',
  'JumpingJack',
  'JumpRope',
  'Kayaking',
  'Knitting',
  'LongJump',
  'Lunges',
  'MilitaryParade',
  'Mixing',
  'MoppingFloor',
  'Nunchucks',
  'ParallelBars',
  'PizzaTossing',
  '