Skip to content

Commit 80b5535

Browse files
UPdates the packages version and Fix the Dataset loader Class to Work with Every Hugging face datasets.
1 parent 2c006e9 commit 80b5535

File tree

2 files changed

+41
-19
lines changed

2 files changed

+41
-19
lines changed

diffusionLM/utils/datasetANDtokenizer.py

Lines changed: 40 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,19 @@ def prepare_dataset(
2626
tokenizer_name: str = "gpt2",
2727
max_length: int = 1024,
2828
cache_dir: Optional[str] = None,
29-
num_proc: int = 4
29+
num_proc: int = 4,
30+
text_column: Optional[str] = None
3031
) -> Tuple[PYTORCH_Dataset, Optional[PYTORCH_Dataset], AutoTokenizer]:
3132
"""
3233
Prepare a Hugging Face dataset for training.
3334
3435
Args:
35-
dataset_name: Name of the dataset to load (e.g., "wikitext/wikitext-2-raw-v1").
36-
tokenizer_name: Name of the tokenizer to use (e.g., "gpt2").
37-
max_length: Maximum sequence length for tokenized inputs.
38-
cache_dir: Directory to cache the dataset.
39-
num_proc: Number of processes for tokenization.
40-
41-
Returns:
42-
A tuple containing:
43-
- Train dataset (PYTORCH_Dataset)
44-
- Validation dataset (Optional[PYTORCH_Dataset])
45-
- Tokenizer (AutoTokenizer)
46-
47-
Raises:
48-
DatasetPreparationError: If there is an issue with loading or tokenizing the dataset.
36+
dataset_name: Name of the dataset to load
37+
tokenizer_name: Name of the tokenizer to use
38+
max_length: Maximum sequence length
39+
cache_dir: Directory to cache the dataset
40+
num_proc: Number of processes for tokenization
41+
text_column: Name of the text column (auto-detect if None)
4942
"""
5043
try:
5144
# Load tokenizer
@@ -68,11 +61,37 @@ def prepare_dataset(
6861
except Exception as e:
6962
raise DatasetPreparationError(f"Failed to load dataset {dataset_name}: {str(e)}")
7063

71-
# Tokenize the dataset
64+
# Auto-detect text column if not specified
65+
if text_column is None:
66+
# Common column names for text data
67+
possible_columns = ['text', 'content', 'input_text', 'sentence', 'document']
68+
available_columns = dataset['train'].column_names
69+
70+
# Find the first matching column
71+
text_column = next(
72+
(col for col in possible_columns if col in available_columns),
73+
None
74+
)
75+
76+
# If no standard column found, look for any string column
77+
if text_column is None:
78+
for col in available_columns:
79+
if isinstance(dataset['train'][0][col], str):
80+
text_column = col
81+
break
82+
83+
if text_column is None:
84+
raise DatasetPreparationError(
85+
f"Could not detect text column. Available columns: {available_columns}"
86+
)
87+
88+
logger.info(f"Auto-detected text column: {text_column}")
89+
90+
# Tokenize function with dynamic column handling
7291
def tokenize_function(examples):
7392
try:
7493
return tokenizer(
75-
examples["text"],
94+
examples[text_column],
7695
padding="max_length",
7796
truncation=True,
7897
max_length=max_length,
@@ -82,11 +101,14 @@ def tokenize_function(examples):
82101
raise DatasetPreparationError(f"Tokenization failed: {str(e)}")
83102

84103
logger.info("Tokenizing dataset")
104+
# Remove only the text column used for tokenization
105+
remove_columns = [text_column] if text_column in dataset['train'].column_names else None
106+
85107
tokenized_dataset = dataset.map(
86108
tokenize_function,
87109
batched=True,
88110
num_proc=num_proc,
89-
remove_columns=["text"]
111+
remove_columns=remove_columns
90112
)
91113

92114
# Convert to PyTorch datasets

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
setup(
77
name="diffusionLM",
8-
version="0.1.7",
8+
version="0.1.8",
99
author="Dark Coder",
1010
author_email="codewithdark90@gmail.com",
1111
description="A diffusion-based language model implementation",

0 commit comments

Comments
 (0)