In [None]:
import torch
from torch.utils.data.dataset import Dataset
from datasets import load_dataset 
from sklearn.model_selection import train_test_split 

from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
)
from sklearn.model_selection import train_test_split

# Example data. 
# In reality, the strings are usually longer and there are 11 possible classes
dataset = load_dataset("jigsaw_toxicity_pred", data_dir="/workspaces/LLM-Experimentation-Capstone/00_source_data/jigsaw_toxicity")
dataset = dataset['train'][0:100]
# Example data. 
# In reality, the strings are usually longer and there are 11 possible classes
texts = dataset['comment_text']

labels = [[x, y, z, a, b, c] for x, y, z, a, b, c in zip(dataset['toxic'], dataset['severe_toxic'], dataset['obscene'], 
                                                   dataset['threat'], dataset['insult'], dataset['identity_hate'])]


train_texts, eval_texts, train_labels, eval_labels = train_test_split(
    texts, labels, test_size=0.33, random_state=42
)

tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased', do_lower_case=True)

train_encodings = tokenizer(train_texts, padding="max_length", truncation=True, max_length=512)
eval_encodings = tokenizer(eval_texts, padding="max_length", truncation=True, max_length=512)


class TextClassifierDataset(Dataset):
    def __init__(self, encodings, labels):
        self.encodings = encodings
        self.labels = labels

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
        return item

train_dataset = TextClassifierDataset(train_encodings, train_labels)
eval_dataset = TextClassifierDataset(eval_encodings, eval_labels)

model = AutoModelForSequenceClassification.from_pretrained(
    "bert-base-uncased", 
    problem_type="multi_label_classification",
    num_labels=5
)

training_arguments = TrainingArguments(
    output_dir=".",
    evaluation_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=1,
)

trainer = Trainer(
    model=model,
    args=training_arguments,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset
)

trainer.train()

: 