In [1]:
import json
import random
from pathlib import Path

import numpy as np

In [2]:
from setup import setup

setup()

In [3]:
from config import TRAIN_DATASET_JSON, TRAIN_DATASET_TXT
from src.augmentations.handler import AugmentationHandler
from src.dataclasses.datapoint import DataPoint

In [4]:
MAX_AUGS_PER_LEVEL = 5
RANDOM_SEED = 42

In [5]:
with open(TRAIN_DATASET_JSON) as f:
    raw_data = json.load(f)

levels = []
for item in raw_data:
    datapoint = DataPoint(
        label=item["label"],
        level=np.array(item["level"]),
    )

    if datapoint.text not in levels:
        levels.append(datapoint.text)

    tmp_levels = []
    for aug_str in item.get("augmentations", []):
        try:
            if "-" in aug_str:
                aug_str, param_str = aug_str.split("-", 1)
                param = int(param_str)
                for i in range(1, param + 1):
                    augmented = AugmentationHandler.handle(
                        augmentation=aug_str, datap=datapoint, param=i
                    )
                    if (
                        augmented.text not in levels
                        and augmented.text not in tmp_levels
                    ):
                        tmp_levels.append(augmented.text)
            else:
                augmented = AugmentationHandler.handle(
                    augmentation=aug_str, datap=datapoint
                )
                if augmented.text not in levels and augmented.text not in tmp_levels:
                    tmp_levels.append(augmented.text)

        except Exception as e:
            print(f"Skipping augmentation '{aug_str}' due to error: {e}")

    if len(tmp_levels) > MAX_AUGS_PER_LEVEL:
        tmp_levels = random.sample(tmp_levels, MAX_AUGS_PER_LEVEL)

    levels.extend(tmp_levels)

Path(TRAIN_DATASET_TXT).write_text("\n".join(levels))
print(f"Construct {len(levels)} unique entries")

Construct 65 unique entries
