In [13]:
from datasets import load_dataset
from collections import defaultdict

In [14]:
from pprint import pprint

In [15]:
dt = load_dataset("nyu-mll/glue", "mrpc", split="train")

In [33]:
from datasets import load_dataset
from collections import defaultdict, deque

def count_positive_groups(dataset):
    # Build the graph: only connect sentences if they are a positive pair (label 1)
    adj = defaultdict(set)
    all_sentences = set()
    
    for row in dataset:
        if row['label'] == 1:
            s1, s2 = row['sentence1'], row['sentence2']
            adj[s1].add(s2)
            adj[s2].add(s1)
            all_sentences.add(s1)
            all_sentences.add(s2)
    
    # Traverse the graph to find connected components
    visited = set()
    num_groups = 0
    
    for sentence in all_sentences:
        if sentence not in visited:
            # Found a new group!
            num_groups += 1
            # Standard BFS to mark all sentences in this group as visited
            queue = deque([sentence])
            visited.add(sentence)
            while queue:
                current = queue.popleft()
                for neighbor in adj[current]:
                    if neighbor not in visited:
                        visited.add(neighbor)
                        queue.append(neighbor)
                        
    return num_groups

# --- Process all splits ---
splits = ['train', 'validation', 'test']

print(f"{'Split':<12} | {'Positive Pairs':<15} | {'Unique Positive Groups'}")
print("-" * 55)

for split_name in splits:
    ds = load_dataset("nyu-mll/glue", "mrpc", split=split_name)
    
    # Count total positive pairs for context
    pos_pairs = sum(1 for row in ds if row['label'] == 1)
    
    # Calculate transitive groups
    groups = count_positive_groups(ds)
    
    print(f"{split_name:<12} | {pos_pairs:<15} | {groups}")

Split        | Positive Pairs  | Unique Positive Groups
-------------------------------------------------------
train        | 2474            | 2338
validation   | 279             | 276
test         | 1147            | 1127


In [37]:
from datasets import load_dataset
from collections import defaultdict, deque

def analyze_group_completeness(dataset):
    adj = defaultdict(set)
    existing_positives = set()
    all_sentences = set()
    
    for row in dataset:
        if row['label'] == 1:
            s1, s2 = row['sentence1'], row['sentence2']
            adj[s1].add(s2)
            adj[s2].add(s1)
            existing_positives.add(frozenset([s1, s2]))
            all_sentences.add(s1)
            all_sentences.add(s2)
            
    visited = set()
    total_missing_pairs = 0
    groups_data = []

    for sentence in all_sentences:
        if sentence not in visited:
            # 1. Identify the group (Connected Component)
            component = []
            queue = deque([sentence])
            visited.add(sentence)
            while queue:
                curr = queue.popleft()
                component.append(curr)
                for neighbor in adj[curr]:
                    if neighbor not in visited:
                        visited.add(neighbor)
                        queue.append(neighbor)
            
            # 2. Calculate completeness for this group
            n = len(component)
            if n > 1:
                # Total pairs needed for a full clique: nC2
                required_pairs = (n * (n - 1)) // 2
                
                # Count how many pairs actually exist in this component
                actual_pairs = 0
                for i in range(n):
                    for j in range(i + 1, n):
                        if frozenset([component[i], component[j]]) in existing_positives:
                            actual_pairs += 1
                
                missing = required_pairs - actual_pairs
                total_missing_pairs += missing
                groups_data.append({
                    "size": n,
                    "actual": actual_pairs,
                    "required": required_pairs,
                    "missing": missing
                })
                
    return total_missing_pairs, groups_data

# --- Execution ---
for split in ['train', 'validation', 'test']:
    dt = load_dataset("nyu-mll/glue", "mrpc", split=split)
    total_missing, details = analyze_group_completeness(dt)
    
    print(f"=== Results for {split.upper()} ===")
    print(f"Total Missing Pairs to achieve Group Transitivity: {total_missing}")
    
    # Show the biggest/most incomplete group as an example
    if details:
        worst_group = max(details, key=lambda x: x['missing'])
        print(f"Largest Group Gap: Size {worst_group['size']} has {worst_group['actual']}/{worst_group['required']} pairs (Missing {worst_group['missing']})")
    print("-" * 40)

=== Results for TRAIN ===
Total Missing Pairs to achieve Group Transitivity: 149
Largest Group Gap: Size 5 has 4/10 pairs (Missing 6)
----------------------------------------
=== Results for VALIDATION ===
Total Missing Pairs to achieve Group Transitivity: 3
Largest Group Gap: Size 3 has 2/3 pairs (Missing 1)
----------------------------------------
=== Results for TEST ===
Total Missing Pairs to achieve Group Transitivity: 20
Largest Group Gap: Size 3 has 2/3 pairs (Missing 1)
----------------------------------------


In [22]:
from collections import Counter

# 1. Create a pool of all sentences
all_sentences = dt['sentence1'] + dt['sentence2']

# 2. Identify sentences that appear more than once
counts = Counter(all_sentences)
duplicate_sentences = {s for s, count in counts.items() if count > 1}

# 3. Count rows that contain at least one of these sentences
rows_with_duplicates = 0
for row in dt:
    if row['sentence1'] in duplicate_sentences or row['sentence2'] in duplicate_sentences:
        rows_with_duplicates += 1

print(f"Total rows in dataset: {len(dt)}")
print(f"Rows containing at least one duplicate sentence: {rows_with_duplicates}")
print(f"Percentage of dataset impacted: {(rows_with_duplicates / len(dt)) * 100:.2f}%")

Total rows in dataset: 3668
Rows containing at least one duplicate sentence: 527
Percentage of dataset impacted: 14.37%


In [32]:
from datasets import load_dataset

# Define the splits to analyze
splits = ['train', 'validation', 'test']

print(f"{'Split':<12} | {'Positive (1)':<12} | {'Negative (0)':<12} | {'Total':<8} | {'% Positive':<10}")
print("-" * 70)

for split_name in splits:
    # Load the specific split
    dataset = load_dataset("nyu-mll/glue", "mrpc", split=split_name)
    
    # Count occurrences of each label
    # label 1 = paraphrase (positive), label 0 = not paraphrase (negative)
    pos_count = sum(1 for x in dataset if x['label'] == 1)
    neg_count = sum(1 for x in dataset if x['label'] == 0)
    total = len(dataset)
    
    # Calculate percentage (handling division by zero just in case)
    labeled_total = pos_count + neg_count
    pos_percentage = (pos_count / labeled_total * 100) if labeled_total > 0 else 0
    
    print(f"{split_name:<12} | {pos_count:<12} | {neg_count:<12} | {total:<8} | {pos_percentage:>9.2f}%")

Split        | Positive (1) | Negative (0) | Total    | % Positive
----------------------------------------------------------------------
train        | 2474         | 1194         | 3668     |     67.45%
validation   | 279          | 129          | 408      |     68.38%
test         | 1147         | 578          | 1725     |     66.49%
