In [4]:
# ──────────────────────────────────────────
# Cell 1: Install & configure NLTK/WordNet
# ──────────────────────────────────────────
!pip install nltk

import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')
nltk.download('stopwords')

from nltk.corpus import wordnet, stopwords
import random
import pandas as pd

# load stop‐word set once
STOP_WORDS = set(stopwords.words('english'))



[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!
[nltk_data] Downloading package omw-1.4 to /root/nltk_data...
[nltk_data]   Package omw-1.4 is already up-to-date!
[nltk_data] Downloading package stopwords to /root/nltk_data...
[nltk_data]   Package stopwords is already up-to-date!


In [5]:
# ──────────────────────────────────────────
# Cell 2 (updated): Define safe EDA functions & augment train set
# ──────────────────────────────────────────
import copy
import random
import pandas as pd
from nltk.corpus import wordnet, stopwords

STOP_WORDS = set(stopwords.words('english'))

def get_synonyms(word):
    syns = set()
    for syn in wordnet.synsets(word):
        for lem in syn.lemmas():
            w = lem.name().replace('_',' ').lower()
            if w != word:
                syns.add(w)
    return list(syns)

def synonym_replacement(words, n):
    if not words:
        return words
    new_words = words.copy()
    candidates = [w for w in words if w not in STOP_WORDS]
    random.shuffle(candidates)
    num_replaced = 0
    for w in candidates:
        syns = get_synonyms(w)
        if syns:
            new_words = [random.choice(syns) if x == w else x for x in new_words]
            num_replaced += 1
        if num_replaced >= n:
            break
    return new_words

def random_insertion(words, n):
    if not words:
        return words
    new_words = words.copy()
    for _ in range(n):
        candidates = [w for w in new_words if w not in STOP_WORDS]
        if not candidates:
            break
        w = random.choice(candidates)
        syns = get_synonyms(w)
        if not syns:
            continue
        insert_word = random.choice(syns)
        idx = random.randint(0, len(new_words))
        new_words.insert(idx, insert_word)
    return new_words

def random_swap(words, n):
    if len(words) < 2:
        return words
    new_words = words.copy()
    for _ in range(n):
        i, j = random.sample(range(len(new_words)), 2)
        new_words[i], new_words[j] = new_words[j], new_words[i]
    return new_words

def random_deletion(words, p):
    if len(words) <= 1:
        return words
    new_words = [w for w in words if random.random() > p]
    return new_words if new_words else [random.choice(words)]

def eda(sentence, alpha=0.1, n_aug=4):
    """
    Perform EDA (Wei & Zou, 2019) on one sentence.
      alpha: percent of words to change for SR/RI/RS or deletion prob for RD
      n_aug: how many augmented samples to produce
    Returns: list of n_aug augmented sentences (strings).
    """
    words = sentence.split()
    l = len(words)
    if l == 0:
        return []
    # number of edits per operation
    n = max(1, int(alpha * l))
    augmented = []
    ops = []
    # only include ops that make sense
    if any(get_synonyms(w) for w in words):
        ops.append(lambda w: synonym_replacement(w, n))
    if len(words) >= 1:
        ops.append(lambda w: random_insertion(w, n))
    if len(words) >= 2:
        ops.append(lambda w: random_swap(w, n))
    ops.append(lambda w: random_deletion(w, alpha))

    for _ in range(n_aug):
        op = random.choice(ops)
        aug_words = op(words)
        augmented.append(" ".join(aug_words))
    return augmented

# ──────────────────────────────────────────
# Now load your training data, augment it, and save
# ──────────────────────────────────────────
train_df = pd.read_csv("data/training_split.csv")   # ["sentence","label"]
alpha    = 0.1    # ~10% of words changed
n_aug    = 4      # 3 new sentences per original

aug_sentences, aug_labels = [], []
for sent, lbl in zip(train_df["sentence"], train_df["label"]):
    seen = set([sent])        # track originals + any duplicates
    aug_count = 0
    while aug_count < n_aug:
        samples = eda(sent, alpha=alpha, n_aug=1)
        if not samples:
            aug_count += 1
            break
        aug = samples[0]
        if aug not in seen:
            aug_sentences.append(aug)
            aug_labels.append(lbl)
            seen.add(aug)
        aug_count += 1

aug_df = pd.DataFrame({"sentence": aug_sentences, "label": aug_labels})
full_train = pd.concat([train_df, aug_df], ignore_index=True)
full_train.to_csv("data/training_split_eda.csv", index=False)

print(f"Original samples: {len(train_df)}")
print(f"Augmented samples: {len(aug_df)}")
print(f"Total samples saved: {len(full_train)}")

Original samples: 91887
Augmented samples: 302148
Total samples saved: 394035
