In [6]:
import os
import json
import argparse
from typing import List, Dict, Tuple, Any
from tqdm import tqdm

In [7]:
def read_text_file(file_path: str) -> List[str]:
    with open(file_path, "r", newline="", encoding="utf-8") as file:
        return file.read().strip().split("\n")

def read_truth_file(file_path: str) -> Dict[str, Any]:
    with open(file_path, "r", encoding="utf-8") as file:
        return json.load(file)

In [None]:
def create_contrastive_pairs(text_lines: List[str], changes: List[int]) -> List[Dict[str, Any]]:
    """Create contrastive learning pairs from text lines and change information."""
    if len(text_lines) - 1 != len(changes):
        print(f"Warning: Mismatch between text lines ({len(text_lines)}) and changes ({len(changes)})")
        # Handle the case where they don't match
        min_len = min(len(text_lines) - 1, len(changes))
        changes = changes[:min_len]
        text_lines = text_lines[:min_len + 1]
    
    pairs = []
    for i in range(len(changes)):
        pair = {
            "sentence1": text_lines[i],
            "sentence2": text_lines[i + 1],
            "label": changes[i],  # 0: same author, 1: different author
        }
        pairs.append(pair)
    
    return pairs

def process_directory(dir_path: str) -> List[Dict[str, Any]]:
    """Process all files in the directory and create contrastive pairs."""
    all_pairs = []
    problem_files = [f for f in os.listdir(dir_path) if f.startswith("problem-") and f.endswith(".txt")]
    
    for problem_file in tqdm(problem_files):
        problem_id = problem_file.replace("problem-", "").replace(".txt", "")
        truth_file = f"truth-problem-{problem_id}.json"
        
        text_path = os.path.join(dir_path, problem_file)
        truth_path = os.path.join(dir_path, truth_file)
        
        if not os.path.exists(truth_path):
            print(f"Warning: Truth file not found for {problem_file}")
            continue
        
        text_lines = read_text_file(text_path)
        truth_data = read_truth_file(truth_path)
        
        if "changes" not in truth_data:
            print(f"Warning: No 'changes' key in truth file for {problem_file}")
            continue
        
        changes = truth_data["changes"]
        pairs = create_contrastive_pairs(text_lines, changes)
        
        for pair in pairs:
            pair["problem_id"] = problem_id
        
        all_pairs.extend(pairs)
    
    return all_pairs

In [None]:
def save_dataset(pairs: List[Dict[str, Any]], output_file: str):
    """Save the dataset to a JSON file."""
    with open(output_file, "w", encoding="utf-8") as f:
        json.dump({"data": pairs}, f, indent=2)

def main():
    for difficulty in ["easy", "medium", "hard"]:

        dir_path_train = f"../data/{difficulty}/train"
        dir_path_valid = f"../data/{difficulty}/validation"

        output_train = f"../data/processed/{difficulty}_train.json"
        output_valid = f"../data/processed/{difficulty}_valid.json"
        
        all_pairs_train = process_directory(dir_path_train)
        print(f"Created {len(all_pairs_train)} contrastive pairs")
        save_dataset(all_pairs_train, output_train)
        print(f"Saved {len(all_pairs_train)} pairs to {output_train}")

        all_pairs_valid = process_directory(dir_path_valid)
        print(f"Created {len(all_pairs_valid)} contrastive pairs")
        save_dataset(all_pairs_valid, output_valid)
        print(f"Saved {len(all_pairs_valid)} pairs to {output_valid}")

if __name__ == "__main__":
    main()

100%|██████████| 4200/4200 [00:01<00:00, 2434.99it/s]


Created 48402 contrastive pairs
Saved 48402 pairs to ../data/processed/easy_train.json


100%|██████████| 900/900 [00:26<00:00, 34.50it/s]


Created 10247 contrastive pairs
Saved 10247 pairs to ../data/processed/easy_valid.json


100%|██████████| 4200/4200 [00:23<00:00, 176.64it/s] 


Created 58817 contrastive pairs
Saved 58817 pairs to ../data/processed/medium_train.json


100%|██████████| 900/900 [00:01<00:00, 700.17it/s]


Created 12759 contrastive pairs
Saved 12759 pairs to ../data/processed/medium_valid.json


100%|██████████| 4200/4200 [02:28<00:00, 28.37it/s]


Created 51061 contrastive pairs
Saved 51061 pairs to ../data/processed/hard_train.json


100%|██████████| 900/900 [00:25<00:00, 34.73it/s]


Created 10648 contrastive pairs
Saved 10648 pairs to ../data/processed/hard_valid.json
