In [1]:
import torch
import transformers
import pandas as pd
import numpy as np

from datasets import load_from_disk, Value, Dataset, Features, Sequence, ClassLabel
from argparse import Namespace
from tqdm.auto import tqdm
from transformers import pipeline
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import DataCollatorWithPadding
from transformers.pipelines.pt_utils import KeyDataset
from torch.utils.data import DataLoader

2025-08-02 23:45:13.339030: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [2]:
labels = ['post1geo10', 'post1geo20', 'post1geo30', 'post1geo50', 'post1geo70', 'post2geo10', 'post2geo20', 
          'post2geo30', 'post2geo50', 'post2geo70', 'post3geo10', 'post3geo20', 'post3geo30', 'post3geo50', 
          'post3geo70', 'post7geo10', 'post7geo20', 'post7geo30', 'post7geo50', 'post7geo70', 'pre1geo10', 
          'pre1geo20', 'pre1geo30', 'pre1geo50', 'pre1geo70', 'pre2geo10', 'pre2geo20', 'pre2geo30', 
          'pre2geo50', 'pre2geo70', 'pre3geo10', 'pre3geo20', 'pre3geo30', 'pre3geo50', 'pre3geo70', 
          'pre7geo10', 'pre7geo20', 'pre7geo30', 'pre7geo50', 'pre7geo70']

In [3]:
# Dynamically extract label names from dataset and create mappings
id2label = {i: label for i, label in enumerate(labels)}
label2id = {label: i for i, label in enumerate(labels)}

# 1. Prepare the dataset for inference

### Preprocess the dataset

In [None]:
# This is a very expensive operation.  It takes approximately 20 min to run
# At runtime, this command stores the entire dataset in /tmp.  This takes about 0.8 T of disk space.
# After using the dataset, delete the temp directory from /tmp, but this requires sudo priviledges
ds_splits = load_from_disk("/data3/mmendieta/Violence_data/geo_corpus.0.0.1_datasets")

In [None]:
ds_splits

In [None]:
# Peek at one sample
ds_splits["train"][0]

In [None]:
ds_test = ds_splits["test"]

In [None]:
ds_test

In [None]:
ds_test[0]

In [None]:
# Remove unncesary columns
columns_to_remove = ['retweetid', 'date', 'timestamp', 'username', 'key']

# Use the .remove_columns() method
ds_test = ds_test.remove_columns(columns_to_remove)

In [None]:
ds_test.features

In [None]:
new_features = ds_test.features.copy()

new_features['tweetid'] = Value(dtype='string')

# Iterate directly through your list of labels to cast them
for col_name in labels:
    # It's good practice to check if the column exists, though with a pre-defined list,
    # it's usually assumed all columns in the list exist.
    if col_name in new_features:
        new_features[col_name] = Value(dtype='float32')
    else:
        # This part will only execute if a column in your 'labels' list is missing
        print(f"Warning: Column '{col_name}' from the 'labels' list not found in dataset features.")

# Cast the dataset with the new features
# Since 'ds_test' is a single Dataset object, apply the .cast() method directly.
ds_test = ds_test.cast(new_features)

In [None]:
# This cell takes approximately 14 min to run
# It is important that the labels are float in order to calculate Binary Cross Entropy loss
# create 'labels' columm

# Define columns to ignore
ignore_columns = ["tweetid", "geo_x", "geo_y", "lang", "text"]

# Filter to only work on the test set
cols = [col for col in ds_test.column_names if col not in ignore_columns]

# Map function to create labels
ds_test = ds_test.map(lambda x: {"labels": [x[c] for c in cols]}, remove_columns=cols)

ds_test                                  


In [None]:
ds_test.features

In [None]:
ds_test[0]

In [None]:
ds_test.save_to_disk("/data4/mmendieta/data/geo_corpus.0.0.1_test_dataset_for_inference")

# 2. Tokenize the dataset

In [None]:
ds = load_from_disk("/data4/mmendieta/data/geo_corpus.0.0.1_test_dataset_for_inference")

In [None]:
# "model_ckpt": "/data3/mmendieta/models/ml_e5_large/"
# "model_ckpt": "setu4993/LaBSE"
# "model_ckpt": "setu4993/smaller-LaBSE"
# "model_ckpt": "cardiffnlp/twitter-xlm-roberta-base-sentiment"

# "fout": "/data3/mmendieta/Violence_data/geo_corpus.0.0.1_tok_test_ds_e5_inference"


config = {
    "model_ckpt": "setu4993/LaBSE",
    "batch_size": 1024,
    "num_labels" : 40,
    "max_length": 32,
    "seed": 42,
    "fout": "/data3/mmendieta/Violence_data/geo_corpus.0.0.1_tok_test_ds_labse_inference"
}

