In [1]:
import sys
from pathlib import Path

# Add project root to Python path
current_dir = Path.cwd()
project_root = (
    current_dir.parent.parent if current_dir.name == "datasets" else current_dir
)
sys.path.insert(0, str(project_root))

print(f"Added to Python path: {project_root}")
print(f"Current working directory: {current_dir}")

# Verify the fix worked
if (project_root / "src").exists():
    print("✅ 'src' directory found - imports should work now")
else:
    print("❌ 'src' directory not found - check your project structure")

Added to Python path: /Users/shaneryan_1/Downloads/binary_align_zh
Current working directory: /Users/shaneryan_1/Downloads/binary_align_zh/src/datasets
✅ 'src' directory found - imports should work now


In [2]:
print(Path.cwd())

/Users/shaneryan_1/Downloads/binary_align_zh/src/datasets


In [3]:
import torch
from src.datasets.datasets import AlignmentDataset
from src.utils.helpers import load_data
from transformers import XLMRobertaTokenizer
from configs.pipeline_configs import PipelineConfig

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
data_path_dict = {
    "src_data": "../../data/english.txt",
    "tgt_data": "../../data/chinese.txt",
    "align_data": "../../data/alignment.txt",
}

In [5]:
src_data, tgt_data, align_data = load_data(paths=data_path_dict)

In [6]:
src_data, tgt_data, align_data = (
    src_data[:21999],
    tgt_data[:21999],
    align_data[:21999],
)

In [7]:
type(tgt_data[0])

str

In [8]:
src_data[0]

"the 66-year-old geoghan was recently sentenced to 10 years ' imprisonment for his sexual molestation of a 10-year-old boy in 1991 ."

In [9]:
p_config = PipelineConfig(
    output_dir=Path("output"), log_dir=Path("logs"), save_checkpoint=False
)

In [10]:
a = AlignmentDataset(
    tokenizer=XLMRobertaTokenizer.from_pretrained("xlm-roberta-base"),
    source_lines=src_data,
    target_lines=tgt_data,
    alignments=align_data,
    config=p_config,
    save_data=True,
)

[32m 13:50:00[0m | [32m[1mSUCCESS [0m |                 [36msrc.utils.logger_config[0m:[36msetup_logger[0m:[36m64[0m - [32m[1mLogger initialized. Logs will be saved to logs[0m
[32m 13:50:00[0m | [1mINFO    [0m |                 [36msrc.datasets.datasets[0m:[36m__post_init__[0m:[36m38[0m - [1mStarting AlignmentDataset step...[0m
[32m 13:50:00[0m | [32m[1mSUCCESS [0m |                 [36msrc.datasets.datasets[0m:[36m__post_init__[0m:[36m39[0m - [32m[1mAlignmentDataset initialised.[0m
[32m 13:50:00[0m | [1mINFO    [0m |                 [36msrc.datasets.datasets[0m:[36m__post_init__[0m:[36m41[0m - [1mPreparing dataset...[0m
100%|██████████| 21999/21999 [11:24<00:00, 32.14it/s]   
100%|██████████| 21999/21999 [21:13<00:00, 17.27it/s]  


In [11]:
a.data

[{'input_ids': [0,
   378,
   1456,
   36639,
   454,
   294,
   21290,
   268,
   70,
   378,
   1456,
   36639,
   454,
   294,
   21290,
   268,
   11251,
   9,
   46799,
   9,
   18345,
   20787,
   80225,
   509,
   78684,
   149357,
   71,
   47,
   209,
   5369,
   242,
   566,
   22876,
   191,
   674,
   100,
   1919,
   17688,
   185679,
   1363,
   111,
   10,
   12417,
   46799,
   9,
   18345,
   25299,
   23,
   12898,
   6,
   5,
   2,
   6,
   8375,
   100775,
   6,
   13288,
   6,
   43,
   6,
   14460,
   30431,
   12442,
   6,
   4673,
   6376,
   45690,
   11669,
   11669,
   12615,
   6,
   1278,
   6,
   160822,
   45690,
   6,
   1795,
   6,
   11196,
   6,
   13288,
   6,
   8553,
   26578,
   6,
   4,
   47504,
   6,
   1317,
   6,
   32230,
   7646,
   6,
   11196,
   6,
   470,
   6,
   41051,
   42510,
   6,
   30,
   2],
  'attention_mask': [1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,
   1,


In [None]:
data = torch.load("data.pt")