# Stage 2: Entity Disambiguation

We utilise GENRE to disambiguate the answers acquired from the first stage 1. GENRE utilizes a sequence-to-sequence approach for entity retrieval, like linking, built on a fine-tuned BART architecture 2. GENRE produces the unique entity name based on the provided input text. Entities intended for disambiguation are marked in the prompt using special tokens: START_ENT and END_ENT. For merging the entity with the article text for more context, we use the following prompt template:

```text
Discussing [START_ENT] ([HERO], [VILLAIN] or [VICTIM]) [END_ENT]: [TEXT]. 
```

De Cao et al. originally combined a prefix tree with a constrained beam search to exclusively generate identifiers corresponding to Wikipedia titles. In our context, given that many entities, such as lesser-known individuals, might not have a Wikipedia page, we employed GENRE without the prefix tree constraint. This approach refined the answers derived from the FLAN model. For high-profile entities, like politicians, the model was able to disambiguate different spelling variants ( Example: D. J. Trump to Donald Trump). Additionally, our observations indicated that when GENRE was used in this manner, it yielded a less noisy output, stripping away redundant punctuation, resolving abbreviations, and at times retrieving information from the original article text.

## Fetch Articles

In [1]:
from utils.preprocessing import *
from utils.accelerators import *
from utils.multithreading import *
from utils.database import *
from utils.files import *
from datasets import Dataset
import random

  from .autonotebook import tqdm as notebook_tqdm


### Connect to Database

Credentials are sourced from the `.env` file.

In [2]:
_, db = getConnection(use_dotenv=True)

### Query Database

Fetches a limited number of articles from the database that haven't been processed yet, 
returning specified fields like url, title, and parsing result text.

In [3]:
collection = "v2_sampled_articles"
fields = {"url": 1, "title": 1, "parsing_result.text": 1, 'processing_result': 1}
query = {
    "processing_result": {"$exists": True},
    "parsing_result.text_length": {"$lt": 10000},
    #"denoising_result": {"$exists": False},
    #"initial_subsample": False
}
articles = fetchArticleTexts(db, 50, 0, fields, query, collection)

Example article:

In [4]:
example_article = random.choice(articles)
title = example_article.get("title")
text = example_article.get("parsing_result").get("text")
print(f"Title: {title}\nText: {text}")
print(f"Processing Result: {example_article.get('processing_result')}")


Title: The New Civil Rights Movement
Text: Secretary of Housing and Urban Development Dr. Ben Carson, is the latest member of the Trump administration to test positive for coronavirus. CNBC reports the news, after... "This sounded like a slur to me." "But now, fortunately, God's given me a chance to do something about it," Carson added. "We are all now more stupid than we were when we came in the room today sir, thank you," Rep. Quigley concluded after going around and... If the government were not shutdown, taxpayers would be on the hook for the Secretary's travel to Missouri. 'It's a Very Complex Issue' Carson Claims Newly released emails show Ben Carson and his wife personally selected a $31,000 dining room set for his office at the Department of Housing and Urban... 'Sometimes I Get a Little Bit Tired of People Ascribing to Me Things That People Have Said That I Believe' Carson Says Here's What You'll Want to Know About Carson's Name Mysteriously Disappears From Schedule HUD Secret


Processes the 'parsing_result' of each article to clean the text, and filters out articles 
that lack a 'title' or 'parsing_result'.


In [5]:
# Basic text cleaning, e.g. removing newlines, tabs, etc.
articles = cleanArticles(articles)

Cleaning articles: 100%|██████████| 50/50 [00:00<00:00, 1599.85it/s]


In [6]:
# Filter out articles with no title or no parsing result 
articles = [article for article in articles if article.get(
    "title", "") and article.get("parsing_result", "")]

print("Number of articles:", len(articles))

Number of articles: 50


### Export as JSON

Saves the given data to a JSON file for optional visual inspection.

In [7]:
exportAsJSON("../data/input/articles.json",  articles)

### Convert to HF Dataset

Convert article IDs to strings and transform a list of articles into a dataset with fields: id, title, url, and text extracted from parsing results. The HuggingFace `datasets` library provides several key advantages over plain JSON files:

