In [1]:
# import SNLI dataset from h
from datasets import load_dataset
dataset = load_dataset("snli")
# Access the splits
train_data = dataset['train']
validation_data = dataset['validation']
test_data = dataset['test']

In [2]:
train_data[0]

{'premise': 'A person on a horse jumps over a broken down airplane.',
 'hypothesis': 'A person is training his horse for a competition.',
 'label': 1}

In [148]:
import re
import string
# Define the punctuation set we care about
PUNCT = {'.', '!', '?'}
common_contractions = {
    "do not": "don't",
    "is not": "isn't",
    "are not": "aren't",
    "it is": "it's",
    "that is": "that's",
    "we are": "we're",
    "you are": "you're",
    "I am": "I'm",
    "I will": "I'll",
    "I would": "I'd",
    "they are": "they're",
    "will not": "won't",
    "can not": "can't",
    "there is": "there's"
}

def encased_with_apostrophes(text):
    # Check if the text is encased with standard quotes (artificat in SNLI)
    return text.startswith('"') and text.endswith('"')

def starts_with_uppercase_word(text):
    # Strip leading whitespace and check if the first character is uppercase
    text = text.lstrip()
    if not text:
        return False
    return text[0].isupper()

def ends_with_punctuation(text):
    # Check if the last non-whitespace character is punctuation
    text = text.rstrip()
    return len(text) > 0 and text[-1] in PUNCT

def contains_punctuation(text):
    # Check if there's any punctuation in the text
    # return any(ch in string.punctuation for ch in text)
    return any(ch in PUNCT for ch in text)

def whitespace_encoding(text):
    # Identify all distinct whitespace code points used in the text.
    # This will differentiate between e.g. U+0020 (normal space) and U+00A0 (no-break space).
    whitespaces = set()
    for ch in text:
        if ch.isspace():
            whitespaces.add(ord(ch))  # store the code point
    return whitespaces

def apostrophe_encoding(text):
    # Extract all apostrophe-like characters: common are `'` and `’`
    # Return a set of apostrophe chars used
    # If you want to be more comprehensive, include other variants.
    # Here we include backtick and right single quotation mark as well.
    possible_apostrophes = {"'", "’", "`"}
    apostrophes = {ch for ch in text if ch in possible_apostrophes}
    return apostrophes

def extract_number_patterns(text):
    # Find all numbers and their surrounding formatting.
    # We'll capture substrings around each digit sequence that may include punctuation and spacing.
    number_patterns = []
    for match in re.finditer(r"\d+", text):
        start, end = match.span()
        # Extend outwards to include punctuation/whitespace directly adjacent to the digits
        left = start
        while left > 0 and (text[left-1] in string.punctuation or text[left-1].isspace()):
            left -= 1
        right = end
        while right < len(text) and (text[right] in string.punctuation or text[right].isspace()):
            right += 1
        substring = text[left:right].strip()
        number_patterns.append(substring)
    return number_patterns

def compare_number_formats(patterns1, patterns2):
    # Check if both lists have the same number of numeric patterns
    if len(patterns1) != len(patterns2):
        return False
    # Compare each pair of patterns
    for p1, p2 in zip(patterns1, patterns2):
        # Compare digits sequence
        digits1 = re.sub(r"\D", "", p1)
        digits2 = re.sub(r"\D", "", p2)
        if digits1 != digits2:
            return False
        # Compare non-digit formatting
        non_digits1 = re.sub(r"\d", "", p1)
        non_digits2 = re.sub(r"\d", "", p2)
        if non_digits1 != non_digits2:
            return False
    return True

def contains_newline(text):
    return "\n" in text

def contains_contractions(text):
    # Check if text contains any of the known contracted forms
    pattern = r'\b(?:' + '|'.join(map(re.escape, common_contractions.values())) + r')\b'
    return bool(re.search(pattern, text, flags=re.IGNORECASE))

