In [None]:
import os
import argparse
from datasets import DatasetDict
from datasets import Dataset
from datasets import load_dataset
from datasets import list_datasets
import logging
import pathlib
from tqdm import tqdm, trange
import random
import copy
import fasttext
import fasttext.util
import torch
import torch.nn as nn
from nltk.stem import PorterStemmer

In [None]:
def arg_parse():
    parser = argparse.ArgumentParser(description='semantic shift config.')
    # Experiment management:
    parser.add_argument('--project', type=str, default="wikitext-15M",
                        help='Original data path.')
    parser.add_argument('--orig_data_dir', type=str, default="../../data-files/wikitext-15M/",
                        help='Original data path.')
    parser.add_argument('--pos_tag_data_dir', type=str, default="../../data-files/wikitext-15M-pos/",
                        help='Original data path.')
    parser.add_argument('--pos_tag', type=str, default="NOUN",
                        choices=["NOUN"],
                        help='Which pos-tag are you scrambling.')
    parser.add_argument('--shift_type', type=str, default="random",
                        choices=["random", "ft", "random_merge", 
                                 "ft_merge", "split"],
                        help='Which type of scrambling methods are you using.')
    parser.add_argument('--seed', type=int, default=42,
                        help='Random seed.')
    parser.set_defaults(
        # Exp management:
        seed=42,
    )
    try:
        get_ipython().run_line_magic('matplotlib', 'inline')
        args = parser.parse_args([])
    except:
        args = parser.parse_args()
    return args

def cosine_sim_distance(v1, v2):
    cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
    sim = cos(v1,v2)
    return sim # [0,2]

In [None]:
if __name__ == "__main__":
    
    # Loading arguments
    args = arg_parse()
    try:        
        get_ipython().run_line_magic('matplotlib', 'inline')
        args.seed=42
        is_jupyter = True
    except:
        is_jupyter = False
        
    output_dir = f"../../data-files/{args.project}-{args.pos_tag}-{args.shift_type}-{args.seed}"
    # Create output directory if not exists.
    pathlib.Path(output_dir).mkdir(parents=True, exist_ok=True) 
    
    # loading the data, and extract all the tokens with given pos-tag.
    logging.basicConfig(
        level=logging.INFO, 
        format='%(asctime)s %(levelname)-8s %(message)s', 
        datefmt='%a, %d %b %Y %H:%M:%S', 
    )
    logger = logging.getLogger(__name__)
    # logging.getLogger().addHandler(logging.StreamHandler(os.sys.stdout))

    logging.info("Running Pos-tagging with data lives in:")
    logging.info(args.orig_data_dir)
    
    wiki_datasets = DatasetDict.load_from_disk(args.pos_tag_data_dir)
    
    logging.info(f"Extract all the {args.pos_tag} from the dataset...")
    words = set([])
    total_count = len(wiki_datasets["train"])
    for i, example in enumerate(wiki_datasets["train"]):
        if i%10000 == 0:
            logging.info(f"processed {i}/{total_count}")
        assert len(example["sentence_str"]) == len(example["upos_str"])
        for pair in zip(example["sentence_str"], example["upos_str"]):
            if pair[-1] == args.pos_tag:
                words.add(pair[0].strip())
                
    ft = fasttext.load_model('../../data-files/cc.en.300.bin')

In [None]:
args.shift_type = 'ft_merge'

In [None]:
words_pair = {}
logging.info(f"Shifting {args.pos_tag} with method type {args.shift_type}")
if args.shift_type == "random":
    # for random, we don't scramble on stems or lower or upper case.
    words_orig = list(words)
    words_shuffled = copy.deepcopy(words_orig)
    random.shuffle(words_shuffled)
    for i in range(len(words_orig)):
        words_pair[words_orig[i]] = words_shuffled[i]
elif args.shift_type == "ft":
    # for all other methods, we work on stems, and work on lower case.
    from textblob import Word
    stems = set([])
    for i in range(len(list(words))):
        stems.add(Word(list(words)[i].lower()).singularize())
    
    words_orig = list(stems)
    embed_list = []
    vocab_idx = {}
    for i in range(len(words_orig)):
        vocab_idx[words_orig[i]] = i
        embed_list += [torch.tensor(ft.get_word_vector(words_orig[i]))]
    embeddings = torch.stack(embed_list, dim=0)
    similarity_fn = cosine_sim_distance
    for _, token in enumerate(tqdm(words_orig)):
        token_embed = embeddings[vocab_idx[token]]
        repeat_token_embed = [token_embed*len(words_orig)]
        repeat_token_embed = torch.stack(repeat_token_embed, dim=0)
        similarity_score = similarity_fn(repeat_token_embed, embeddings)
        argmax_second = torch.topk(similarity_score, k=2, dim=0)[1][-1].tolist()
        words_pair[token] = words_orig[argmax_second]
elif args.shift_type == "random_merge":
    # for all other methods, we work on stems, and work on lower case.
    from textblob import Word
    stems = set([])
    for i in range(len(list(words))):
        stems.add(Word(list(words)[i].lower()).singularize())
    words_orig = list(stems)
    merge_ratio = 10
    num_pivot = len(stems)//merge_ratio
    pivots = random.sample(words_orig, num_pivot)
    _pivots = []
    for i in range(merge_ratio+1):
        for p in pivots:
            _pivots += [p]
    random.shuffle(_pivots)
    for i in range(len(words_orig)):
        words_pair[words_orig[i]] = _pivots[i]
elif args.shift_type == "ft_merge":
    # for all other methods, we work on stems, and work on lower case.
    from textblob import Word
    stems = set([])
    for i in range(len(list(words))):
        stems.add(Word(list(words)[i].lower()).singularize())
    words_orig = list(stems)
    merge_ratio = 10
    num_pivot = len(stems)//merge_ratio
    pivots = random.sample(words_orig, num_pivot)
    pivot_embed_list = []
    pivot_idx = {}
    idx = 0
    for p in pivots:
        pivot_embed_list += [torch.tensor(ft.get_word_vector(p))]
        pivot_idx[p] = idx
        idx += 1
    pivot_embeddings = torch.stack(pivot_embed_list, dim=0)
    for _, token in enumerate(tqdm(words_orig)):
        token_embed = torch.tensor(ft.get_word_vector(token))
        repeat_token_embed = [token_embed*len(pivot_idx)]
        repeat_token_embed = torch.stack(repeat_token_embed, dim=0)
        similarity_score = similarity_fn(repeat_token_embed, pivot_embeddings)
        argmax_two = torch.topk(similarity_score, k=2, dim=0)
        if argmax_two[0][0].tolist() == 0:
            argmax = torch.topk(similarity_score, k=2, dim=0)[0][-1].tolist()
        else:
            argmax = torch.topk(similarity_score, k=2, dim=0)[1][-1].tolist()
        words_pair[token] = pivots[argmax]
elif args.shift_type == "split":
    # split in on-the-fly, we don't need to have a mapping predefined i think!
    # a simple counter should be enough.
    pass