args = Namespace(**config)

In [None]:
# Instantiate the tokenizer
model_ckpt = args.model_ckpt
tokenizer = AutoTokenizer.from_pretrained(model_ckpt, 
                                              model_max_length=args.max_length)

In [None]:
def tokenize(batch):
    return tokenizer(batch["text"], truncation=True)

In [None]:
# This code takes 2 min to run
%time tokenized_ds = ds.map(tokenize, batched=True)

In [None]:
tokenized_ds.set_format('torch')

In [None]:
tokenized_ds

In [None]:
tokenized_ds.features

In [None]:
tokenized_ds.save_to_disk(args.fout)

# 3. Inference

In [4]:
config = {
    "cuda_device": 2,
    "path_to_model_on_disk": "/data4/mmendieta/models/labse_finetuned_twitter_all_labels/legendary-eon-1/epoch_19/", 
    "model_ckpt": "",
    "max_length": 32,
    "dataset_name": "/data3/mmendieta/Violence_data/geo_corpus.0.0.1_tok_test_ds_labse_inference",
    "batch_size": 1024,
    "fout_inference": "/data4/mmendieta/data/geo_corpus.0.0.1_tok_test_ds_labse_inference_results"
}

args = Namespace(**config)

In [5]:
# Recall that the test dataset has 2.329.158 observations. The inference is done in batches to exploit GPU resources
print(f"Loading dataset from: {args.dataset_name}")
try:
    # LOAD YOUR DATASET FROM DISK
    ds_tok = load_from_disk(args.dataset_name)

    print("Features of the loaded dataset:", ds_tok.features)
    
    # --- NEW DEBUG: Inspect example 0 ---
    if len(ds_tok) > 0:
        first_example_labels = ds_tok[0]['labels']
        print(f"\nDEBUG: After casting, ds_tok[0]['labels'] type: {type(first_example_labels)}")
        print(f"DEBUG: After casting, ds_tok[0]['labels'] value: {first_example_labels}")
        if isinstance(first_example_labels, np.ndarray):
            print(f"DEBUG: After casting, ds_tok[0]['labels'] shape: {first_example_labels.shape}")
        if isinstance(first_example_labels, torch.Tensor):
            print(f"DEBUG: After casting, ds_tok[0]['labels'] .shape: {first_example_labels.shape}")
            print(f"DEBUG: After casting, ds_tok[0]['labels'] .ndim: {first_example_labels.ndim}")
    # --- END NEW DEBUG ---

except Exception as e:
    print(f"Failed to load or cast dataset from {args.dataset_name}: {e}")
    print("Please ensure your actual dataset exists at the specified path and its content is compatible with the defined 'initial_features'.")
    raise # Re-raise the exception to stop execution if actual loading fails

Loading dataset from: /data3/mmendieta/Violence_data/geo_corpus.0.0.1_tok_test_ds_labse_inference
Features of the loaded dataset: {'tweetid': Value(dtype='string', id=None), 'geo_x': Value(dtype='float64', id=None), 'geo_y': Value(dtype='float64', id=None), 'lang': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'labels': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}

DEBUG: After casting, ds_tok[0]['labels'] type: <class 'torch.Tensor'>
DEBUG: After casting, ds_tok[0]['labels'] value: tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 1.])
DEBUG: After casting,

In [None]:
ds_tok

In [6]:
print("Current features of your loaded dataset:")
print(ds_tok.features)

# Now, specifically look at the 'labels' entry
if 'labels' in ds_tok.features:
    labels_feature = ds_tok.features['labels']
    print(f"\nFeature for 'labels' column: {labels_feature}")

    if isinstance(labels_feature, Sequence):
        print("The 'labels' column is a Sequence (list-like).")
        inner_feature = labels_feature.feature
        print(f"  Inner feature type: {inner_feature}")
        if isinstance(inner_feature, ClassLabel):
            print(f"  It contains ClassLabel objects (strings that map to predefined names). Names: {inner_feature.names}")
        elif isinstance(inner_feature, Value):
            print(f"  It contains Value objects (e.g., integers, floats). Dtype: {inner_feature.dtype}")
        else:
            print("  Unknown inner feature type for Sequence.")
    elif isinstance(labels_feature, ClassLabel):
        print(f"The 'labels' column contains ClassLabel objects (single strings mapping to predefined names). Names: {labels_feature.names}")
    elif isinstance(labels_feature, Value):
        print(f"The 'labels' column contains single Value objects (e.g., integers, floats). Dtype: {labels_feature.dtype}")
    else:
        print("The 'labels' column has an unexpected feature type.")
