In [42]:
import json
# Downloaded from https://datasets.d2.mpi-inf.mpg.de/rakshith/a4nt_usenix/dataset/dataset_blog.json
with open("blog.json", "r") as file:
    json_data = json.load(file)
docs = json_data['docs'][1:] # I don't want to see the first document

In [58]:
from torch.utils import data
from collections import Counter
from tqdm import tqdm
import re

# Took me 2 min to run

def stop(limit = 1, times = [0]):
    times[0] += 1
    assert times[0] < limit, "STOP HERE"


class GenderDataset(data.Dataset):

    def __init__(self, docs, PARAGRAPH_LENGTH = 128, UNK_THRESHOLD = 5) -> None:
        super().__init__()
        
        SEP_TOKEN = "<SEP>"
        SOS_TOKEN = "<SOS>" # This is a misnomer, it is actually the first token of the paragraph
        EOS_TOKEN = "<EOS>" # This is a misnomer, it is actually the last token of the paragraph

        def combine_and_cut(rawtext):
            """
            Generates finer grained sentences by splitting on punctuation.
            Used to cut the document into paragraphs.

            Input: list of strings
            Output: Generate sentences one by one
            """
            for sent in rawtext:
                # Split on punctuation
                sent = re.split(r' ([.!?])', sent)
                # Remove punctuation at the beginning and end of sentences
                sent = [re.sub(r'^[.!?]', '', s) for s in sent]
                sent = [re.sub(r'[.!?]$', '', s) for s in sent]
                # Remove empty strings and convert to lower case
                sent = [x.lower() for x in sent if x]
                yield from sent
        
        #---------------------------------------------------------------------------   
        #  Cut the documents into paragraphs of PARAGRAPH_LENGTH
        #---------------------------------------------------------------------------     
        doc_texts = [] # List of paragraphs
        gender_label = []

        print(f"Cutting documents into paragraphs of length {PARAGRAPH_LENGTH}...")
        for doc in tqdm(docs):

            gender = int(doc['gender'] == 'male')
            doc_texts.append(SOS_TOKEN)
            gender_label.append(gender)

            for sent in combine_and_cut(doc['rawtext']):
                if len(doc_texts[-1].split()) + len(sent.split()) > PARAGRAPH_LENGTH - 2: # Minus 2 for SEP and EOS tokens
                    doc_texts[-1] += " " + EOS_TOKEN
                    doc_texts.append(SOS_TOKEN)
                    gender_label.append(gender)
                doc_texts[-1] += sent + " " + SEP_TOKEN
            doc_texts[-1] += " " + EOS_TOKEN
        
        print(f"Number of documents: {len(doc_texts)}")

        print(f"Counting freqeuncies of words...")
        freq = Counter() # Count the number of times each word appears
        doc_texts_token = [] # This is a list of lists of tokens
        for doc_text in tqdm(doc_texts):
            if len(doc_text.split()) <= PARAGRAPH_LENGTH:
                doc_text_token = doc_text.split()
                freq.update(doc_text_token)
                doc_texts_token.append(doc_text_token)
        
        print(f"Number of documents with lengths <= {PARAGRAPH_LENGTH}: {len(doc_texts_token)}")

        #---------------------------------------------------------------------------   
        #  Convert words to <UNK>, then to indices
        #---------------------------------------------------------------------------     
        print("Number of unique words before converting to <UNK>: ", len(freq))
        before_occur = sum(freq.values())

        unique_words = set()

        print(f"Converting words with frequencies less than {UNK_THRESHOLD} to <UNK>...")
        total_occur = before_occur
        for i, doc_text_token in tqdm(enumerate(doc_texts_token)):
            # Replace words with less than 5 occurrences with <UNK>
            doc_text_token = [word if freq[word] > UNK_THRESHOLD else "<UNK>" for word in doc_text_token]
            unique_words.update(doc_text_token)
            total_occur -= doc_text_token.count("<UNK>")
            doc_texts_token[i] = doc_text_token
        print("Number of unique words after converting <UNK>: ", len(unique_words))
        print(f"Known occurrences rate {round(total_occur/before_occur * 100, 2)}%")
        
        self.raw_text = doc_texts # list of strings of ~128 words
        self.vocab_size = len(unique_words) # numbers of unique words == len(token2idx)
        self.vocab = unique_words # set of unique words
        self.token2idx = {token: idx for idx, token in enumerate(unique_words)} # dict of unique words to indices
        self.idx2token = {idx: token for idx, token in enumerate(unique_words)} # dict of indices to unique words
        self.data = [[self.token2idx[token] for token in doc] for doc in doc_texts_token]
        self.label = gender_label
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx], self.label[idx]
    
    def __iter__(self):
        for idx in range(len(self)):
            yield self[idx]
    

    def tokenize(self, idx):
        """
        Converts indice or a list of indices to tokens.
        Input: int or list of ints
        Output: string or list of strings
        """
        if isinstance(idx, int):
            return self.idx2token[idx]
        elif isinstance(idx, list):
            return " ".join([self.idx2token[i] for i in idx])
        else:
            raise TypeError(f"Expected int or list, got {type(idx)}")
    
    def detokenize(self, token):
        """
        Converts token or a list of tokens to indices.
        Input: string. Document/paragraph to be converted to indices
        Output: list of ints
        """
        if isinstance(token, str):
            return [self.token2idx[t] for t in token.split()]
        else:
            raise TypeError(f"Expected str, got {type(token)}")

