In [1]:
import json
import numpy as np
from pathlib import Path

In [2]:
from setup import setup

setup()

In [3]:
from config import TRAIN_DATASET_JSON, TRAIN_DATASET_TXT
from src.augmentations.handler import AugmentationHandler
from src.dataclasses.datapoint import DataPoint

In [4]:
with open(TRAIN_DATASET_JSON) as f:
    raw_data = json.load(f)

text_cache = set()
lines = []

for item in raw_data:
    datapoint = DataPoint(
        label=item["label"],
        level=np.array(item["level"]),
    )

    if datapoint.text not in text_cache:
        lines.append(datapoint.text)
        text_cache.add(datapoint.text)

    for aug_str in item.get("augmentations", []):
        try:
            if "-" in aug_str:
                aug_str, param_str = aug_str.split("-", 1)
                param = int(param_str)
                for i in range(1, param + 1):
                    augmented = AugmentationHandler.handle(
                        augmentation=aug_str, datap=datapoint, param=i
                    )
                    if augmented.text not in text_cache:
                        lines.append(augmented.text)
                        text_cache.add(augmented.text)
            else:
                augmented = AugmentationHandler.handle(
                    augmentation=aug_str, datap=datapoint
                )
                if augmented.text not in text_cache:
                    lines.append(augmented.text)
                    text_cache.add(augmented.text)
        except Exception as e:
            print(f"Skipping augmentation '{aug_str}' due to error: {e}")


Path(TRAIN_DATASET_TXT).write_text("\n".join(lines))
print(f"Construct {len(lines)} unique entries")

Construct 8 unique entries