else:
    print("The 'labels' column does not exist in your dataset's features.")

# Also inspect the first few raw data entries to be absolutely sure
print("\nFirst 5 entries of the 'labels' column (raw data):")
for i in range(min(5, len(ds_tok))):
    print(f"  Example {i}: {ds_tok[i]['labels']}")

Current features of your loaded dataset:
{'tweetid': Value(dtype='string', id=None), 'geo_x': Value(dtype='float64', id=None), 'geo_y': Value(dtype='float64', id=None), 'lang': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'labels': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}

Feature for 'labels' column: Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None)
The 'labels' column is a Sequence (list-like).
  Inner feature type: Value(dtype='float64', id=None)
  It contains Value objects (e.g., integers, floats). Dtype: float64

First 5 entries of the 'labels' column (raw data):
  Example 0: tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
 

In [None]:
# If using the full dataset, make sure to change ds_tok_sample by ds_tok
# First, shuffle the dataset
ds_tok_shuffled = ds_tok.shuffle(seed=42) # Using a seed for reproducibility

# Then, select the first 10,000 observations from the shuffled dataset
ds_tok_sample = ds_tok_shuffled.select(range(10000))

In [7]:
# --- CRITICAL NEW STEP: Ensure labels are plain Python lists after casting ---
# This is the key fix given the PyTorch Tensor issue.
# We convert the labels column to a pure Python list of floats.
# This code takes approximately 30 min with the full test dataset

def convert_labels_to_list(example, idx): # Add idx for debugging
    original_labels = example['labels']

    if idx == 999: # Focus on the problematic example
        print(f"\n--- DEBUGGING EXAMPLE {idx} in convert_labels_to_list ---")
        print(f"Labels BEFORE conversion: Type={type(original_labels)}, Value={original_labels}")
        if isinstance(original_labels, (np.ndarray, torch.Tensor)):
            print(f"Labels BEFORE conversion: Shape={original_labels.shape}, NDIM={original_labels.ndim}")


    if isinstance(original_labels, torch.Tensor):
        # Always convert to list of floats, specifically handling 0-D tensors if they appear
        if original_labels.ndim == 0: # Handle scalar tensor
            converted_labels = [original_labels.item()]
        else: # Handle 1-D or higher-D tensors
            converted_labels = original_labels.detach().cpu().numpy().tolist()
    elif isinstance(original_labels, np.ndarray):
        if original_labels.ndim == 0: # Handle scalar numpy array
            converted_labels = [original_labels.item()]
        else: # Handle 1-D or higher-D numpy arrays
            converted_labels = original_labels.tolist()
    elif isinstance(original_labels, (float, int)): # Directly handle Python scalars
        converted_labels = [float(original_labels)]
    else: # Assume it's already a list or other iterable, or raise an error if unexpected
        converted_labels = original_labels # Keep as is, or convert if needed
        if not isinstance(converted_labels, list):
             # This means it's neither tensor, numpy array, float, int, nor list.
             # This case should ideally not happen if features are correctly casted.
             print(f"WARNING: Example {idx}, 'labels' is an unexpected type: {type(original_labels)}. Attempting to convert to list.")
             try:
                 converted_labels = list(original_labels)
             except TypeError:
                 raise TypeError(f"Example {idx}, 'labels' could not be converted to a list: {original_labels}")

    # Final check to ensure it's a list
    if not isinstance(converted_labels, list):
        raise TypeError(f"CRITICAL ERROR: Example {idx}, 'labels' is not a list AFTER conversion: "
                        f"Type={type(converted_labels)}, Value={converted_labels}")
    # Ensure all elements in the list are floats
    if not all(isinstance(x, (float, int)) for x in converted_labels):
        print(f"WARNING: Example {idx}, 'labels' list contains non-numeric elements: {converted_labels}")
        # Attempt to cast all elements to float if possible
        converted_labels = [float(x) for x in converted_labels]

    example['labels'] = converted_labels

    if idx == 999:
        print(f"Labels AFTER conversion: Type={type(example['labels'])}, Value={example['labels']}")
        print(f"--- END DEBUGGING EXAMPLE {idx} ---")

    return example

print("\nConverting 'labels' column to plain Python lists of floats...")
ds_tok = ds_tok.map(convert_labels_to_list, batched=False, with_indices=True, desc="Converting labels to list")
print("'labels' column conversion complete.")
print(f"Sample dataset features AFTER label conversion: {ds_tok.features}")


Converting 'labels' column to plain Python lists of floats...


Converting labels to list:   0%|          | 0/2329158 [00:00<?, ? examples/s]


