In [3]:
import os
import json
# Downloaded from https://datasets.d2.mpi-inf.mpg.de/rakshith/a4nt_usenix/dataset/dataset_blog.json
with open(os.path.join(os.curdir, "data", "blog.json"), "r") as file:
    json_data = json.load(file)
docs = json_data['docs'][1:] # I don't want to see the first document

In [72]:
import torch
from torch.utils import data
from collections import Counter
from tqdm import tqdm
from sentence_transformers import CrossEncoder

tokenizer = CrossEncoder('cross-encoder/stsb-TinyBERT-L-4').tokenizer

# Took me 5 min to run

def stop(limit = 1, times = [0]):
    times[0] += 1
    assert times[0] < limit, "STOP HERE"


class GenderDataset(data.Dataset):

    def __init__(self, docs, PARAGRAPH_LENGTH=128, MIN_LENGTH=128, UNK_THRESHOLD=10) -> None:
        
        super().__init__()
        
        #---------------------------------------------------------------------------   
        #  Cut the documents into paragraphs of PARAGRAPH_LENGTH
        #---------------------------------------------------------------------------     
        doc_texts = [] # List of paragraphs
        gender_label = []

        print(f"Cutting documents into paragraphs of length {PARAGRAPH_LENGTH}...")
        freq = Counter() # Count the number of times each word appears
        discarded_doc = 0
        for doc in tqdm(docs):

            gender = int(doc['gender'] == 'male')
            # gender_label.append(gender)

            for text in (doc['rawtext']):
                tokens = self.str2token(text)
                i = 0
                while i + PARAGRAPH_LENGTH < len(tokens):
                    doc_texts.append(tokens[i : i + PARAGRAPH_LENGTH])
                    freq.update(doc_texts[-1])
                    gender_label.append(gender)
                    i += PARAGRAPH_LENGTH
                last_bit = tokens[-PARAGRAPH_LENGTH:]
                if len(last_bit) >= MIN_LENGTH:
                    doc_texts.append(last_bit)
                    gender_label.append(gender)
                    freq.update(doc_texts[-1])
                else:
                    discarded_doc += 1
        
        print(f"Number of documents: {len(doc_texts)}")
        print(f"Discarded ratio (due to MIN_LENGTH): {round((discarded_doc) / (discarded_doc + len(doc_texts)), 3)}")
        

        #---------------------------------------------------------------------------   
        #  Convert words to [UNK], then to indices
        #---------------------------------------------------------------------------     
        print("Number of unique words before converting to [UNK]: ", len(freq))
        before_occur = sum(freq.values())

        unique_words = set()

        ids = []

        print(f"Converting words with frequencies less than {UNK_THRESHOLD} to [UNK]...")
        total_occur = before_occur
        for i, doc_text in enumerate(tqdm(doc_texts)):
            # Replace words with less than 5 occurrences with [UNK]
            doc_text = [word if freq[word] > UNK_THRESHOLD else "[UNK]" for word in doc_text]
            unique_words.update(doc_text)
            total_occur -= doc_text.count("[UNK]")
            doc_texts[i] = doc_text
            ids.append(self.token2idx(doc_text))

        print("Number of unique words after converting [UNK]: ", len(unique_words))
        print(f"Known occurrences rate {round(total_occur/before_occur * 100, 2)}%")

        self._vocab_size = len(unique_words) # numbers of unique words == len(token2idx)
        self._vocab = unique_words # set of unique words
        self.raw_tokens = doc_texts # list of list of string tokens
        self.ids = ids # list of list of ints
        self.label = gender_label

    def __len__(self):
        return len(self.ids)
    
    def __getitem__(self, idx):
        return self.ids[idx], self.label[idx]
    
    def __iter__(self):
        for idx in range(len(self)):
            yield self[idx]

    def str2idx(self, s: str):
        return tokenizer.encode(s)[1:-1]
    
    def str2token(self, s: str):
        return tokenizer.convert_ids_to_tokens(self.str2idx(s))
    
    def token2idx(self, tokens : list[str]):
        return tokenizer.convert_tokens_to_ids(tokens)
    
    def idx2str(self, idx):
        return tokenizer.convert_tokens_to_string(tokenizer.convert_ids_to_tokens(idx))


def gender_data_collate_fn(gender_data):
    src_len = torch.tensor([len(gender_datum[0]) for gender_datum in gender_data], dtype=torch.int32)
    max_len = max(src_len)
    src_ids = torch.stack([torch.cat([
        torch.tensor(gender_datum[0], dtype=torch.int32), 
        torch.zeros(max_len - len(gender_datum[0]), dtype=torch.int32)
        ]) for gender_datum in gender_data])
    tgt = torch.tensor([gender_datum[1] for gender_datum in gender_data])
    return src_ids, src_len, tgt

gender_data = GenderDataset(docs)

Cutting documents into paragraphs of length 128...


  0%|          | 0/19676 [00:00<?, ?it/s]Token indices sequence length is longer than the specified maximum sequence length for this model (1154 > 512). Running this sequence through the model will result in indexing errors
100%|██████████| 19676/19676 [02:49<00:00, 115.81it/s]


Number of documents: 618377
Discarded ratio (due to MIN_LENGTH): 0.163
Number of unique words before converting to [UNK]:  27334
Converting words with frequencies less than 10 to [UNK]...


100%|██████████| 618377/618377 [00:53<00:00, 11658.32it/s]

Number of unique words after converting [UNK]:  25659
Known occurrences rate 99.99%





In [62]:
import numpy as np
lens =  [len(doc) for doc,_ in gender_data]
lens = np.array(lens)
lens.mean(), lens.std(), lens.max(), lens.min()

(128.0, 0.0, 128, 128)

In [47]:
print("Data is a pair of 128-dim vector of indices and a gender label:", gender_data[0], '\n')
tokenized_sample = gender_data.idx2str(gender_data[0][0])
print("You can use GenderData.idx2str to str-lize the data:", tokenized_sample, '\n')
detokenized_sample = gender_data.str2token(tokenized_sample)
print("Or use GenderData.str2token convert strings to tokens:",detokenized_sample, '\n')

Data is a pair of 128-dim vector of indices and a gender label: ([1045, 2165, 2019, 1045, 1012, 1053, 1012, 3231, 1996, 2060, 2154, 1012, 1012, 1012, 1012, 2026, 1045, 1012, 1053, 1012, 2003, 14748, 1012, 1998, 2000, 5587, 15301, 2000, 4544, 1010, 3984, 2054, 2026, 7789, 6412, 2003, 1012, 1000, 2017, 2024, 3811, 12785, 1998, 6037, 2000, 2022, 1037, 11067, 2000, 2087, 2111, 1012, 17012, 2213, 1012, 1012, 1012, 1045, 3711, 2000, 2022, 1037, 11067, 1029, 2054, 2515, 2008, 2812, 1029, 3046, 2009, 2041, 1012, 1998, 2074, 13012, 9035, 1024, 2632, 5677, 15313, 2018, 2019, 1045, 1012, 1053, 1012, 1997, 1010, 2066, 1010, 16923, 2000, 18582, 1012, 3786, 2008, 1012, 1024, 1011, 25269, 2497, 1011, 7479, 18447, 13348, 17905, 22199, 1012, 4012, 2030, 7479, 18515, 22199, 1012, 4012], 0) 

You can use GenderData.idx2str to str-lize the data: i took an i. q. test the other day.... my i. q. is 144. and to add insult to injury, guess what my intellectual description is. " you are highly gifted and appear