In [1]:
import pandas as pd
import os

In [14]:
def process_predictions(dev_file, prediction_file, output_dir, mode=None, model=None, expected_counts=None):
    """
    Processes predictions and generates language-specific files.
    """
    # Read the dev dataset
    df_dev = pd.read_csv(dev_file)
    # display(df_dev)

    # Read the prediction file
    predictions = []
    with open(prediction_file, 'r', encoding='utf-8') as file:
        for line in file:
            parts = line.strip().split('\t')
            if len(parts) == 2:  # Ensure correct format: filename and generated text
                article_id, generated = parts
                predictions.append({"filename": article_id, "generated": generated})

    df_predictions = pd.DataFrame(predictions)
    # display(df_predictions)

    # Merge predictions with the dev dataset
    df_combined = pd.merge(df_dev, df_predictions, on='filename', how='left')

    # Filter rows with generated predictions
    filtered_df = df_combined[df_combined['generated'].notna()]

    # Check expected counts
    if expected_counts:
        language_counts = filtered_df['language'].value_counts()
        for lang, expected in expected_counts.items():
            actual = language_counts.get(lang, 0)
            if actual != expected:
                print(f"Discrepancy for {lang}: Expected {expected}, Got {actual}")
                print(f"Missing data for {lang}:")
                print(df_combined[(df_combined['language'] == lang) & (df_combined['generated'].isna())])

    # Create output directory
    os.makedirs(output_dir, exist_ok=True)

    # Generate language-specific files
    for lang, group in filtered_df.groupby('language'):
        if mode and model:
            lang_output_file = os.path.join(output_dir, f"{mode}_{model}_pred_{lang}.txt")
        elif mode:
            lang_output_file = os.path.join(output_dir, f"{mode}_pred_{lang}.txt")
        elif model:
            lang_output_file = os.path.join(output_dir, f"{model}_pred_{lang}.txt")
        else:
            lang_output_file = os.path.join(output_dir, f"pred_{lang}.txt")

        group[['filename', 'generated']].to_csv(lang_output_file, sep='\t', index=False, header=False, encoding='utf-8')
        print(f"Saved: {lang_output_file}")

In [15]:
dev_file_path = './data/subtask_3_dev.csv'
# prediction_file_path = './predictions/generated_predictions/sft_predictions.txt'
prediction_file_path = './sub_dev/gemma/predictions.txt'
output_directory = './predictions/predictions_outputs'

process_predictions(
    dev_file=dev_file_path,
    prediction_file=prediction_file_path,
    output_dir=output_directory,
    mode=None,
    model='gemma',
    expected_counts={"EN": 30, "HI": 29, "BG": 28, "PT": 25}
)

Saved: ./predictions/predictions_outputs/gemma_pred_BG.txt
Saved: ./predictions/predictions_outputs/gemma_pred_EN.txt
Saved: ./predictions/predictions_outputs/gemma_pred_HI.txt
Saved: ./predictions/predictions_outputs/gemma_pred_PT.txt