--- DEBUGGING EXAMPLE 999 in convert_labels_to_list ---
Labels BEFORE conversion: Type=<class 'torch.Tensor'>, Value=tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.,
        0., 0., 1., 1.])
Labels BEFORE conversion: Shape=torch.Size([40]), NDIM=1
Labels AFTER conversion: Type=<class 'list'>, Value=[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0]
--- END DEBUGGING EXAMPLE 999 ---
'labels' column conversion complete.
Sample dataset features AFTER label conversion: {'tweetid': Value(dtype='string', id=None), 'geo_x': Value(dtype='float64', id=None), 'geo_y': Value(dtype='float64', id=None), 'lang': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'labels': Sequence(feature=Value(dtype='float64', id=None), 

In [8]:
# Verify the conversion for example 0
if len(ds_tok) > 0:
    first_example_labels_after_conversion = ds_tok[0]['labels']
    print(f"DEBUG: After conversion, ds_tok[0]['labels'] type: {type(first_example_labels_after_conversion)}")
    print(f"DEBUG: After conversion, ds_tok[0]['labels'] value: {first_example_labels_after_conversion}")

    # *** MODIFICATION HERE ***
    # Change the type check to expect torch.Tensor, which is what datasets is producing.
    if not isinstance(first_example_labels_after_conversion, torch.Tensor):
        # If it's not a torch.Tensor, then there's still an unexpected type.
        raise TypeError(f"CRITICAL ERROR: ds_tok[0]['labels'] is NOT a torch.Tensor after conversion: "
                        f"Type={type(first_example_labels_after_conversion)}")
    else:
        # Optional: Verify the dtype of the tensor
        if first_example_labels_after_conversion.dtype != torch.float32:
            print(f"WARNING: ds_tok_sample[0]['labels'] is a Tensor but its dtype is {first_example_labels_after_conversion.dtype}, not float32. Consider casting if needed downstream.")

print("\nVerification of labels column type complete (expecting torch.Tensor).")

DEBUG: After conversion, ds_tok[0]['labels'] type: <class 'torch.Tensor'>
DEBUG: After conversion, ds_tok[0]['labels'] value: tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 1.])

Verification of labels column type complete (expecting torch.Tensor).


In [9]:
# If using a sample dataset, change the number of the line ' if i % 50000 == 0' to 1000
# --- Data Integrity Check for Sequence Columns (Keep this for robustness) ---
print("\n--- Starting Data Integrity Check for Sequence Columns ---")
problem_found = False
problematic_cols = ['labels', 'input_ids', 'attention_mask']

for i, example in enumerate(ds_tok):
    if i % 50000 == 0:
        print(f"Checking example {i}/{len(ds_tok)}")

    for col in problematic_cols:
        if col not in example:
            print(f"ERROR: Column '{col}' missing from example {i}")
            problem_found = True
            continue

        item = example[col]
        if col == 'labels':
            # Allow for both list and torch.Tensor, as datasets might re-tensorize
            if not isinstance(item, (list, torch.Tensor)):
                print(f"ERROR: Example {i}, column '{col}': Expected a list or torch.Tensor, but found type {type(item)} with value {item}")
                problem_found_in_initial_data = True
            elif isinstance(item, torch.Tensor):
                if item.ndim == 0:
                    print(f"ERROR: Example {i}, column '{col}': Found a 0-dimensional torch.Tensor. Value: {item}")
                    problem_found_in_initial_data = True
                elif item.dtype not in [torch.float32, torch.float64]: # Check for float dtype
                    print(f"WARNING: Example {i}, column '{col}': torch.Tensor has dtype {item.dtype}, expected float. Value: {item}")
                elif item.shape[-1] != 40: # Assuming 40 is the expected length
                    print(f"WARNING: Example {i}, column '{col}': torch.Tensor length is {item.shape[-1]}, expected 40. Value: {item}")
            elif isinstance(item, list):
                if not all(isinstance(x, (float, int)) for x in item):
                    print(f"ERROR: Example {i}, column '{col}': List contains non-numeric elements. Value: {item}")
                    problem_found_in_initial_data = True
                elif len(item) == 0:
                    print(f"WARNING: Example {i}, column '{col}': Found an empty list. Value: {item}")
                elif len(item) != 40:
                    print(f"WARNING: Example {i}, column '{col}': List length is {len(item)}, expected 40. Value: {item}")

        else: # For input_ids, attention_mask, etc.
            # Keep these checks largely the same, but remember ds_tok_sample.features
            # indicates 'input_ids' and 'attention_mask' are Sequence(Value(dtype='int32/int8')).
            # Datasets will likely keep these as lists or numpy arrays unless set_format("torch") is used.
            if not isinstance(item, (list, np.ndarray, torch.Tensor)): # Add torch.Tensor just in case
                print(f"ERROR: Example {i}, column '{col}': Expected a list, array, or tensor, but found scalar type {type(item)} with value {item}")
                problem_found_in_initial_data = True
            elif isinstance(item, np.ndarray) and item.ndim == 0:
                print(f"ERROR: Example {i}, column '{col}': Found a 0-dimensional NumPy array. Value: {item}")
                problem_found_in_initial_data = True
            elif isinstance(item, torch.Tensor) and item.ndim == 0:
                 print(f"ERROR: Example {i}, column '{col}': Found a 0-dimensional torch.Tensor. Value: {item}")
                 problem_found_in_initial_data = True
            elif isinstance(item, (list, np.ndarray, torch.Tensor)) and (len(item) == 0 if isinstance(item, list) else item.size == 0):
                print(f"WARNING: Example {i}, column '{col}': Found an empty list/array/tensor. Value: {item}")
            elif isinstance(item, list) and not all(isinstance(x, (float, int)) for x in item): # assuming int types for ids/mask
                print(f"ERROR: Example {i}, column '{col}': List contains non-numeric elements. Value: {item}")
            # For tensors/arrays, you might want to check dtype too
            elif isinstance(item, (np.ndarray, torch.Tensor)) and item.dtype not in [np.int32, np.int8, torch.int32, torch.int8, torch.long]:
                print(f"WARNING: Example {i}, column '{col}': Array/Tensor has unexpected dtype {item.dtype}. Value: {item}")
        
