In [1]:
import os
import random

In [2]:
src_data_dir = '/Users/feifang/Desktop/Dev/glue_data'
target_data_dir = './glue_data'

In [3]:
tasks = ['SST-2', 'RTE', 'MRPC']

### Downsample datasets 

In [4]:
def read_data(path):
    with open(path, 'r', encoding='utf-8') as f:
        header = f.readline()
        all_samples = f.readlines()
    return header, all_samples

In [5]:
def write_data(task, fn, header, samples):
    output_dir = os.path.join(target_data_dir, task)
    if not os.path.isdir(output_dir): os.makedirs(output_dir)
    output_fn = os.path.join(output_dir, fn)
    with open(output_fn, 'w', encoding='utf-8') as f:
        f.write(header)
        f.writelines(samples)

In [6]:
def sample_data(input_dir, input_fn, num_samples):
    task = input_dir.split('/')[-1]

    if input_fn == 'dev_ids.tsv': 
        return
    input_fp = os.path.join(input_dir, input_fn)
    header, all_data = read_data(input_fp)
    if num_samples > 0:
        num_samples = min(num_samples, len(all_data))
        random.seed(42)
        samples = random.sample(all_data, num_samples)
    else:
        samples = all_data
    write_data(task, input_fn, header, samples)

In [83]:
def sample():
    for root, _, files in os.walk(src_data_dir):
        for f in files:
            if f.endswith('tsv'):
                task = root.split('/')[-1]
                print(f'{task}: {f}')
                if task not in tasks:
                    print('Skipping: task not in domain.')
                    continue
                num_samples = input('Enter number of samples for this dataset: ')
                print('\n')
                if not num_samples: 
                    return
                sample_data(root, f, int(num_samples))

In [84]:
sample()

SST-2: train.tsv
Enter number of samples for this dataset: 9128


SST-2: dev.tsv
Enter number of samples for this dataset: 0


SST-2: test.tsv
Enter number of samples for this dataset: 




### for SST only: move 1128 samples of train to dev 

In [85]:
task = 'SST-2'
sst_train_fp = os.path.join(target_data_dir, task, 'train.tsv')
sst_header, sst_samples = read_data(sst_train_fp)

In [86]:
len(sst_samples)

9128

In [87]:
indices_shuffled = list(range(len(sst_samples)))
random.seed(42)
random.shuffle(indices_shuffled)

In [88]:
num_train = 8000
indices_to_keep, indices_to_move = indices_shuffled[:num_train], indices_shuffled[num_train:]

In [89]:
new_train_samples = [sst_samples[i] for i in indices_to_keep]
new_dev_samples = [sst_samples[i] for i in indices_to_move]

In [90]:
len(new_train_samples)

8000

In [91]:
with open(sst_train_fp, 'w', encoding='utf-8') as f:
    f.write(sst_header)
    f.writelines(new_train_samples)

In [92]:
sst_dev_fp = os.path.join(target_data_dir, task, 'dev.tsv')

In [93]:
with open(sst_dev_fp, 'a', encoding='utf-8') as f:
    f.writelines(new_dev_samples)

## Move half of dev to test 

In [7]:
for task in tasks:
    dev_fp = os.path.join(target_data_dir, task, 'dev.tsv')
    test_fp = os.path.join(target_data_dir, task, 'test.tsv')
    header, samples = read_data(dev_fp)
    num_samples = len(samples)
    dev_samples, test_samples = samples[:num_samples//2], samples[num_samples//2:]
    with open(dev_fp, 'w', encoding='utf-8') as f:
        f.write(header)
        f.writelines(dev_samples)
    with open(test_fp, 'w', encoding='utf-8') as f:
        f.write(header)
        f.writelines(test_samples)