def can_form_contractions(text):
    # Check if text contains any expansions that could be turned into known contractions
    # If we find at least one expansion pattern in the text, return True
    for expansion in common_contractions.keys():
        # Create a regex pattern for the expansion
        exp_words = expansion.split()
        pattern = r'\b' + r'\s+'.join(exp_words) + r'\b'
        if re.search(pattern, text, flags=re.IGNORECASE):
            return True
    return False


In [149]:
def compare_texts(text1, text2):
    conditions = []
    conditions.append(encased_with_apostrophes(text1) == encased_with_apostrophes(text2))
    conditions.append(starts_with_uppercase_word(text1) == starts_with_uppercase_word(text2))
    conditions.append(ends_with_punctuation(text1) == ends_with_punctuation(text2))
    conditions.append(contains_punctuation(text1) == contains_punctuation(text2))
    conditions.append(whitespace_encoding(text1) == whitespace_encoding(text2))
    conditions.append(apostrophe_encoding(text1) == apostrophe_encoding(text2))
    patterns1 = extract_number_patterns(text1)
    patterns2 = extract_number_patterns(text2)
    conditions.append(compare_number_formats(patterns1, patterns2))
    conditions.append(contains_contractions(text1) == contains_contractions(text2))
    similarity = sum(conditions) / len(conditions)
    return similarity

def make_texts_similar(text1, text2):
    # Adjust Quotes
    if encased_with_apostrophes(text1) != encased_with_apostrophes(text2):
        if encased_with_apostrophes(text1) and not encased_with_apostrophes(text2):
            text2 = '"' + text2 + '"'
        elif not encased_with_apostrophes(text1) and encased_with_apostrophes(text2):
            text2 = text2[1:-1]
    
    # Adjust capitalization at the start
    if starts_with_uppercase_word(text1) != starts_with_uppercase_word(text2):
        if starts_with_uppercase_word(text1) and not starts_with_uppercase_word(text2):
            stripped = text2.lstrip()
            if stripped:
                start_idx = len(text2) - len(stripped)
                text2 = text2[:start_idx] + stripped[0].upper() + stripped[1:]
        elif not starts_with_uppercase_word(text1) and starts_with_uppercase_word(text2):
            stripped = text2.lstrip()
            if stripped:
                start_idx = len(text2) - len(stripped)
                text2 = text2[:start_idx] + stripped[0].lower() + stripped[1:]

    # Adjust punctuation at the end
    if ends_with_punctuation(text1) != ends_with_punctuation(text2):
        if ends_with_punctuation(text1) and not ends_with_punctuation(text2):
            t1_end_punct = text1.rstrip()[-1]
            text2 = text2.rstrip() + t1_end_punct
        elif not ends_with_punctuation(text1) and ends_with_punctuation(text2):
            text2 = text2.rstrip()
            while text2 and text2[-1] in PUNCT:
                text2 = text2[:-1]

    # Now text1 and text2 should be similar in capitalization and end punctuation.
    # Apostrophe and whitespace encoding is the same initially.
    # Randomly decide if we want to change them for BOTH texts simultaneously.
    
    # Random chance to change whitespace encoding for both
    # For example, replace all regular spaces with non-breaking spaces in both texts
    if random.random() < 0.5:
        # Check if we have spaces
        if " " in text1 or " " in text2:
            # Replace all spaces with non-breaking spaces
            text1 = text1.replace(" ", "\u00A0")
            text2 = text2.replace(" ", "\u00A0")

    # Random chance to toggle apostrophe encoding for both
    # If we have apostrophes, switch them from `'` to `’` or vice versa
    apos1 = apostrophe_encoding(text1)
    apos2 = apostrophe_encoding(text2)
    # Since they are initially the same, we can just pick a toggle.
    if random.random() < 0.5 and (apos1 and apos2):
        # If we have at least one type of apostrophe in the texts
        # If we find `'` in texts, replace it with `’`, else if `’` then replace with `'`
        if "'" in text1 or "'" in text2:
            # Replace `'` with `’`
            text1 = text1.replace("'", "’")
            text2 = text2.replace("'", "’")
        elif "’" in text1 or "’" in text2:
            # Replace `’` with `'`
            text1 = text1.replace("’", "'")
            text2 = text2.replace("’", "'")

    return text1, text2
    

