@@ -26,26 +26,19 @@ def prepare_dataset(
2626 tokenizer_name : str = "gpt2" ,
2727 max_length : int = 1024 ,
2828 cache_dir : Optional [str ] = None ,
29- num_proc : int = 4
29+ num_proc : int = 4 ,
30+ text_column : Optional [str ] = None
3031) -> Tuple [PYTORCH_Dataset , Optional [PYTORCH_Dataset ], AutoTokenizer ]:
3132 """
3233 Prepare a Hugging Face dataset for training.
3334
3435 Args:
35- dataset_name: Name of the dataset to load (e.g., "wikitext/wikitext-2-raw-v1").
36- tokenizer_name: Name of the tokenizer to use (e.g., "gpt2").
37- max_length: Maximum sequence length for tokenized inputs.
38- cache_dir: Directory to cache the dataset.
39- num_proc: Number of processes for tokenization.
40-
41- Returns:
42- A tuple containing:
43- - Train dataset (PYTORCH_Dataset)
44- - Validation dataset (Optional[PYTORCH_Dataset])
45- - Tokenizer (AutoTokenizer)
46-
47- Raises:
48- DatasetPreparationError: If there is an issue with loading or tokenizing the dataset.
36+ dataset_name: Name of the dataset to load
37+ tokenizer_name: Name of the tokenizer to use
38+ max_length: Maximum sequence length
39+ cache_dir: Directory to cache the dataset
40+ num_proc: Number of processes for tokenization
41+ text_column: Name of the text column (auto-detect if None)
4942 """
5043 try :
5144 # Load tokenizer
@@ -68,11 +61,37 @@ def prepare_dataset(
6861 except Exception as e :
6962 raise DatasetPreparationError (f"Failed to load dataset { dataset_name } : { str (e )} " )
7063
71- # Tokenize the dataset
64+ # Auto-detect text column if not specified
65+ if text_column is None :
66+ # Common column names for text data
67+ possible_columns = ['text' , 'content' , 'input_text' , 'sentence' , 'document' ]
68+ available_columns = dataset ['train' ].column_names
69+
70+ # Find the first matching column
71+ text_column = next (
72+ (col for col in possible_columns if col in available_columns ),
73+ None
74+ )
75+
76+ # If no standard column found, look for any string column
77+ if text_column is None :
78+ for col in available_columns :
79+ if isinstance (dataset ['train' ][0 ][col ], str ):
80+ text_column = col
81+ break
82+
83+ if text_column is None :
84+ raise DatasetPreparationError (
85+ f"Could not detect text column. Available columns: { available_columns } "
86+ )
87+
88+ logger .info (f"Auto-detected text column: { text_column } " )
89+
90+ # Tokenize function with dynamic column handling
7291 def tokenize_function (examples ):
7392 try :
7493 return tokenizer (
75- examples ["text" ],
94+ examples [text_column ],
7695 padding = "max_length" ,
7796 truncation = True ,
7897 max_length = max_length ,
@@ -82,11 +101,14 @@ def tokenize_function(examples):
82101 raise DatasetPreparationError (f"Tokenization failed: { str (e )} " )
83102
84103 logger .info ("Tokenizing dataset" )
104+ # Remove only the text column used for tokenization
105+ remove_columns = [text_column ] if text_column in dataset ['train' ].column_names else None
106+
85107 tokenized_dataset = dataset .map (
86108 tokenize_function ,
87109 batched = True ,
88110 num_proc = num_proc ,
89- remove_columns = [ "text" ]
111+ remove_columns = remove_columns
90112 )
91113
92114 # Convert to PyTorch datasets
0 commit comments