## Fine-Tuning Multiclass Text Classification

In [None]:
!pip install transformers datasets accelerate seaborn bertviz umap-learn wandb

### Get and View Dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("dair-ai/emotion")
dataset

In [None]:
## Let's peek at the data
train = dataset["train"]
df = train.to_pandas()
label_names = train.features['label'].names

print(df.head())
print()
print(label_names)

In [None]:
## Let's add the class labels into the dataframe
df['label_text'] = df['label'].apply(lambda x: label_names[x])
df.head()

In [None]:
import matplotlib.pyplot as plt

label_counts = df['label_text'].value_counts(ascending=True)

label_counts.plot(kind='barh')
plt.title("Class Distribution")
plt.show()

### Prepare Dataset

In [None]:
from transformers import AutoTokenizer

MODEL = "google-bert/bert-base-uncased"

tokenizer = AutoTokenizer.from_pretrained(MODEL)

In [None]:
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict

train_df, test_df = train_test_split(df, test_size=0.3, stratify=df['label_text'])
test_df, val_df = train_test_split(test_df, test_size=1/3, stratify=test_df['label'])

print(f"Train shape: {train_df.shape}")
print(f"Test shape: {test_df.shape}")
print(f"Validation shape: {val_df.shape}")
print()

dataset = DatasetDict({
    'train': Dataset.from_pandas(train_df, preserve_index=False),
    'test': Dataset.from_pandas(test_df, preserve_index=False),
    'validation': Dataset.from_pandas(val_df, preserve_index=False)
})

dataset['train'][0], dataset['test'][0], dataset['validation'][0]

In [None]:
encoded = dataset.map(lambda batch: tokenizer(batch["text"], padding=True, truncation=True), batched=True, batch_size=None)
encoded

### Get Model and Prepare for Training

In [None]:
from transformers import AutoModel

model = AutoModel.from_pretrained(MODEL)

