In [4]:
import os
import torch
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer, BertForSequenceClassification
from sklearn.model_selection import train_test_split


In [5]:
class THUCNewsDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_length):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]
        encoding = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors='pt'
        )
        item = {key: val.squeeze(0) for key, val in encoding.items()}
        item['labels'] = torch.tensor(label)
        return item


In [6]:
def load_data(data_dir, tokenizer, max_length, test_size=0.2):
    texts = []
    labels = []
    label_map = {}
    label_count = 0
    
    for label in os.listdir(data_dir):
        label_path = os.path.join(data_dir, label)
        if os.path.isdir(label_path):
            if label not in label_map:
                label_map[label] = label_count
                label_count += 1
            for file_name in os.listdir(label_path):
                file_path = os.path.join(label_path, file_name)
                with open(file_path, 'r', encoding='utf-8') as f:
                    text = f.read().strip()
                    texts.append(text)
                    labels.append(label_map[label])
    
    train_texts, test_texts, train_labels, test_labels = train_test_split(
        texts, labels, test_size=test_size, random_state=42
    )

    train_dataset = THUCNewsDataset(train_texts, train_labels, tokenizer, max_length)
    test_dataset = THUCNewsDataset(test_texts, test_labels, tokenizer, max_length)
    
    return train_dataset, test_dataset

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
max_length = 128
data_dir = 'path/to/THUCNews'
train_dataset, test_dataset = load_data(data_dir, tokenizer, max_length)


OSError: Can't load tokenizer for 'bert-base-chinese'. If you were trying to load it from 'https://huggingface.co/models', make sure you don't have a local directory with the same name. Otherwise, make sure 'bert-base-chinese' is the correct path to a directory containing all relevant files for a BertTokenizer tokenizer.

In [None]:
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8)
