In [2]:
import os
import pandas as pd
import numpy as np
import nltk
from nltk.corpus import wordnet

BASE_PATH = r"C:\Users\Impana\Downloads\invoice-classification\\"
d3_path = os.path.join(BASE_PATH, "data", "sroie", "D3_sroie_train.csv")

nltk.download('wordnet')

d3 = pd.read_csv(d3_path)
print(d3.shape, d3['category'].value_counts())

(762, 2) category
OTHER          498
RETAIL          78
FOOD            72
HARDWARE        50
STATIONERY      48
SERVICES         6
ELECTRONICS      6
HOTEL            4
Name: count, dtype: int64


[nltk_data] Downloading package wordnet to
[nltk_data]     C:\Users\Impana\AppData\Roaming\nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


In [3]:
def get_synonym(word):
    synsets = wordnet.synsets(word)
    if not synsets:
        return word
    lemmas = [l.name().replace('_', ' ') for s in synsets for l in s.lemmas()]
    lemmas = [l for l in lemmas if l.lower() != word.lower()]
    return np.random.choice(lemmas) if lemmas else word

def augment_sentence_wn(text, frac=0.1):
    words = str(text).split()
    if not words:
        return text
    n_to_aug = max(1, int(len(words) * frac))
    idxs = np.random.choice(len(words), n_to_aug, replace=False)
    for i in idxs:
        words[i] = get_synonym(words[i])
    return ' '.join(words)


In [4]:
target_size = 10_000
current = len(d3)
needed = target_size - current
print("Current:", current, "Needed:", needed)

aug_rows = []
if needed > 0:
    base_sample = d3.sample(needed, replace=True, random_state=42)
    for _, row in base_sample.iterrows():
        aug_rows.append({
            "text": augment_sentence_wn(row['text']),
            "category": row['category']
        })

aug_df = pd.DataFrame(aug_rows)
d3_wn10k = pd.concat([d3, aug_df], ignore_index=True)
print("Final size:", len(d3_wn10k), d3_wn10k['category'].value_counts())

out_path = os.path.join(BASE_PATH, "data", "sroie", "D3_sroie_WNtrain10k.csv")
d3_wn10k.to_csv(out_path, index=False)
print("Saved:", out_path)


Current: 762 Needed: 9238
Final size: 10000 category
OTHER          6497
RETAIL         1033
FOOD            936
STATIONERY      670
HARDWARE        661
ELECTRONICS      84
SERVICES         76
HOTEL            43
Name: count, dtype: int64
Saved: C:\Users\Impana\Downloads\invoice-classification\\data\sroie\D3_sroie_WNtrain10k.csv
