In [1]:
import logging
import re
from pathlib import Path
from typing import Optional

import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
import mlcroissant as mlc
from sklearn.model_selection import train_test_split

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class MBTIDataset(Dataset):
    """PyTorch Dataset for MBTI personality types."""
    
    def __init__(self, texts, labels):
        """
        Args:
            texts: List or array of text posts
            labels: Tensor of integer labels (0-15)
        """
        self.texts = texts
        self.labels = labels
    
    def __len__(self):
        return len(self.texts)
    
    def __getitem__(self, idx):
        return {
            'text': self.texts[idx],
            'label': self.labels[idx]
        }


class MBTIDataModule(pl.LightningDataModule):
    """
    PyTorch Lightning DataModule for MBTI personality classification.
    
    Handles data downloading, cleaning, preprocessing, and DataLoader creation.
    """
    
    def __init__(
        self,
        raw_data_path: str = "data/raw",
        processed_data_path: str = "data/processed",
        batch_size: int = 32,
        num_workers: int = 4,
        test_size: float = 0.2,
        val_size: float = 0.1,
        random_seed: int = 42
    ):
        """
        Args:
            raw_data_path: Path to store raw downloaded data
            processed_data_path: Path to store processed data
            batch_size: Batch size for DataLoaders
            num_workers: Number of workers for DataLoaders
            test_size: Proportion of data for test set
            val_size: Proportion of train data for validation set
            random_seed: Random seed for reproducibility
        """
        super().__init__()
        self.raw_data_path = Path(raw_data_path)
        self.processed_data_path = Path(processed_data_path)
        self.batch_size = batch_size
        self.num_workers = num_workers
        self.test_size = test_size
        self.val_size = val_size
        self.random_seed = random_seed
        
        # Will be populated during setup
        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None
        self.type_to_idx = None
        self.idx_to_type = None
        self.num_classes = None
    
    def prepare_data(self):
        """
        Download data if needed. Called only on 1 GPU in distributed settings.
        """
        self._ensure_data()
    
    def setup(self, stage: Optional[str] = None):
        """
        Load and split data. Called on every GPU in distributed settings.
        
        Args:
            stage: 'fit', 'validate', 'test', or 'predict'
        """
        # Load and process data
        df = self._load_and_clean_data()
        
        # Create label mappings
        unique_types = sorted(df['type'].unique())
        self.type_to_idx = {t: i for i, t in enumerate(unique_types)}
        self.idx_to_type = {i: t for t, i in self.type_to_idx.items()}
        self.num_classes = len(unique_types)
        
        df['type_idx'] = df['type'].map(self.type_to_idx)
        
        # Split into train+val and test
        train_val_df, test_df = train_test_split(
            df,
            test_size=self.test_size,
            random_state=self.random_seed,
            stratify=df['type']
        )
        
        # Split train into train and validation
        train_df, val_df = train_test_split(
            train_val_df,
            test_size=self.val_size,
            random_state=self.random_seed,
            stratify=train_val_df['type']
        )
        
        # Create datasets
        if stage == 'fit' or stage is None:
            self.train_dataset = MBTIDataset(
                train_df['posts'].values,
                torch.tensor(train_df['type_idx'].values, dtype=torch.long)
            )
            self.val_dataset = MBTIDataset(
                val_df['posts'].values,
                torch.tensor(val_df['type_idx'].values, dtype=torch.long)
            )
        
        if stage == 'test' or stage is None:
            self.test_dataset = MBTIDataset(
                test_df['posts'].values,
                torch.tensor(test_df['type_idx'].values, dtype=torch.long)
            )
        
        logger.info(f"Data split: Train={len(train_df)}, Val={len(val_df)}, Test={len(test_df)}")
    
    def train_dataloader(self):
        """Create training DataLoader."""
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def val_dataloader(self):
        """Create validation DataLoader."""
        return DataLoader(
            self.val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def test_dataloader(self):
        """Create test DataLoader."""
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=self.num_workers,
            pin_memory=True
        )
    
    def _ensure_data(self):
        """Download dataset if not present locally."""
        csv_path = self.raw_data_path / "mbti_1.csv"
        
        if csv_path.exists():
            logger.info(f"Raw data found at: {csv_path}")
            return csv_path
        
        logger.info("Downloading dataset via mlcroissant...")
        self.raw_data_path.mkdir(parents=True, exist_ok=True)
        
        try:
            url = "https://www.kaggle.com/datasets/datasnaek/mbti-type/croissant/download"
            dataset = mlc.Dataset(url)
            record_sets = dataset.metadata.record_sets
            records = dataset.records(record_set=record_sets[0].uuid)
            df = pd.DataFrame(records)
            
            # Clean column names
            df.columns = [col.split("/")[-1] for col in df.columns]
            
            df.to_csv(csv_path, index=False)
            logger.info(f"Dataset saved to: {csv_path}")
        except Exception as e:
            logger.error(f"Error downloading dataset: {e}")
            raise
        
        return csv_path
    
    def _load_and_clean_data(self):
        """Load and clean the MBTI dataset."""
        csv_path = self.raw_data_path / "mbti_1.csv"
        df = pd.read_csv(csv_path)
        logger.info(f"Loaded {len(df)} rows")
        
        # Clean 'type' column (remove byte string artifacts)
        if df['type'].dtype == object:
            df['type'] = df['type'].astype(str).str.replace(r"^b'|'$", "", regex=True)
        
        # Clean 'posts' column
        df['posts'] = df['posts'].astype(str).apply(self._clean_text)
        
        # Add binary features (optional, can be used for multi-task learning)
        df['is_E'] = df['type'].apply(lambda x: 1 if 'E' in x else 0)
        df['is_S'] = df['type'].apply(lambda x: 1 if 'S' in x else 0)
        df['is_T'] = df['type'].apply(lambda x: 1 if 'T' in x else 0)
        df['is_J'] = df['type'].apply(lambda x: 1 if 'J' in x else 0)
        
        return df
    
    @staticmethod
    def _clean_text(text: str) -> str:
        """
        Clean text data by removing URLs, byte artifacts, and normalizing whitespace.
        
        Args:
            text: Raw text string
            
        Returns:
            Cleaned text string
        """
        # Remove byte string prefixes
        if text.startswith("b'") or text.startswith('b"'):
            text = text[2:-1]
        
        # Remove URLs
        text = re.sub(r'http\S+|www\.\S+', '', text)
        
        # Remove pipe separators and normalize
        text = text.replace('|||', ' ')
        text = text.lower()
        text = ' '.join(text.split())
        
        return text