if problem_found:
    print("\n--- Data integrity issues found and attempted to fix. Please review warnings/errors. ---")
else:
    print("\n--- Data integrity check passed. No problematic scalar or 0-d array entries found. ---")

# --- END Data Integrity Check ---


--- Starting Data Integrity Check for Sequence Columns ---
Checking example 0/2329158
Checking example 50000/2329158
Checking example 100000/2329158
Checking example 150000/2329158
Checking example 200000/2329158
Checking example 250000/2329158
Checking example 300000/2329158
Checking example 350000/2329158
Checking example 400000/2329158
Checking example 450000/2329158
Checking example 500000/2329158
Checking example 550000/2329158
Checking example 600000/2329158
Checking example 650000/2329158
Checking example 700000/2329158
Checking example 750000/2329158
Checking example 800000/2329158
Checking example 850000/2329158
Checking example 900000/2329158
Checking example 950000/2329158
Checking example 1000000/2329158
Checking example 1050000/2329158
Checking example 1100000/2329158
Checking example 1150000/2329158
Checking example 1200000/2329158
Checking example 1250000/2329158
Checking example 1300000/2329158
Checking example 1350000/2329158
Checking example 1400000/2329158
Checking 

In [10]:
# Instantiate the pipeline with the model of choice
print(f"Initializing pipeline with model: {args.path_to_model_on_disk} on device: {args.cuda_device}")
violence_pipe = pipeline(model=args.path_to_model_on_disk,
                         task="text-classification", # This line helps with e5. For the other models is not necessary
                         device=args.cuda_device,
                         framework="pt",
                         return_all_scores=True)

Initializing pipeline with model: /data4/mmendieta/models/labse_finetuned_twitter_all_labels/legendary-eon-1/epoch_19/ on device: 2




In [11]:
# This code takes approximately 1h 45 min to run on the full test dataset
# Perform Inference
preds = []

for i, outputs in enumerate(tqdm(violence_pipe(KeyDataset(ds_tok, "text"), batch_size=args.batch_size,
                                              truncation=True),
                                 total=len(ds_tok))):
    preds.append(outputs)

  0%|          | 0/2329158 [00:00<?, ?it/s]

In [12]:
preds[0]