In [150]:
import random

def flip_quotes(t):
    if encased_with_apostrophes(t):
        return t[1:-1], True
    else:
        return '"' + t + '"', True
    
def flip_capitalization(t):
    stripped = t.lstrip()
    if not stripped:
        return t, False
    start_idx = len(t) - len(stripped)
    first_char = stripped[0]
    if first_char.isalpha():
        flipped = first_char.lower() if first_char.isupper() else first_char.upper()
        new_t = t[:start_idx] + flipped + stripped[1:]
        changed = (new_t != t)
        return new_t, changed
    else:
        return t, False

def toggle_end_punctuation(t):
    if ends_with_punctuation(t):
        original = t
        t = t.rstrip()
        while t and t[-1] in PUNCT:
            t = t[:-1]
        changed = (t != original)
        return t, changed
    else:
        return t + ".", True

# def toggle_punctuation_presence(t):
#     if contains_punctuation(t):
#         original = t
#         t = "".join(ch for ch in t if ch not in PUNCT).rstrip()
#         changed = (t != original)
#         return t, changed
#     else:
#         return t, False

def toggle_whitespace_encoding(t):
    # Assume it only includes " " whitespaces. Change those to non-breaking spaces (\u00A0)
    original = t
    if " " in t:
        # Replace all spaces with non-breaking spaces
        t = t.replace(" ", "\u00A0")
        changed = (t != original)
        return t, changed
    else:
        # No spaces to change
        return t, False

def toggle_apostrophe_encoding(t):
    original = t
    apos = apostrophe_encoding(t)
    if apos:
        if "'" in apos and "’" in apos:
            t = t.replace("'", "\uFFFF")
            t = t.replace("’", "'")
            t = t.replace("\uFFFF", "’")
        elif "'" in apos:
            t = t.replace("'", "’")
        elif "’" in apos:
            t = t.replace("’", "'")
        changed = (t != original)
        return t, changed
    else:
        return t, False

def toggle_number_format(t):
    patterns = extract_number_patterns(t)
    changed = False
    if patterns:
        for p in patterns:
            if ',' in p:
                new_p = re.sub(r",", "", p)
                if new_p != p:
                    idx = t.find(p)
                    if idx != -1:
                        t = t[:idx] + new_p + t[idx+len(p):]
                        changed = True
                        break
    return t, changed

def maybe_add_contraction(text1, text2):
    # Only add a contraction if:
    # - text1 can form contractions
    # - text1 has no contractions
    # - text2 has no contractions
    original = text2
    if not can_form_contractions(text1):
        return text2, False
    if contains_contractions(text1) or contains_contractions(text2):
        return text2, False

    expansions = list(common_contractions.keys())
    random.shuffle(expansions)

    for expansion in expansions:
        exp_words = expansion.split()
        pattern = r'\b' + r'\s+'.join(exp_words) + r'\b'
        match = re.search(pattern, text2, flags=re.IGNORECASE)
        if match:
            contraction = common_contractions[expansion]
            matched_text = match.group(0)
            if matched_text[0].isupper():
                contraction = contraction[0].upper() + contraction[1:]
            text2 = text2[:match.start()] + contraction + text2[match.end():]
            return text2, (text2 != original)

    return text2, False