# # Example usage
# if __name__ == "__main__":
#     # Initialize DataModule
#     dm = MBTIDataModule(
#         batch_size=32,
#         num_workers=4
#     )
    
#     # Prepare and setup data
#     dm.prepare_data()
#     dm.setup()
    
#     # Access dataloaders
#     train_loader = dm.train_dataloader()
#     val_loader = dm.val_dataloader()
#     test_loader = dm.test_dataloader()
    
#     # Print information
#     print(f"Number of classes: {dm.num_classes}")
#     print(f"Type to index mapping: {dm.type_to_idx}")
#     print(f"Train batches: {len(train_loader)}")
#     print(f"Val batches: {len(val_loader)}")
#     print(f"Test batches: {len(test_loader)}")
    
#     # Example: Iterate through one batch
#     for batch in train_loader:
#         print(f"\nBatch text sample: {batch['text'][0][:100]}...")
#         print(f"Batch labels shape: {batch['label'].shape}")
#         break

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
import logging
from typing import Optional, Dict, Any

import torch
import torch.nn as nn
import pytorch_lightning as pl
from transformers import (
    DistilBertTokenizer,
    DistilBertForSequenceClassification,
    get_linear_schedule_with_warmup
)
from torchmetrics import Accuracy, Precision, Recall, F1Score, ConfusionMatrix
# Import metrics collection from torchmetrics
from torchmetrics import MetricCollection
import wandb
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class MBTIClassifier(pl.LightningModule):
    """
    DistilBERT-based classifier for MBTI personality type prediction.
    Supports fine-tuning with comprehensive WandB logging.
    """
    
    def __init__(
        self,
        num_classes: int = 16,
        model_name: str = "distilbert-base-uncased",
        learning_rate: float = 2e-5,
        weight_decay: float = 0.01,
        warmup_steps: int = 500,
        max_length: int = 512,
        dropout_rate: float = 0.1,
        freeze_encoder: bool = False,
        freeze_layers: int = 0
    ):
        """
        Args:
            num_classes: Number of personality types (16 for MBTI)
            model_name: HuggingFace model identifier
            learning_rate: Learning rate for optimizer
            weight_decay: Weight decay for AdamW
            warmup_steps: Number of warmup steps for scheduler
            max_length: Maximum sequence length for tokenizer
            dropout_rate: Dropout rate for classification head
            freeze_encoder: Whether to freeze entire encoder
            freeze_layers: Number of encoder layers to freeze (0 = none)
        """
        super().__init__()
        self.save_hyperparameters()
        
        # Load tokenizer and model
        self.tokenizer = DistilBertTokenizer.from_pretrained(model_name)
        self.model = DistilBertForSequenceClassification.from_pretrained(
            model_name,
            num_labels=num_classes,
            dropout=dropout_rate
        )
        
        # Freeze layers if specified
        if freeze_encoder:
            for param in self.model.distilbert.parameters():
                param.requires_grad = False
        elif freeze_layers > 0:
            for layer in self.model.distilbert.transformer.layer[:freeze_layers]:
                for param in layer.parameters():
                    param.requires_grad = False
        
        # Loss function
        self.criterion = nn.CrossEntropyLoss()
        
        # Metrics for each split
        self.training_metrics = MetricCollection({
            'acc': Accuracy(task="multiclass", num_classes=num_classes)
        self.train_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = Accuracy(task="multiclass", num_classes=num_classes)
        self.test_acc = Accuracy(task="multiclass", num_classes=num_classes)
        
        self.val_precision = Precision(task="multiclass", num_classes=num_classes, average='macro')
        self.val_recall = Recall(task="multiclass", num_classes=num_classes, average='macro')
        self.val_f1 = F1Score(task="multiclass", num_classes=num_classes, average='macro')
        
        self.test_precision = Precision(task="multiclass", num_classes=num_classes, average='macro')
        self.test_recall = Recall(task="multiclass", num_classes=num_classes, average='macro')
        self.test_f1 = F1Score(task="multiclass", num_classes=num_classes, average='macro')
        
        # Confusion matrix for test
        self.test_confusion = ConfusionMatrix(task="multiclass", num_classes=num_classes)
        
        # Store predictions for analysis
        self.validation_step_outputs = []
        self.test_step_outputs = []
    
    def forward(self, input_ids, attention_mask):
        """Forward pass through the model."""
        outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
        return outputs.logits
    
    def _shared_step(self, batch, batch_idx):
        """Common step for train/val/test."""
        texts = batch['text']
        labels = batch['label']
        
        # Tokenize
        encoding = self.tokenizer(
            texts,
            padding=True,
            truncation=True,
            max_length=self.hparams.max_length,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(self.device)
        attention_mask = encoding['attention_mask'].to(self.device)
        
        # Forward pass
        logits = self(input_ids, attention_mask)
        loss = self.criterion(logits, labels)
        
        # Predictions
        preds = torch.argmax(logits, dim=1)
        
        return loss, preds, labels, logits
    
    def training_step(self, batch, batch_idx):
        """Training step."""
        loss, preds, labels, _ = self._shared_step(batch, batch_idx)
        
        # Update metrics
        self.train_acc(preds, labels)
        
        # Log metrics
        self.log('train/loss', loss, on_step=True, on_epoch=True, prog_bar=True)
        self.log('train/acc', self.train_acc, on_step=True, on_epoch=True, prog_bar=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        """Validation step."""
        loss, preds, labels, logits = self._shared_step(batch, batch_idx)
        
        # Update metrics
        self.val_acc(preds, labels)
        self.val_precision(preds, labels)
        self.val_recall(preds, labels)
        self.val_f1(preds, labels)
        
        # Log metrics
        self.log('val/loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val/acc', self.val_acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val/precision', self.val_precision, on_step=False, on_epoch=True)
        self.log('val/recall', self.val_recall, on_step=False, on_epoch=True)
        self.log('val/f1', self.val_f1, on_step=False, on_epoch=True, prog_bar=True)
        
        # Store outputs for epoch-level analysis
        self.validation_step_outputs.append({
            'preds': preds,
            'labels': labels,
            'loss': loss
        })
        
        return loss
    
    def on_validation_epoch_end(self):
        """Log additional validation metrics at epoch end."""
        if len(self.validation_step_outputs) == 0:
            return
        
        # Calculate per-class accuracy
        all_preds = torch.cat([x['preds'] for x in self.validation_step_outputs])
        all_labels = torch.cat([x['labels'] for x in self.validation_step_outputs])
        
        # Log per-class metrics if using WandB
        if isinstance(self.logger, WandbLogger):
            per_class_acc = {}
            for class_idx in range(self.hparams.num_classes):
                mask = all_labels == class_idx
                if mask.sum() > 0:
                    class_acc = (all_preds[mask] == all_labels[mask]).float().mean()
                    per_class_acc[f'val/class_{class_idx}_acc'] = class_acc.item()
            
            self.logger.experiment.log(per_class_acc)
        
        self.validation_step_outputs.clear()
    
    def test_step(self, batch, batch_idx):
        """Test step."""
        loss, preds, labels, logits = self._shared_step(batch, batch_idx)
        
        # Update metrics
        self.test_acc(preds, labels)
        self.test_precision(preds, labels)
        self.test_recall(preds, labels)
        self.test_f1(preds, labels)
        self.test_confusion(preds, labels)
        
        # Log metrics
        self.log('test/loss', loss, on_step=False, on_epoch=True)
        self.log('test/acc', self.test_acc, on_step=False, on_epoch=True)
        self.log('test/precision', self.test_precision, on_step=False, on_epoch=True)
        self.log('test/recall', self.test_recall, on_step=False, on_epoch=True)
        self.log('test/f1', self.test_f1, on_step=False, on_epoch=True)
        
        # Store outputs for final analysis
        self.test_step_outputs.append({
            'preds': preds,
            'labels': labels,
            'logits': logits
        })
        
        return loss
    
    def on_test_epoch_end(self):
        """Log comprehensive test results."""
        if len(self.test_step_outputs) == 0:
            return
        
        all_preds = torch.cat([x['preds'] for x in self.test_step_outputs])
        all_labels = torch.cat([x['labels'] for x in self.test_step_outputs])
        
        # Compute confusion matrix
        cm = self.test_confusion.compute()
        
        # Log confusion matrix to WandB
        if isinstance(self.logger, WandbLogger):
            # Per-class metrics
            per_class_metrics = {}
            for class_idx in range(self.hparams.num_classes):
                mask = all_labels == class_idx
                if mask.sum() > 0:
                    class_acc = (all_preds[mask] == all_labels[mask]).float().mean()
                    per_class_metrics[f'test/class_{class_idx}_acc'] = class_acc.item()
            
            self.logger.experiment.log(per_class_metrics)
            
            # Log confusion matrix as heatmap
            self.logger.experiment.log({
                "test/confusion_matrix": wandb.plot.confusion_matrix(
                    probs=None,
                    y_true=all_labels.cpu().numpy(),
                    preds=all_preds.cpu().numpy(),
                    class_names=[str(i) for i in range(self.hparams.num_classes)]
                )
            })
        
        self.test_step_outputs.clear()
    
    def configure_optimizers(self):
        """Configure optimizer and learning rate scheduler."""
        # Separate parameters for different learning rates (optional)
        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {
                'params': [p for n, p in self.model.named_parameters() 
                          if not any(nd in n for nd in no_decay) and p.requires_grad],
                'weight_decay': self.hparams.weight_decay,
            },
            {
                'params': [p for n, p in self.model.named_parameters() 
                          if any(nd in n for nd in no_decay) and p.requires_grad],
                'weight_decay': 0.0,
            }
        ]
        
        optimizer = torch.optim.AdamW(
            optimizer_grouped_parameters,
            lr=self.hparams.learning_rate
        )
        
        # Learning rate scheduler
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=self.hparams.warmup_steps,
            num_training_steps=self.trainer.estimated_stepping_batches
        )
        
        return {
            'optimizer': optimizer,
            'lr_scheduler': {
                'scheduler': scheduler,
                'interval': 'step',
                'frequency': 1
            }
        }


def train_mbti_classifier(
    datamodule,
    project_name: str = "mbti-classification",
    experiment_name: Optional[str] = None,
    max_epochs: int = 10,
    gpus: int = 1,
    **model_kwargs
):
    """
    Train the MBTI classifier with comprehensive WandB logging.
    
    Args:
        datamodule: MBTIDataModule instance
        project_name: WandB project name
        experiment_name: WandB run name (optional)
        max_epochs: Maximum number of training epochs
        gpus: Number of GPUs to use
        **model_kwargs: Additional arguments for MBTIClassifier
    """
    
    # Initialize WandB logger
    wandb_logger = WandbLogger(
        project=project_name,
        name=experiment_name,
        log_model=True  # Save model checkpoints to WandB
    )
    
    # Log hyperparameters
    wandb_logger.experiment.config.update({
        "batch_size": datamodule.batch_size,
        "max_epochs": max_epochs,
        **model_kwargs
    })
    
    # Initialize model
    model = MBTIClassifier(
        num_classes=datamodule.num_classes,
        **model_kwargs
    )
    
    # Callbacks
    checkpoint_callback = ModelCheckpoint(
        monitor='val/f1',
        mode='max',
        dirpath='checkpoints/',
        filename='mbti-{epoch:02d}-{val_f1:.3f}',
        save_top_k=3,
        verbose=True
    )
    
    early_stop_callback = EarlyStopping(
        monitor='val/f1',
        patience=3,
        mode='max',
        verbose=True
    )
    
    lr_monitor = LearningRateMonitor(logging_interval='step')
    
    # Initialize trainer
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator='gpu' if gpus > 0 else 'cpu',
        devices=gpus if gpus > 0 else 'auto',
        logger=wandb_logger,
        callbacks=[checkpoint_callback, early_stop_callback, lr_monitor],
        log_every_n_steps=10,
        gradient_clip_val=1.0,
        precision='16-mixed' if gpus > 0 else 32,  # Use mixed precision for faster training
    )
    
    # Train
    trainer.fit(model, datamodule)
    
    # Test
    trainer.test(model, datamodule)
    
    # Close WandB
    wandb.finish()
    
    return model, trainer



In [None]:

# # Example usage
# if __name__ == "__main__":
#     from mbti_datamodule import MBTIDataModule  # Import your DataModule
    
# Initialize DataModule
dm = MBTIDataModule(
    batch_size=16,  # Smaller batch size for transformer models
    num_workers=4
)

# Prepare data
dm.prepare_data()
dm.setup()

# Train model with WandB logging
model, trainer = train_mbti_classifier(
    datamodule=dm,
    project_name="mbti-distilbert",
    experiment_name="distilbert-finetuning-v1",
    max_epochs=10,
    gpus=1,
    learning_rate=2e-5,
    weight_decay=0.01,
    warmup_steps=500,
    max_length=512,
    dropout_rate=0.1,
    freeze_layers=0  # Fine-tune all layers
)

print("Training completed!")
print(f"Best model checkpoint: {trainer.checkpoint_callback.best_model_path}")

INFO:__main__:Raw data found at: data/raw/mbti_1.csv
INFO:__main__:Loaded 8675 rows
INFO:__main__:Data split: Train=6246, Val=694, Test=1735
[34m[1mwandb[0m: [wandb.login()] Loaded credentials for https://api.wandb.ai from /home/prg/.netrc.
[34m[1mwandb[0m: Currently logged in as: [33mpablorocg10[0m ([33mpablorg[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
You are using a CUDA device ('NVIDIA GeForce RTX 3070 Ti Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
INFO:__main__:Raw data found at: data/raw/mbti_1.csv
INFO:__main__:Loaded 8675 rows
INFO:__main__:Data split: Train=6246, Val=694, Test=1735
LOCA

Output()