In [None]:
def show_model_info(model, show_layers=False):
    """Comprehensive model inspection"""
    config = model.config
    architecture = None
    model_heads = []
    model_type = "Unknown"
    id2label = None
    label2id = None
    merged_labels = None
    quant_type = "None"
    q = 0

    gbs = model.get_memory_footprint() / 1e9
    param_count = model.num_parameters()

    # Model architecture
    if hasattr(config, 'architectures') and config.architectures:
        architecture = config.architectures[0]

    # Model heads
    try:
        if hasattr(model, 'base_model'):
            for module in model.modules():
                model_heads.append(type(module).__name__)

                if module == model.base_model:
                    break
        else:
            for name, module in model.named_children()[:5]:
                model_heads.append(f"{name}(type(module).__name__))")

        # Clean model head list
        model_heads = list(dict.fromkeys(model_heads))[:10]
    except Exception as e:
        model_heads = [f"Detection failed: {str(e)[:50]}"]

    # Detect quantization
    if hasattr(config, "quantization_config") and config.quantization_config is not None:
        q_config = config.quantization_config

        if hasattr(q_config, "load_in_4bit") and q_config.load_in_4bit == True:
            q = 4
            quant_type = f"4-bit ({getattr(q_config, 'bnb_4bit_quant_type', 'unknown')})"
        elif hasattr(q_config, "load_in_8bit") and q_config.load_in_8bit == True:
            q = 8
            quant_type = "8-bit"
    else:
        if hasattr(config, "torch_dtype") and config.torch_dtype is not None:
            q = config.torch_dtype.itemsize * 8
            quant_type = f"FP{q} ({config.torch_dtype})"

    # Model type detection
    if hasattr(model.config, 'model_type'):
        model_type = model.config.model_type

    # Label detection
    if  hasattr(model.config, 'label2id') and model.config.label2id is not None and hasattr(config, 'id2label') and config.id2label is not None:

        id2label = model.config.id2label
        label2id = model.config.label2id

        try:
            # Check for label consistency
            label2id_swap = {str(v): k for k, v in label2id.items()}
            id2label_str = {str(k): v for k, v in id2label.items()}

            if id2label_str != label2id_swap:
                merged_labels = {}
                for k, v in id2label.items():
                    key = int(k) if isinstance(k, str) else k
                    merged_labels[key] = [v]

                for k, v in label2id.items():
                    try:
                        v_int = int(v)

                        if v_int in merged_labels:
                            merged_labels[v_int].append(k)
                            merged_labels[v_int] = list(set(merged_labels[v_int]))
                        else:
                            merged_labels[v_int] = [k]
                    except ValueError:
                        continue

        except Exception as e:
            print(f"Label validation error: {e}")

    # Basic model info
    print(f"{'='*55}")
    print(f"MODEL: {getattr(config, '_name_or_path', 'Unknown')}")
    print(f"{'='*55}")

    print(f"Model Type: {model_type}")

    if architecture is not None:
        print(f"Architecture: {architecture}")

    if len(model_heads) > 0:
        print(f"Model Structure: {' → '.join(model_heads)}")

    if hasattr(config, "problem_type") and config.problem_type is not None:
        print(f"Problem Type: {config.problem_type}")

    if hasattr(config, "vocab_size"):
        print(f"Vocab Size: {config.vocab_size:,}")

    if id2label is not None:
        print("\nLabel Info:")

        if merged_labels is None:
            print("  ✅ id2label and label2id match")
            print(f"  Label count: {len(id2label)}")

            if len(id2label) <= 10:
                print(f"  Labels: {id2label}")
            else:
                sample_labels = dict(list(id2label.items())[:5])
                print(f"  Labels (sample): {sample_labels}... (+{len(id2label)-5} more)")
        else:
            print("  ⚠️ WARNING: Model id2label and label2id don't match")
            print(f"  Merged labels: {merged_labels}")

    print(f"\nParameters: {param_count:,}")
    print(f"Quantization: {quant_type}")
    print(f"Memory (actual): {gbs:.2f} GB")
    print(f"Memory (FP32 equiv): {param_count*4/1e9:.2f} GB")

    if gbs > 0:
        print(f"Memory savings: {((param_count*4/1e9 - gbs) / (param_count*4/1e9) * 100):.1f}%")

    # Device info
    device = next(model.parameters()).device
    print(f"\nDevice: {device}")

    # Check if all components on same device
    devices = set()
    for name, param in model.named_parameters():
        devices.add(str(param.device))
    for name, buffer in model.named_buffers():
        devices.add(str(buffer.device))

    if len(devices) > 1:
        print(f"⚠️  WARNING: Model spans multiple devices: {devices}")
    else:
        print(f"✅ All components on: {device}")

    # Add training state info
    if hasattr(model, 'training'):
        mode = "Training" if model.training else "Evaluation"
        print(f"\nMode: {mode}")

    # Memory per layer breakdown
    if show_layers:
        print(f"\n{'Layer Breakdown':^55}")
        print(f"{'Layer':<30} {'Parameters':<14} {'Device'}")
        print("-" * 55)
        for name, param in model.named_parameters():
            # Only show layers with >1M params
            if param.numel() > 1000000:
                print(f"{name[:28]:<30} {param.numel():>10,} {str(param.device):>10}")

In [None]:
show_model_info(model)

In [None]:
label2id = {label: i for i, label in enumerate(label_names)}
id2label = {i: label for i, label in enumerate(label_names)}

label2id, id2label

In [None]:
from transformers import AutoModelForSequenceClassification, AutoConfig
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

config = AutoConfig.from_pretrained(MODEL, label2id=label2id, id2label=id2label)
model = AutoModelForSequenceClassification.from_pretrained(MODEL, config=config, device_map=device)

print(" " * 80)
show_model_info(model)

#### Model Training

In [None]:
from transformers import TrainingArguments, Trainer
import wandb

batch_size = 64
training_dir = "bert-base-uncased-class-trained"

training_args = TrainingArguments(
    output_dir=training_dir,
    overwrite_output_dir=True,
    num_train_epochs=2,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    weight_decay=0.01,
    eval_strategy="epoch",
    disable_tqdm=False,
    logging_steps=10
)

In [None]:
from sklearn.metrics import accuracy_score, f1_score

def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    f1 = f1_score(labels, preds, average="weighted")
    acc = accuracy_score(labels, preds)

    return {"Accuracy": acc, "F1": f1}

In [None]:
trainer = Trainer(
    model=model,
    processing_class=tokenizer,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=encoded["train"],
    eval_dataset=encoded["validation"]
)

wandb.init(project=training_dir)

In [None]:
trainer.train()

In [None]:
## Log into Hugging Face
from google.colab import userdata
from huggingface_hub import login
login(token=userdata.get('HF_TOKEN'))

In [None]:
trainer.push_to_hub()

In [None]:
from transformers import pipeline

HF_USER = "shayharding"
classifier = pipeline("text-classification", model=HF_USER + "/" + training_dir)

In [None]:
classifier("Oh my God!")