def make_texts_distinct(text1, text2):
    initial_similarity = compare_texts(text1, text2)
    if initial_similarity < 1.0:
        # Already distinct enough (final similarity < initial similarity would mean final < initial)
        # but user wants strictly less than initial. Let's try to reduce further.
        target_similarity = 0.0 if initial_similarity == 0.0 else (initial_similarity - 0.01)
    else:
        # initial_similarity is 1.0 or close to it, we must get below that (e.g., <1.0)
        target_similarity = initial_similarity - 0.01

    transformations = [
        flip_quotes,
        flip_capitalization,
        toggle_end_punctuation,
        toggle_whitespace_encoding,
        toggle_apostrophe_encoding,
        toggle_number_format,
        lambda t: maybe_add_contraction(text1, t),
    ]

    # Try transformations until we achieve a final similarity less than initial_similarity
    # We'll try each transformation one by one and apply only if it reduces similarity.

    # Because we want to ensure difference, we may attempt multiple transformations.
    # We'll try a randomized approach: shuffle transformations and try them repeatedly
    # until we either succeed or exhaust attempts.
    text_modified = text2
    attempts = 20  # limit attempts to avoid infinite loops

    while attempts > 0:
        current_similarity = compare_texts(text1, text_modified)
        if current_similarity < initial_similarity:
            # We've achieved our goal: final similarity is less than initial
            return text_modified

        # Attempt a random transformation
        transform = random.choice(transformations)
        new_text, changed = transform(text_modified)
        if changed:
            # Check if similarity improved (decreased)
            new_similarity = compare_texts(text1, new_text)
            if new_similarity < current_similarity:
                # Keep this change
                text_modified = new_text
            # If new_similarity is not lower, revert to old text_modified (no change applied)
        attempts -= 1

    # If we exit the loop, we failed to reduce similarity
    return text_modified

In [151]:
text_a = "Hello, world!"
text_b = "Hello,\u00a0world!"
score = compare_texts(text_a, text_b)
print("Similarity score:", score)

Similarity score: 0.875


In [152]:
text_a = "Hello, world!\nThe price is 1,000 dollars. It’s great."
text_b = "hello world. The price is 1000 dollars It's great"
score = compare_texts(text_a, text_b)
print("Similarity score:", score)

Similarity score: 0.25


In [155]:
text_a = "The two farmers are working on a piece of John Deere equipment."
text_b = "Men are working on John Deere equipment"
score = compare_texts(text_a, text_b)
print("Similarity score:", score)
text_b_synth = make_texts_similar(text_a, text_b)[1]
print("Synthesized text:", text_b_synth)
score = compare_texts(text_a, text_b_synth)
print("Similarity score:", score)

Similarity score: 0.75
Synthesized text: Men are working on John Deere equipment.
Similarity score: 1.0


In [161]:
text_a = "There is a party"
text_b = "There is a party"
score = compare_texts(text_a, text_b)
print("Similarity score:", score)
text_b_synth = make_texts_distinct(text_a, text_b)
print("Synthesized text:", text_b_synth)
score = compare_texts(text_a, text_b_synth)
print("Similarity score:", score)

Similarity score: 1.0
Synthesized text: "There is a party"
Similarity score: 0.75


## Data Augmentation for SNLI

In [163]:
import pandas as pd
import os

dataset = load_dataset("snli")
os.makedirs("snli_modified", exist_ok=True)

for split in dataset.keys():
    data = dataset[split]

    rows = []
    for example in data:
        premise = example["premise"]
        hypothesis = example["hypothesis"]
        label = example["label"]

        # Skip if label is not in {0, 1, 2}
        if label not in {0, 1, 2}:
            continue

        # Flip a coin for similar/distinct
        want_similar = random.choice([True, False])

        # Check current similarity
        initial_sim = compare_texts(premise, hypothesis)
        currently_similar = (initial_sim == 1.0)

        if want_similar and not currently_similar:
            # Make them similar
            premise, hypothesis = make_texts_similar(premise, hypothesis)
        elif not want_similar and currently_similar:
            # Make them distinct
            hypothesis = make_texts_distinct(premise, hypothesis)

        # Re-check similarity after transformations
        final_sim = compare_texts(premise, hypothesis)
        style = 1 if final_sim == 1.0 else 0 # 1 for similar, 0 for distinct

        rows.append({
            "premise": premise,
            "hypothesis": hypothesis,
            "label": label, # 0 entailment, 1 neutral, 2 contradiction
            "style": style # 0 distinct, 1 similar
        })

    df = pd.DataFrame(rows, columns=["premise", "hypothesis", "label", "style"])
    output_file = f"snli_modified/{split}_modified.csv"
    df.to_csv(output_file, index=False, encoding='utf-8')