In [1]:
import os
import argparse
from datasets import DatasetDict
from datasets import Dataset
from datasets import load_dataset
from datasets import list_datasets
import logging
import pathlib

In [2]:
def arg_parse():
    parser = argparse.ArgumentParser(description='semantic shift config.')
    # Experiment management:
    parser.add_argument('--project', type=str, default="wikitext-15M",
                        help='Original data path.')
    parser.add_argument('--orig_data_dir', type=str, default="../../data-files/wikitext-15M/",
                        help='Original data path.')
    parser.add_argument('--pos_tag_data_dir', type=str, default="../../data-files/wikitext-15M-pos/",
                        help='Original data path.')
    parser.add_argument('--pos_tag', type=str, default="NOUN",
                        choices=["NOUN"],
                        help='Which pos-tag are you scrambling.')
    parser.add_argument('--shift_type', type=str, default="random",
                        choices=["random", "ft", "random_merge", 
                                 "random_split", "ft_merge", 
                                 "ft_split"],
                        help='Which type of scrambling methods are you using.')
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed.')
    parser.set_defaults(
        # Exp management:
        seed=42,
    )
    try:
        get_ipython().run_line_magic('matplotlib', 'inline')
        args = parser.parse_args([])
    except:
        args = parser.parse_args()
    return args

In [3]:
if __name__ == "__main__":
    
    # Loading arguments
    args = arg_parse()
    try:        
        get_ipython().run_line_magic('matplotlib', 'inline')
        args.seed=42
        is_jupyter = True
    except:
        is_jupyter = False
        
    output_dir = f"../../data-files/{args.project}-{args.pos_tag}-{args.shift_type}-{args.seed}"
    # Create output directory if not exists.
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) 
    
    # loading the data, and extract all the tokens with given pos-tag.
    logging.basicConfig(
        level=logging.INFO, 
        format='%(asctime)s %(levelname)-8s %(message)s', 
        datefmt='%a, %d %b %Y %H:%M:%S', 
        filename=os.path.join(output_dir, "training.log"),
    )
    logger = logging.getLogger(__name__)
    logging.getLogger().addHandler(logging.StreamHandler(os.sys.stdout))

    logging.info("Running Pos-tagging with data lives in:")
    logging.info(args.orig_data_dir)

Running Pos-tagging with data lives in:
../../data-files/wikitext-15M/


In [4]:
wiki_datasets = DatasetDict.load_from_disk(args.pos_tag_data_dir)

In [None]:
wiki_datasets_orig = DatasetDict.load_from_disk(args.orig_data_dir)

In [None]:
matched_index = 0
split = "train"
for i in range(len(wiki_datasets_orig[split])):
    example_orig = wiki_datasets_orig[split][i]["text"]
    if len(example_orig.strip()) > 0:
        tokens_by_space = example_orig.strip().split(" ") # i think we just need to strip this.
        pos_tags = wiki_datasets[split][matched_index]['upos_str'].split(",")
        if len(tokens_by_space) != len(pos_tags):
            print(tokens_by_space)
            print(pos_tags)
            print(len(tokens_by_space), len(pos_tags))
        matched_index += 1
    if matched_index == 10:
        break