[{'label': 'post1geo10', 'score': 0.5103994607925415},
 {'label': 'post1geo20', 'score': 0.5083104968070984},
 {'label': 'post1geo30', 'score': 0.505756139755249},
 {'label': 'post1geo50', 'score': 0.4982481896877289},
 {'label': 'post1geo70', 'score': 0.6435137987136841},
 {'label': 'post2geo10', 'score': 0.48857489228248596},
 {'label': 'post2geo20', 'score': 0.48798850178718567},
 {'label': 'post2geo30', 'score': 0.4826910197734833},
 {'label': 'post2geo50', 'score': 0.4746701419353485},
 {'label': 'post2geo70', 'score': 0.6336734890937805},
 {'label': 'post3geo10', 'score': 0.47698694467544556},
 {'label': 'post3geo20', 'score': 0.47777342796325684},
 {'label': 'post3geo30', 'score': 0.46957069635391235},
 {'label': 'post3geo50', 'score': 0.45954760909080505},
 {'label': 'post3geo70', 'score': 0.6274043321609497},
 {'label': 'post7geo10', 'score': 0.4379730522632599},
 {'label': 'post7geo20', 'score': 0.43846064805984497},
 {'label': 'post7geo30', 'score': 0.42034977674484253},
 {'

In [13]:
processed_data = []
for pred in preds:
    scores = {item['label']: item['score'] for item in pred}
    processed_data.append(scores)

In [14]:
processed_data[0]

{'post1geo10': 0.5103994607925415,
 'post1geo20': 0.5083104968070984,
 'post1geo30': 0.505756139755249,
 'post1geo50': 0.4982481896877289,
 'post1geo70': 0.6435137987136841,
 'post2geo10': 0.48857489228248596,
 'post2geo20': 0.48798850178718567,
 'post2geo30': 0.4826910197734833,
 'post2geo50': 0.4746701419353485,
 'post2geo70': 0.6336734890937805,
 'post3geo10': 0.47698694467544556,
 'post3geo20': 0.47777342796325684,
 'post3geo30': 0.46957069635391235,
 'post3geo50': 0.45954760909080505,
 'post3geo70': 0.6274043321609497,
 'post7geo10': 0.4379730522632599,
 'post7geo20': 0.43846064805984497,
 'post7geo30': 0.42034977674484253,
 'post7geo50': 0.39889204502105713,
 'post7geo70': 0.6028570532798767,
 'pre1geo10': 0.5066482424736023,
 'pre1geo20': 0.5058010816574097,
 'pre1geo30': 0.5055418610572815,
 'pre1geo50': 0.496917188167572,
 'pre1geo70': 0.6394330263137817,
 'pre2geo10': 0.4947558343410492,
 'pre2geo20': 0.4973897933959961,
 'pre2geo30': 0.4956568777561188,
 'pre2geo50': 0.48467

In [15]:
# --- CRITICAL: Aligning Pipeline Output Labels to Your True Labels ---
# Get the generic labels (e.g., 'LABEL_0', 'LABEL_1', ...) from the pipeline's output.
if processed_data:
    if isinstance(processed_data[0], list):
        # Expected format: [{'label': 'post1geo10', 'score': 0.9}, {'label': 'pre2geo20', 'score': 0.1}]
        pipeline_output_labels = [item['label'] for item in processed_data[0]]
    elif isinstance(processed_data[0], dict):
        # Less common for return_all_scores=True but possible: {'post1geo10': 0.9, 'pre2geo20': 0.1}
        pipeline_output_labels = list(processed_data[0].keys())
    else:
        raise TypeError(f"Unexpected format for processed_data[0]: {type(processed_data[0])}")

    # Create a mapping from pipeline label (string) to its index in your predefined `labels` list
    # This helps sort them consistently even if pipeline output order varies.
    pipeline_label_to_id = {label_name: i for i, label_name in enumerate(labels)}

    # Sort the pipeline output labels based on your predefined order
    # Filter out any pipeline labels that are not in your `labels` list to avoid KeyError
    pipeline_output_labels_ordered = sorted(
        [lbl for lbl in pipeline_output_labels if lbl in pipeline_label_to_id],
        key=lambda lbl: pipeline_label_to_id[lbl]
    )

    if not pipeline_output_labels_ordered:
        print("WARNING: No pipeline output labels matched your predefined 'labels' list. Check model's output labels.")

else:
    raise ValueError("No predictions were processed. Ensure your inference code runs correctly.")


if len(pipeline_output_labels_ordered) != len(labels):
    print(f"WARNING: Mismatch in label count! Pipeline outputs {len(pipeline_output_labels_ordered)} labels, "
          f"but you have {len(labels)} true labels defined. "
          "This might indicate some labels were not predicted or are named differently.")
    print(f"Predefined labels: {labels}")
    print(f"Pipeline discovered labels (ordered): {pipeline_output_labels_ordered}")


pipeline_label_to_pred_col_name_map = {}
for actual_label_name in pipeline_output_labels_ordered:
    # Ensure only labels that are in your original `labels` list are mapped
    if actual_label_name in labels:
        pipeline_label_to_pred_col_name_map[actual_label_name] = f"pred_{actual_label_name}"
    else:
        print(f"WARNING: Skipping pipeline label '{actual_label_name}' as it's not in your predefined 'labels' list.")

print("\nPipeline output label to 'pred_' column name mapping created:")
print(pipeline_label_to_pred_col_name_map)


Pipeline output label to 'pred_' column name mapping created:
{'post1geo10': 'pred_post1geo10', 'post1geo20': 'pred_post1geo20', 'post1geo30': 'pred_post1geo30', 'post1geo50': 'pred_post1geo50', 'post1geo70': 'pred_post1geo70', 'post2geo10': 'pred_post2geo10', 'post2geo20': 'pred_post2geo20', 'post2geo30': 'pred_post2geo30', 'post2geo50': 'pred_post2geo50', 'post2geo70': 'pred_post2geo70', 'post3geo10': 'pred_post3geo10', 'post3geo20': 'pred_post3geo20', 'post3geo30': 'pred_post3geo30', 'post3geo50': 'pred_post3geo50', 'post3geo70': 'pred_post3geo70', 'post7geo10': 'pred_post7geo10', 'post7geo20': 'pred_post7geo20', 'post7geo30': 'pred_post7geo30', 'post7geo50': 'pred_post7geo50', 'post7geo70': 'pred_post7geo70', 'pre1geo10': 'pred_pre1geo10', 'pre1geo20': 'pred_pre1geo20', 'pre1geo30': 'pred_pre1geo30', 'pre1geo50': 'pred_pre1geo50', 'pre1geo70': 'pred_pre1geo70', 'pre2geo10': 'pred_pre2geo10', 'pre2geo20': 'pred_pre2geo20', 'pre2geo30': 'pred_pre2geo30', 'pre2geo50': 'pred_pre2geo50

In [16]:
# --- STEP 1: Append prediction columns to the Dataset ---
def add_predictions_to_example_batched(examples, indices): 
    # 'examples' is a dictionary like {'text': [...], 'input_ids': [...], ...}
    # 'indices' is a list of indices for the current batch
    batch_size = len(indices)
    new_columns_data = {col_name: [] for col_name in pipeline_label_to_pred_col_name_map.values()}

    for i, global_idx in enumerate(indices):
        # Access predictions for the current example within the batch
        raw_prediction_output_for_example = processed_data[global_idx]

        # Debugging for a specific example (e.g., global_idx 999)
        # Remember: 'labels' here would be examples['labels'][i] for the current example in the batch
        if global_idx == 999 and i < batch_size: # Ensure index is within current batch
            print(f"\n--- DEBUGGING EXAMPLE {global_idx} in add_predictions_to_example_batched ---")
            print(f"Labels column content (for this example in batch): {examples['labels'][i]}")
            print(f"Labels column type (for this example in batch): {type(examples['labels'][i])}")
            if isinstance(examples['labels'][i], np.ndarray):
                print(f"Labels column shape (for this example in batch): {examples['labels'][i].shape}")
            print(f"--- END DEBUGGING EXAMPLE {global_idx} ---")


        for original_label_str, pred_col_name in pipeline_label_to_pred_col_name_map.items():
            score = raw_prediction_output_for_example.get(original_label_str, 0.0)

            if not isinstance(score, (float, int, np.number)):
                raise ValueError(f"CRITICAL ERROR: Example {global_idx}, Column '{pred_col_name}': "
                                 f"Score is not a recognized numeric type! Type={type(score)}, Value={score}")

            if isinstance(score, np.number):
                score = score.item()

            new_columns_data[pred_col_name].append(float(score))

    # Return a dictionary where keys are new column names and values are lists of scores for the batch
    return new_columns_data

In [17]:
# --- STEP 2: Define final_dataset_features (TOP-LEVEL SCRIPT CODE) ---
final_dataset_features = ds_tok.features.copy()
for pred_col_name in pipeline_label_to_pred_col_name_map.values():
    final_dataset_features[pred_col_name] = Value("float32")

print("\nDefined final features for dataset after adding prediction columns:")
print(final_dataset_features)


Defined final features for dataset after adding prediction columns:
{'tweetid': Value(dtype='string', id=None), 'geo_x': Value(dtype='float64', id=None), 'geo_y': Value(dtype='float64', id=None), 'lang': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'labels': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'pred_post1geo10': Value(dtype='float32', id=None), 'pred_post1geo20': Value(dtype='float32', id=None), 'pred_post1geo30': Value(dtype='float32', id=None), 'pred_post1geo50': Value(dtype='float32', id=None), 'pred_post1geo70': Value(dtype='float32', id=None), 'pred_post2geo10': Value(dtype='float32', id=None), 'pred_post2geo20': Value(dtype='float32', id=None), 'pred_post2geo30': Value

In [18]:
# --- STEP 3: Call .map using the batched function (TOP-LEVEL SCRIPT CODE) ---
ds_with_predictions = ds_tok.map(
    add_predictions_to_example_batched, # This is the *function object* passed to map
    with_indices=True,
    batched=True,                       # Critical for performance
    batch_size=args.batch_size,                    # Adjust as needed
    features=final_dataset_features,
    # Remove writer_batch_size=1
)

print("Prediction columns added to the dataset.")
print(ds_with_predictions.features)

Map:   0%|          | 0/2329158 [00:00<?, ? examples/s]


--- DEBUGGING EXAMPLE 999 in add_predictions_to_example_batched ---
Labels column content (for this example in batch): tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 1., 0.,
        0., 0., 1., 1.])
Labels column type (for this example in batch): <class 'torch.Tensor'>
--- END DEBUGGING EXAMPLE 999 ---
Prediction columns added to the dataset.
{'tweetid': Value(dtype='string', id=None), 'geo_x': Value(dtype='float64', id=None), 'geo_y': Value(dtype='float64', id=None), 'lang': Value(dtype='string', id=None), 'text': Value(dtype='string', id=None), 'labels': Sequence(feature=Value(dtype='float64', id=None), length=-1, id=None), 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None), 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None), 'pred_

In [19]:
# --- STEP 4: Sample of the resulting dataset (TOP-LEVEL SCRIPT CODE) ---
print("\nSample of the resulting dataset (first row):")
if len(ds_with_predictions) > 0:
    first_sample = ds_with_predictions[0]
    for key, value in first_sample.items():
        if key in ['labels', 'text'] or key.startswith('pred_'):
            print(f"{key}: {value}")
else:
    print("Dataset is empty after mapping.")


Sample of the resulting dataset (first row):
text: talking abt my case ☺️
labels: tensor([0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0.,
        0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 1., 0.,
        0., 0., 0., 1.])
pred_post1geo10: 0.5103994607925415
pred_post1geo20: 0.5083104968070984
pred_post1geo30: 0.505756139755249
pred_post1geo50: 0.4982481896877289
pred_post1geo70: 0.6435137987136841
pred_post2geo10: 0.48857489228248596
pred_post2geo20: 0.48798850178718567
pred_post2geo30: 0.4826910197734833
pred_post2geo50: 0.4746701419353485
pred_post2geo70: 0.6336734890937805
pred_post3geo10: 0.47698694467544556
pred_post3geo20: 0.47777342796325684
pred_post3geo30: 0.46957069635391235
pred_post3geo50: 0.45954760909080505
pred_post3geo70: 0.6274043321609497
pred_post7geo10: 0.4379730522632599
pred_post7geo20: 0.43846064805984497
pred_post7geo30: 0.42034977674484253
pred_post7geo50: 0.39889204502105713
pred_post7geo70: 0.6028570532798767
pred

In [20]:
# --- STEP 5: Save the dataset ---
try:
    ds_with_predictions.save_to_disk(args.fout_inference)
    print(f"Dataset with predictions successfully saved to: {args.fout_inference}")
except Exception as e:
    print(f"ERROR: Failed to save dataset with predictions to disk: {e}")

Saving the dataset (0/4 shards):   0%|          | 0/2329158 [00:00<?, ? examples/s]

Dataset with predictions successfully saved to: /data4/mmendieta/data/geo_corpus.0.0.1_tok_test_ds_labse_inference_results


In [21]:
ds_with_predictions

Dataset({
    features: ['tweetid', 'geo_x', 'geo_y', 'lang', 'text', 'labels', 'input_ids', 'token_type_ids', 'attention_mask', 'pred_post1geo10', 'pred_post1geo20', 'pred_post1geo30', 'pred_post1geo50', 'pred_post1geo70', 'pred_post2geo10', 'pred_post2geo20', 'pred_post2geo30', 'pred_post2geo50', 'pred_post2geo70', 'pred_post3geo10', 'pred_post3geo20', 'pred_post3geo30', 'pred_post3geo50', 'pred_post3geo70', 'pred_post7geo10', 'pred_post7geo20', 'pred_post7geo30', 'pred_post7geo50', 'pred_post7geo70', 'pred_pre1geo10', 'pred_pre1geo20', 'pred_pre1geo30', 'pred_pre1geo50', 'pred_pre1geo70', 'pred_pre2geo10', 'pred_pre2geo20', 'pred_pre2geo30', 'pred_pre2geo50', 'pred_pre2geo70', 'pred_pre3geo10', 'pred_pre3geo20', 'pred_pre3geo30', 'pred_pre3geo50', 'pred_pre3geo70', 'pred_pre7geo10', 'pred_pre7geo20', 'pred_pre7geo30', 'pred_pre7geo50', 'pred_pre7geo70'],
    num_rows: 2329158
})