# Generate Zip Midi List and Clean
Generate a list of midi files in a zip file, and clean (remove invalid files and duplicated files).

## Settings

In [1]:
zip_file_path = '../../data/merged-dataset.zip'
suffixes = ('mid', )
midi_checker = 'default'
remove_invalid = True
remove_duplicated = True
save_path = '../../processed_data/info_note/merged-dataset/zip.merged-dataset.clean.deduplicated.txt'
num_workers=None

## Preprocessing

In [2]:
import os
from midiprocessor import midi_utils, data_utils
import zipfile
from concurrent.futures import ProcessPoolExecutor
from collections import defaultdict
from tqdm import tqdm

In [3]:
file_path_list = data_utils.get_zip_file_paths(zip_file_path, suffixes=suffixes)
num_files = len(file_path_list)
print('%d files' % num_files)
for item in file_path_list[:10]:
    print(item)

2111510 files
bitmidi.com/000001.mid
bitmidi.com/000002.mid
bitmidi.com/000003.mid
bitmidi.com/000004.mid
bitmidi.com/000005.mid
bitmidi.com/000006.mid
bitmidi.com/000007.mid
bitmidi.com/000008.mid
bitmidi.com/000009.mid
bitmidi.com/000010.mid


## Processing

In [4]:
def batch_process_zip(zip_file_path, file_path_list, midi_checker='default', order=0):
    file_list = set()
    invalid_file_list = set()
    md5_dict = defaultdict(list)

    with zipfile.ZipFile(zip_file_path, 'r') as zip_obj:
        for file_path in tqdm(file_path_list, position=order):
            try:
                with zip_obj.open(file_path, 'r') as f:
                    md5 = data_utils.get_md5_sum(file_obj=f)
                    md5_dict[md5].append(file_path)

                    f.seek(0)
                    midi_obj = midi_utils.load_midi(file=f, midi_checker=midi_checker)
            except:
                invalid_file_list.add(file_path)
                # Todo
                # if not remove_invalid:
                #     file_list.append(relative_path)
            else:
                file_list.add(file_path)

    return file_list, invalid_file_list, md5_dict

In [5]:
def split_list(file_path_list, num):
    each_length = num_files // num_workers + 1
    batch_file_lists = []

    left = 0
    while True:
        right = min(num_files, left + each_length)
        temp_file_list = file_path_list[left: right]
        batch_file_lists.append(temp_file_list)
        left = right
        if left >= num_files:
            break
    assert len(batch_file_lists) == num
    
    return batch_file_lists

In [9]:
def process(zip_file_path, file_path_list, midi_checker='default', remove_invalid=True, remove_duplicated=True, num_workers=None):
    final_valid = []
    final_invalid = []
    duplicated_groups = []
    final_md5_dict = None

    with ProcessPoolExecutor(max_workers=num_workers) as pool:
        if num_workers is None:
            num_workers = pool._max_workers
        print('Using %d to process %d files...' % (num_workers, len(file_path_list)))
        
        batch_file_lists = split_list(file_path_list, num_workers)

        results = pool.map(
            batch_process_zip,
            [zip_file_path] * num_workers,
            batch_file_lists,
            [midi_checker] * num_workers,
            range(num_workers),
        )

    for file_list, invalid_file_list, md5_dict in results:
        final_valid.extend(list(file_list))
        final_invalid.extend(list(invalid_file_list))
        if final_md5_dict is None:
            final_md5_dict = md5_dict
        else:
            for key in md5_dict:
                final_md5_dict[key].extend(md5_dict[key])

    if remove_invalid:
        final_list = final_valid
    else:
        final_list = final_valid + final_invalid
    final_list = set(final_list)
    
    if remove_duplicated:
        for key in final_md5_dict:
            if len(final_md5_dict) > 1:
                group = final_md5_dict[key]
                duplicated_groups.append(group)
                for item in group[1:]:
                    try:
                        final_list.remove(item)
                    except KeyError:
                        pass
    
    return final_list, final_invalid, duplicated_groups

In [10]:
final_list, final_invalid, duplicated_groups = process(zip_file_path, file_path_list)

Using 8 to process 2111510 files...


## Save

In [None]:
data_utils.dump_list(file_list, save_path)