In [1]:
import json
import pandas as pd
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
import os
from collections import Counter, defaultdict
from datetime import datetime
import warnings
warnings.filterwarnings('ignore')

sns.set_style('whitegrid')
plt.rcParams['figure.figsize'] = (12, 6)

## 1. Dataset Overview & Statistics

In [None]:
data_root = "../data"

tasks = [
    #'Citation_Identification',
      'Email_Subject_Generation',
    # 'Movie_Tagging',
     'News_Headline_Generation',
#     'Product_Rating',
#     'Scholarly_Title_Generation',
#     'Tweet_Paraphrasing'
]

def load_json_sample(filepath):
    try:
        with open(filepath, 'r') as f:
            data = json.load(f)
            return data[:max_items] if isinstance(data, list) else data
    except Exception as e:
        print(f"Error loading {filepath}: {e}")
        return []

def get_file_stats(filepath):
    try:
        with open(filepath, 'r') as f:
            data = json.load(f)
            return len(data) if isinstance(data, list) else 1
    except:
        return 0
    
def get_dataset_stats(tasks, data_root, max_items=100):
    dataset_stats = []
    for task in tasks:
        task_path = os.path.join(data_root, task)
        
        for split in tqdm(['Train', 'Validation', 'Test'], desc=f"Collecting data from {task}"):
            for scenario in ['time_based', 'user_based']:

                if split == 'Test':
                    input_file = os.path.join(task_path ,split ,f'test_questions_{scenario}.json')
                else:
                    input_file = os.path.join(task_path, split, 'Inputs',  f'{split.lower()}_questions_{scenario}.json')

                if split != 'Test':
                    output_file = os.path.join(task_path, split, 'Outputs', f'{split.lower()}_outputs_{scenario}.json')
                else:
                    output_file = None
                
                if os.path.exists(input_file):
                    num_samples = pd.read_json(input_file).head(max_items)
                    num_outputs = get_file_stats(output_file) if output_file and os.path.exists(output_file) else 0
                    
                    dataset_stats.append({
                        'Task': task,
                        'Split': split,
                        'Scenario': scenario,
                        'Num_Samples': num_samples,
                        'Num_Outputs': num_outputs,
                        'Has_Labels': num_outputs > 0
                    })

        return dataset_stats

dataset_stats = get_dataset_stats(tasks, data_root)

stats_df = pd.DataFrame(dataset_stats)
print("Dataset overview:")
print("=" * 80)
print(stats_df)
print("\n" + "=" * 80)
print(f"\nTotal samples across all tasks: {int(stats_df['Num_Samples'].apply(len).sum())}")

Collecting data from Citation_Identification: 100%|██████████| 3/3 [00:55<00:00, 18.57s/it]


Dataset overview:
                      Task       Split    Scenario  \
0  Citation_Identification       Train  time_based   
1  Citation_Identification       Train  user_based   
2  Citation_Identification  Validation  time_based   
3  Citation_Identification  Validation  user_based   
4  Citation_Identification        Test  time_based   
5  Citation_Identification        Test  user_based   

                                         Num_Samples  Num_Outputs  Has_Labels  
0        id                                      ...            1        True  
1      id                                        ...            1        True  
2        id                                      ...            1        True  
3       id                                       ...            1        True  
4        id                                      ...            0       False  
5       id                                       ...            0       False  


Total samples across all tasks: 600


In [None]:
task_samples = stats_df.groupby("Task")

Unnamed: 0,Task,Split,Scenario,Num_Samples,Num_Outputs,Has_Labels
0,Citation_Identification,Train,time_based,id ...,1,True
1,Citation_Identification,Train,user_based,id ...,1,True
2,Citation_Identification,Validation,time_based,id ...,1,True
3,Citation_Identification,Validation,user_based,id ...,1,True
4,Citation_Identification,Test,time_based,id ...,0,False
5,Citation_Identification,Test,user_based,id ...,0,False


In [None]:
samples = stats_df['Num_Samples']
samples

0           id              input  \
0    1200  175...
1          id              input                   ...
2           id              input  \
0    1210  163...
3          id              input                   ...
4           id              input  \
0    1220  059...
5          id              input                   ...
Name: Num_Samples, dtype: object

: 