- **Efficiency**: The datasets are memory-mapped, allowing you to work with data that's larger than your available RAM without loading the entire dataset into memory. 
- **Speed**: Datasets in the HuggingFace format (which is Arrow-based) can be loaded faster than large JSON files, facilitating quicker data operations.
- **Columnar Storage**: By using Apache Arrow for storage, HuggingFace datasets benefit from a columnar format that ensures more efficient serialization and deserialization compared to row-based storage, such as JSON.


In [8]:
column_names = ["_id", "title", "url", "parsing_result.text", "processing_result.hero", "processing_result.villain", "processing_result.victim"]
articles = convertListToDataset(articles, column_names)
describeDataset(articles)

Number of rows: 50
Column names: ['_id', 'title', 'url', 'text', 'hero', 'villain', 'victim']
Features (schema): {'_id': Value(dtype='string', id=None), 'title': Value(dtype='string', id=None), 'url': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'hero': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'villain': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'victim': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}


In [9]:
print("Example Article Text:", articles[42]["text"][:100])
print("Example Article Hero:", articles[42]["hero"])
print("Example Article Villain:", articles[42]["villain"])
print("Example Article Victim:", articles[42]["victim"])

Example Article Text: UPDATE: 6:52 p.m. The sheriff of the Oregon county where armed militia members have taken over a fed
Example Article Hero: ['David Ward']
Example Article Villain: ['Ammon Bundy']
Example Article Victim: ['Harney County residents']


Save dataset to disk:

In [10]:
articles.save_to_disk('../data/input/articles')

Saving the dataset (1/1 shards): 100%|██████████| 50/50 [00:00<00:00, 5860.75 examples/s]


***

## Prepare Dataset 

In [11]:
from langchain.text_splitter import RecursiveCharacterTextSplitter, CharacterTextSplitter
from datasets import Dataset, load_from_disk
from transformers import AutoTokenizer
from multiprocessing import Pool
from utils.preprocessing import *
from utils.database import *
from utils.files import *
import transformers

The code `os.environ["TOKENIZERS_PARALLELISM"] = "false"` disables parallel tokenization in HuggingFace's libraries. It's a way to suppress warnings and prevent potential issues tied to multi-core tokenization.
See: https://stackoverflow.com/questions/62691279/how-to-disable-tokenizers-parallelism-true-false-warning

In [12]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
transformers.utils.logging.set_verbosity_error()

### Import Raw Dataset

In [13]:
articles = load_from_disk('../data/input/articles')
describeDataset(articles)

Number of rows: 50
Column names: ['_id', 'title', 'url', 'text', 'hero', 'villain', 'victim']
Features (schema): {'_id': Value(dtype='string', id=None), 'title': Value(dtype='string', id=None), 'url': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'hero': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'villain': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'victim': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None)}


### Prepare Dataset

#### Define Prompt Template:

In [14]:
PROMPT_TEMPLATE_prefix = "Discussing [START_ENT] {entity} [END_ENT]"
PROMPT_TEMPLATE = "{prefix} . {article}"

# Test the template with a dummy text
prefix = PROMPT_TEMPLATE_prefix.format(entity='Donald Trump')
print(PROMPT_TEMPLATE.format(prefix=prefix,
      article='Lorem ipsum dolor sit amet, consectetur adipiscing elit.'))

Discussing [START_ENT] Donald Trump [END_ENT] . Lorem ipsum dolor sit amet, consectetur adipiscing elit.


#### Expand Dataset

Functions to segment articles into chunks fitting within the input window:

In [15]:
# Load the tokenizer
tokenizer = AutoTokenizer.from_pretrained("facebook/genre-kilt", add_prefix_space=True)
print("Input window length:", tokenizer.model_max_length)

Input window length: 1024


In [16]:
template_length = calcInputLength(tokenizer, PROMPT_TEMPLATE.format(prefix=prefix, article=' '))
print("Max length of empty prompt template:", template_length)

Max length of empty prompt template: 20


For each article, distinct prompts identify 'hero', 'villain', and 'victim'. If an article exceeds the model's input size, it's divided into chunks, generating additional prompts. It seems that one article results in about 10 to 12 prompts.

In [17]:
articles[42]

