In [1]:
import torch
from torch.utils.data import DataLoader
import sys
import os

    # Добавление пути к src, если необходимо
module_path = os.path.abspath(os.path.join('..')) 
if module_path not in sys.path:
        sys.path.append(module_path)

from config import main_config
from src.modeling.dataset import ContractChunkDataset # Импортируем наш новый Dataset

In [2]:
# --------------------------------------------------------------------------------
    # Настройки (убедитесь, что они соответствуют сохраненным файлам)
    # --------------------------------------------------------------------------------
TOKENIZER_NAME = "microsoft/codebert-base" # Имя токенизатора, которое использовалось
TOKENIZER_NAME_FOR_PATH = TOKENIZER_NAME.replace('/', '_') # Для формирования пути

PROCESSED_DATA_DIR_CHUNKS = main_config.PROCESSED_DATA_DIR / "chunked_data"
file_suffix = (f"t{main_config.MAX_TOTAL_TOKENS}_c{main_config.MODEL_CHUNK_SIZE}"
                   f"_o{main_config.CHUNK_OVERLAP}_{TOKENIZER_NAME_FOR_PATH}.pt")

PATH_TRAIN_CHUNKS = PROCESSED_DATA_DIR_CHUNKS / f"train_chunks_{file_suffix}"
PATH_TRAIN_CHUNK_LABELS = PROCESSED_DATA_DIR_CHUNKS / f"train_chunk_labels_{file_suffix}"
PATH_TRAIN_ORIGINAL_INDICES = PROCESSED_DATA_DIR_CHUNKS / f"train_original_indices_{file_suffix}"
    
PATH_TEST_CHUNKS = PROCESSED_DATA_DIR_CHUNKS / f"test_chunks_{file_suffix}"
PATH_TEST_CHUNK_LABELS = PROCESSED_DATA_DIR_CHUNKS / f"test_chunk_labels_{file_suffix}"
PATH_TEST_ORIGINAL_INDICES = PROCESSED_DATA_DIR_CHUNKS / f"test_original_indices_{file_suffix}"

In [3]:
# Создание экземпляров Dataset
if PATH_TRAIN_CHUNKS.exists() and PATH_TRAIN_CHUNK_LABELS.exists():
        train_dataset = ContractChunkDataset(
            chunk_list_path=str(PATH_TRAIN_CHUNKS),
            labels_path=str(PATH_TRAIN_CHUNK_LABELS),
            original_indices_path=str(PATH_TRAIN_ORIGINAL_INDICES) if PATH_TRAIN_ORIGINAL_INDICES.exists() else None
        )
        print(f"\nTrain dataset created. Length: {len(train_dataset)}")
else:
        print(f"ERROR: Training data files for chunks not found. Searched for:")
        print(f"  Chunks: {PATH_TRAIN_CHUNKS}")
        print(f"  Labels: {PATH_TRAIN_CHUNK_LABELS}")
        train_dataset = None # или assert False

if PATH_TEST_CHUNKS.exists() and PATH_TEST_CHUNK_LABELS.exists():
        test_dataset = ContractChunkDataset(
            chunk_list_path=str(PATH_TEST_CHUNKS),
            labels_path=str(PATH_TEST_CHUNK_LABELS),
            original_indices_path=str(PATH_TEST_ORIGINAL_INDICES) if PATH_TEST_ORIGINAL_INDICES.exists() else None
        )
        print(f"Test dataset created. Length: {len(test_dataset)}")
else:
        print(f"ERROR: Test data files for chunks not found. Searched for:")
        print(f"  Chunks: {PATH_TEST_CHUNKS}")
        print(f"  Labels: {PATH_TEST_CHUNK_LABELS}")
        test_dataset = None # или assert False

Loading chunk list from: E:\Code\diplom\ml_service\data\processed\chunked_data\train_chunks_t4096_c512_o64_microsoft_codebert-base.pt
Loading labels from: E:\Code\diplom\ml_service\data\processed\chunked_data\train_chunk_labels_t4096_c512_o64_microsoft_codebert-base.pt
Loading original indices from: E:\Code\diplom\ml_service\data\processed\chunked_data\train_original_indices_t4096_c512_o64_microsoft_codebert-base.pt
Dataset loaded. Number of chunks: 154231. Labels shape: torch.Size([154231, 7])

Train dataset created. Length: 154231
ERROR: Test data files for chunks not found. Searched for:
  Chunks: E:\Code\diplom\ml_service\data\processed\chunked_data\test_chunks_t4096_c512_o64_microsoft_codebert-base.pt
  Labels: E:\Code\diplom\ml_service\data\processed\chunked_data\test_chunk_labels_t4096_c512_o64_microsoft_codebert-base.pt


In [4]:
# Создание DataLoader'ов
BATCH_SIZE = 16 # Можно настроить

if train_dataset:
        train_dataloader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=0) # num_workers=0 для Windows для начала
        print(f"Train DataLoader created. Batches per epoch: ~{len(train_dataloader)}")
if test_dataset:
        test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0)
        print(f"Test DataLoader created. Batches per epoch: ~{len(test_dataloader)}")

    # Проверка одного батча из train_dataloader
if train_dataset and len(train_dataset) > 0:
        try:
            print("\nSample batch from train_dataloader:")
            sample_train_batch = next(iter(train_dataloader))
            print(f"  Input IDs batch shape: {sample_train_batch['input_ids'].shape}")
            print(f"  Attention Mask batch shape: {sample_train_batch['attention_mask'].shape}")
            print(f"  Labels batch shape: {sample_train_batch['labels'].shape}")
            if 'original_index' in sample_train_batch:
                print(f"  Original Index batch shape: {sample_train_batch['original_index'].shape}")
        except Exception as e:
            print(f"Error fetching a batch from train_dataloader: {e}")
            import traceback
            traceback.print_exc()
elif train_dataset is None:
         print("Train dataset not created, cannot fetch a sample batch.")
else:
         print("Train dataset is empty, cannot fetch a sample batch.")

Train DataLoader created. Batches per epoch: ~9640

Sample batch from train_dataloader:
  Input IDs batch shape: torch.Size([16, 512])
  Attention Mask batch shape: torch.Size([16, 512])
  Labels batch shape: torch.Size([16, 7])
  Original Index batch shape: torch.Size([16])