gender_data = GenderDataset(docs)

Cutting documents into paragraphs of length 128...


100%|██████████| 19676/19676 [00:20<00:00, 982.21it/s] 


Number of documents: 559126
Counting freqeuncies of words...


100%|██████████| 559126/559126 [00:10<00:00, 53425.92it/s]


Number of documents with lengths <= 128: 554016
Number of unique words before converting to <UNK>:  505954
Converting words with frequencies less than 5 to <UNK>...


554016it [00:09, 59662.08it/s]


Number of unique words after converting <UNK>:  85906
Known occurrences rate 99.01%


In [59]:
import numpy as np
lens =  [len(doc) for doc,_ in gender_data]
lens = np.array(lens)
lens.mean(), lens.std(), lens.max(), lens.min()

(114.59587087737538, 17.034593549416673, 128, 2)

In [57]:
print("Data is a pair of 128-dim vector of indices and a gender label:", gender_data[0], '\n')
tokenized_sample = gender_data.tokenize(gender_data[0][0])
print("You can use GenderData.tokenize to retokenize the data:", tokenized_sample, '\n')
detokenized_sample = gender_data.detokenize(tokenized_sample)
print("Or use GenderData.detokenize detokenize it back to indices:",detokenized_sample, '\n')

Data is a pair of 128-dim vector of indices and a gender label: ([58430, 45267, 58809, 77229, 21467, 6736, 32732, 73874, 58077, 59426, 68442, 21467, 11245, 32399, 59426, 31956, 55413, 49670, 53510, 55413, 76617, 59804, 28171, 34420, 68442, 43731, 9120, 11245, 59426, 81851, 18060, 77727, 15045, 51681, 31956, 16452, 55413, 4830, 16893, 63852, 55413, 36453, 60389, 59426, 45972, 70433, 45267, 13442, 55413, 4830, 16893, 63852, 59426, 34420, 40269, 3801, 21553, 59426, 61123, 27520, 6464, 59426, 31956, 84366, 28681, 69027, 32085, 27756, 76076, 77229, 21467, 23903, 59804, 15000, 59804, 8692, 55413, 1895, 59426, 67367, 3801, 59426, 51487, 32085, 21752, 32085, 59426, 32085, 59426, 18060, 28993, 45267, 57143, 79666, 72935, 33527, 55413, 77562, 59804, 57143, 12282, 67423, 59426, 57143, 18060, 31871, 69027, 10767, 58528, 82382, 81851, 7752, 80672, 81851, 59426, 66565], 0) 

You can use GenderData.tokenize to retokenize the data: <SOS> i took an i.q. test the other day <SEP> my i.q. is 144 <SEP> and