{'_id': '64d8eb39516b265872292f80',
 'title': 'Oregon Sheriff To Meet Again Friday With Armed Refuge Militia',
 'url': 'http://talkingpointsmemo.com/livewire/militia-sheriff-peaceful-resolution--2',
 'text': 'UPDATE: 6:52 p.m. The sheriff of the Oregon county where armed militia members have taken over a federal wildlife refuge met Thursday afternoon local time with the militia members and will meet again Friday. Here’s the update his office tweeted after the Thursday the discussion: Meeting now over. Plans to talk again tomorrow. Sheriff Ward called on them for a peaceful resolution. #HarneyCounty — Harney Cty. Sheriff (@HarneyCoSheriff) January 7, 2016 The office also stated that Harney County Sheriff David Ward asked Ammon Bundy, one of the men leading the militia takeover, to leave. Sheriff Ward asks Ammon Bundy to please leave and respect the wishes of Harney County residents. — Harney Cty. Sheriff (@HarneyCoSheriff) January 7, 2016 His office emphasized he is seeking a “peaceful 

In [18]:
def expandRow(row, template_prefix, template, col_name="text", roles=['hero', 'villain', 'victim']):
    """
    Generate prompts based on various roles and text chunks from the input row.
    """
    prompts = []

    # Generate prompts for each role and text chunk
    for role in roles:
        for entity in row.get(role, []):
            text = row.get(col_name, "")
            prompt_prefix = template_prefix.format(entity=entity)
            prompt = template.format(
                prefix=prompt_prefix, article=text)
            new_row = {
                **row,
                'prompt': prompt,
                'role': role,
                'entity': entity,
            }
            prompts.append(new_row)
            
    return prompts

In [19]:
roles=['hero', 'villain', 'victim']
col_name="text"

example_row = articles[42]
example_row_exp = expandRow(example_row, PROMPT_TEMPLATE_prefix, PROMPT_TEMPLATE, col_name, roles)

print("Example Prompt:", example_row_exp[0].get("prompt"))
print("Expanded row:", example_row_exp)
print("Expanded row length:", len(example_row_exp))

Example Prompt: Discussing [START_ENT] David Ward [END_ENT] . UPDATE: 6:52 p.m. The sheriff of the Oregon county where armed militia members have taken over a federal wildlife refuge met Thursday afternoon local time with the militia members and will meet again Friday. Here’s the update his office tweeted after the Thursday the discussion: Meeting now over. Plans to talk again tomorrow. Sheriff Ward called on them for a peaceful resolution. #HarneyCounty — Harney Cty. Sheriff (@HarneyCoSheriff) January 7, 2016 The office also stated that Harney County Sheriff David Ward asked Ammon Bundy, one of the men leading the militia takeover, to leave. Sheriff Ward asks Ammon Bundy to please leave and respect the wishes of Harney County residents. — Harney Cty. Sheriff (@HarneyCoSheriff) January 7, 2016 His office emphasized he is seeking a “peaceful resolution” and is not there to arrest anyone in a tweet. This meeting is called for a peaceful resolution. Sheriff Ward is NOT there to make an ar

Process datataset using multiple proesses:

In [20]:
PROMPT_TEMPLATE_prefix

'Discussing [START_ENT] {entity} [END_ENT]'

In [21]:
num_processes = 12
params = (PROMPT_TEMPLATE_prefix, PROMPT_TEMPLATE, col_name, roles,)
dataset_hvv = processDataset(articles, num_processes, expandRow, params)


In [22]:
dataset_hvv.save_to_disk('../data/input/articles_chunkified')

Saving the dataset (1/1 shards): 100%|██████████| 109/109 [00:00<00:00, 5859.09 examples/s]


***

## Tokenize Dataset

Tokenization refers to the process of converting input text into smaller units, such as words or subwords, which are then represented as tokens. These tokens are mapped to indices in a vocabulary that the model can understand. Hugging Face provides a variety of tokenizers, each suited for different types of models. For instance, the BertTokenizer is designed for BERT-like models and tokenizes text into wordpieces. Similarly, the GPT2Tokenizer is tailored for GPT-2-like models and tokenizes text into subwords using the Byte-Pair Encoding (BPE) algorithm. 

#### Parameters

Below are descriptions of key parameters helpful for using these tokenizers:

**`add_special_tokens`**:
* Whether to add special tokens such as `[CLS]` and `[SEP]` (default is True).
* Special tokens are necessary for some models to function properly.

**`max_length`**:
* The maximum number of tokens for the output (default varies, often 512).
* Texts longer than this will be truncated.

**`padding`**:
* Whether to pad the output to `max_length`, and the padding strategy (default is False).
* Options include `'max_length'`, `'longest'`, or `True` to pad to the length of the longest sequence.

**`truncation`**:
* Whether to truncate sequences to `max_length` (default is False).

**`return_tensors`**:
* The framework to use for the returned tensors, either `'pt'` for PyTorch or `'tf'` for TensorFlow (default is None, which returns plain lists).

**`return_token_type_ids`**:
* Whether to return token type IDs (default is True).
* Necessary for some models to understand the different segments of input (e.g., question vs answer).

**`return_attention_mask`**:
* Whether to return the attention mask (default is True).
* Attention masks tell the model which tokens to pay attention to and which to ignore.

**`verbose`**:
* Whether to log information during tokenization (default is True).

**`is_split_into_words`**:
* Whether the input is pre-tokenized into words (default is False).

These parameters allow for fine-grained control over the tokenization process, ensuring the text is prepared in a way that's suitable for your model and task.

For more information, consider checking the [`encode` and `encode_plus` methods documentation](https://huggingface.co/transformers/main_classes/tokenizer.html).


In [23]:
# Paramater passed to the tokenizer
tokenizer_params = {"truncation": True, "is_split_into_words": False,
                    "add_special_tokens": True, "padding": "max_length"}

# Parameters passed to the tokenization function
params = {"tokenizer": tokenizer, "col_name": "prompt", "params": tokenizer_params}

# Tokenize the dataset
tokenized_dataset = dataset_hvv.map(tokenizeInputs, fn_kwargs=params)

Map: 100%|██████████| 109/109 [00:00<00:00, 280.56 examples/s]


In [24]:
tokenized_dataset.save_to_disk('../data/input/articles_tokenized')

Saving the dataset (1/1 shards): 100%|██████████| 109/109 [00:00<00:00, 9280.19 examples/s] 


***

## Make Predictions

In [25]:
from torch.utils.data import DataLoader, TensorDataset, SequentialSampler
from datasets import Dataset, load_from_disk, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from tqdm import tqdm
import threading
import torch
import pickle
import time
import copy

In [26]:
dataset = load_from_disk('../data/input/articles_tokenized')
print("Dataset length:", len(dataset))

Dataset length: 109


### Split Dataset

List infos about the available GPUs:

In [27]:
gpu_info_list = listAvailableGPUs()

GPU 0:
  Name: Tesla P100-PCIE-16GB
  Memory: 16276.00 MiB
  Compute Capability: 6.0

GPU 1:
  Name: Tesla P100-PCIE-16GB
  Memory: 16276.00 MiB
  Compute Capability: 6.0



Determine the number of available GPUs:

In [28]:
num_gpus = torch.cuda.device_count()
print(f'Number of available GPUs: {num_gpus}')


Number of available GPUs: 2


In [29]:
# Split the dataset into chunks (one for each GPU)
chunks = splitDataset(dataset, num_chunks=num_gpus)

# Print the length of each chunk
print("Number of chunks:", len(chunks))
for i, chunk in enumerate(chunks):
    print(f"Chunk {i} length:", len(chunk))


Number of chunks: 2
Chunk 0 length: 55
Chunk 1 length: 54


### Process Articles

Check GPU utilization:

In [31]:
!nvidia-smi

Sun Oct 22 09:57:37 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06   Driver Version: 525.125.06   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:17:00.0 Off |                    0 |
| N/A   44C    P0    27W / 250W |      2MiB / 16384MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  Tesla P100-PCIE...  Off  | 00000000:65:00.0 Off |                    0 |
| N/A   46C    P0    29W / 250W |      2MiB / 16384MiB |      0%      Default |
|       

#### Parameters for Text Generation

Each parameter influences the text generation in a specific way. Below are the parameters along with a brief explanation:

**`max_length`**:
* Sets the maximum number of tokens in the generated text (default is 50).
* Generation stops if the maximum length is reached before the model produces an EOS token.
* A higher `max_length` allows for longer generated texts but may increase the time and computational resources required.

**`min_length`**:
* Sets the minimum number of tokens in the generated text (default is 10).
* Generation continues until this minimum length is reached even if an EOS token is produced.

**`num_beams`**:
* In beam search, sets the number of "beams" or hypotheses to keep at each step (default is 4).
* A higher number of beams increases the chances of finding a good output but also increases the computational cost.

**`num_return_sequences`**:
* Specifies the number of independently computed sequences to return (default is 3).
* When using sampling, multiple different sequences are generated independently from each other.

**`early_stopping`**:
* Stops generation if the model produces the EOS (End Of Sentence) token, even if the predefined maximum length is not reached (default is True).
* Useful when an EOS token signifies the logical end of a text (often represented as `</s>`).

**`do_sample`**:
* Tokens are selected probabilistically based on their likelihood scores (default is True).
* Introduces randomness into the generation process for diverse outputs.
* The level of randomness is controlled by the 'temperature' parameter.

**`temperature`**:
* Adjusts the probability distribution used for sampling the next token (default is 0.7).
* Higher values make the generation more random, while lower values make it more deterministic.

**`top_k`**:
* Limits the number of tokens considered for sampling at each step to the top K most likely tokens (default is 50).
* Can make the generation process faster and more focused.

**`top_p`**:
* Also known as nucleus sampling, sets a cumulative probability threshold (default is 0.95).
* Tokens are sampled only from the smallest set whose cumulative probability exceeds this threshold.

**`repetition_penalty`**:
* Discourages the model from repeating the same token by modifying the token's score (default is 1.5).
* Values greater than 1.0 penalize repetitions, and values less than 1.0 encourage repetitions.


In [32]:
def generatePredictions(process_id, dataset, device):
    """Generates predictions for a given dataset."""

    # Print some information about the process
    print(f"--------- Process {process_id:02} ---------")
    print(f"Dataset length: {len(dataset)}")
    print(f"Device: {device}")
    print(f"------------------------------")

    # Load tokenizer and model for generation
    tokenizer = AutoTokenizer.from_pretrained(
        "facebook/genre-kilt", add_prefix_space=True)
    model = AutoModelForSeq2SeqLM.from_pretrained("facebook/genre-kilt").eval()
    model.eval()
    model.to(device)

    print("Device:", torch.cuda.get_device_name())

    dataset_full = copy.copy(dataset)
    dataset.set_format(type='torch', columns=[
        'input_ids', 'attention_mask'])

    # Create dataloader without explicit sampler for sequential loading
    BATCH_SIZE = 64
    dataloader = DataLoader(
        dataset, batch_size=BATCH_SIZE, shuffle=False)

    params = {'do_sample': True,
              'early_stopping': False,
              # 'max_length': 100,
              # 'min_length': 1,
              # 'num_beam_groups': 2,
              # 'num_beams': 5,
              # 'max_tokens': 32,
              # 'min_tokens': 1,
              # 'output_scores': False,
              # 'num_return_sequences': 1,
              'repetition_penalty': 1.0,
              # 'return_dict_in_generate': False,
              'temperature': 1.0,
              'top_k': 50,
              'top_p': 1.0, }

    # Make predictions
    predictions = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Batches"):
            batch = {k: v.to(device) for k, v in batch.items()}

            # Generate outputs
            batch_outputs = model.generate(
                input_ids=batch['input_ids'], attention_mask=batch['attention_mask'], **params, max_new_tokens=10)

            # Decode and store predictions
            decoded_outputs = [tokenizer.decode(
                output_id, skip_special_tokens=True) for output_id in batch_outputs]
            predictions.extend(decoded_outputs)

    # results.extend(predictions)

    # Ensure the new column has the same number of items as the dataset
    assert len(dataset_full) == len(
        predictions), "The length of new_column_values must match the dataset's length"

    # Add new column
    dataset_full = dataset_full.add_column('answer', predictions)
    dataset_full.save_to_disk('data/output/articles_processed_' + str(id))

    return dataset_full
    

Start one thread per GPU before collecting and merging the results:

In [33]:
# Assuming datasets and devices are lists containing the datasets and device names
datasets = chunks  # and so on...
devices = ['cuda:0', 'cuda:1']  # and so on...

# Calls the function to start the threads
returned_datasets = startThreads(len(datasets), datasets, devices, generatePredictions)
print("Number of returned datasets:", len(returned_datasets))

# Concatenate the returned datasets
merged_dataset = concatenate_datasets(returned_datasets)
merged_dataset.save_to_disk('../data/output/articles_processed')

# Print the length of the merged dataset
print("Processing on both GPUs completed!")
print("Results:", len(merged_dataset))

--------- Process 00 ---------
Dataset length: 55
Device: cuda:0
------------------------------
--------- Process 01 ---------
Dataset length: 54
Device: cuda:1
------------------------------
Device: Tesla P100-PCIE-16GB


Batches:   0%|          | 0/1 [00:00<?, ?it/s]

Device: Tesla P100-PCIE-16GB


Batches: 100%|██████████| 1/1 [00:04<00:00,  4.82s/it]
Batches: 100%|██████████| 1/1 [00:04<00:00,  4.27s/it] [00:00<?, ? examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 54/54 [00:00<00:00, 2474.03 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 55/55 [00:00<00:00, 2271.97 examples/s]


Number of returned datasets: 2


Saving the dataset (1/1 shards): 100%|██████████| 109/109 [00:00<00:00, 9289.62 examples/s]

Processing on both GPUs completed!
Results: 109





In [35]:
dataset = merged_dataset
dataset

Dataset({
    features: ['_id', 'title', 'url', 'text', 'hero', 'villain', 'victim', 'prompt', 'role', 'entity', 'input_ids', 'attention_mask', 'answer'],
    num_rows: 109
})

***

## Upload Results

In [36]:
from utils.preprocessing import *
from utils.database import *
from datasets import load_from_disk
from tqdm import tqdm

In [38]:
dataset = load_from_disk('../data/output/articles_processed')
describeDataset(dataset)

Number of rows: 109
Column names: ['_id', 'title', 'url', 'text', 'hero', 'villain', 'victim', 'prompt', 'role', 'entity', 'input_ids', 'attention_mask', 'answer']
Features (schema): {'_id': Value(dtype='string', id=None), 'title': Value(dtype='string', id=None), 'url': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'hero': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'villain': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'victim': Sequence(feature=Value(dtype='string', id=None), length=-1, id=None), 'prompt': Value(dtype='string', id=None), 'role': Value(dtype='string', id=None), 'entity': Value(dtype='string', id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'answer': Value(dtype='string', id=None)}


In [39]:
# Exmample amswer 
answer = dataset[0].get("answer")
role = dataset[0].get("role")
print("Role:", role)
print("Answer:", answer)   

Role: villain
Answer: List of United States Republican Party presidential candidates


### Connect to Database

In [40]:
_, db = getConnection(use_dotenv=True)

### Update Documents in Database

In [43]:
def processResults(dataset):
    # Initial processing results
    processing_result = {"hero": [], "villain": [], "victim": []}
    object_id_prev = None

    for item in dataset:
        object_id = item['_id']
        role = item['role']
        answer = item['answer']

        # If the object_id changes, reset the processing_result
        if object_id_prev is not None and object_id_prev != object_id:
            yield object_id_prev, processing_result
            processing_result = {"hero": [], "villain": [], "victim": []}

        processing_result[role].append(answer)
        object_id_prev = object_id

    # Yield the final processing_result if any
    if processing_result["hero"] or processing_result["villain"] or processing_result["victim"]:
        yield object_id_prev, processing_result

In [44]:
# Assuming `ds` is your dataset object
unique_ids = set(dataset["_id"])

# Count of unique ids
count_unique_ids = len(unique_ids)
# print(count_unique_ids)‚

for object_id, result in tqdm(processResults(dataset), total=count_unique_ids, desc="Uploading results"):
    pass
    #updateProcessingResults(db, object_id, {"denoising_result": result})

Uploading results: 100%|██████████| 48/48 [00:00<00:00, 256.49it/s]
