# ✂️ Snorkel Intro Tutorial: _Data Slicing_

In real-world applications, some model outcomes are often more important than others — e.g. vulnerable cyclist detections in an autonomous driving task, or, in our running **spam** application, potentially malicious link redirects to external websites.

Traditional machine learning systems optimize for overall quality, which may be too coarse-grained.
Models that achieve high overall performance might produce unacceptable failure rates on critical slices of the data — data subsets that might correspond to vulnerable cyclist detection in an autonomous driving task, or in our running spam detection application, external links to potentially malicious websites.

In this tutorial, we:
1. **Introduce _Slicing Functions (SFs)_** as a programming interface
1. **Monitor** application-critical data subsets
2. **Improve model performance** on slices

## 1. Load Labeled Data and Define Slicing Functions (SFs)

In [1]:
# --- Initial Setup ---
%matplotlib inline
import os
import re
import pandas as pd
import numpy as np
import random
import torch
import utils # Your utility functions
import logging

In [2]:
# For reproducibility
os.environ["PYTHONHASHSEED"] = "0"
SEED = 123
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

<torch._C.Generator at 0x154248d30>

In [3]:

# Configure logging and display
logger = logging.getLogger()
logger.setLevel(logging.WARNING) # Reduce verbose Snorkel logging
pd.set_option("display.max_colwidth", 0) # Display full text

In [4]:
# --- Load Labeled Data ---
# Note: load_dataset keeps labels for both train and test here
df_train, df_test = utils.load_dataset(csv_path="data/sentiment_analysis.csv")

# Clean the text (important for SFs too)
def clean_text(text):
    text = text.lower()
    text = re.sub(r'https?://\S+|www\.\S+', '', text)
    text = re.sub(r'@[^\s]+', '', text)
    text = re.sub(r'#([^\s]+)', r'\1', text)
    text = re.sub(r'[^\w\s]', '', text)
    return text

df_train['text'] = df_train['text'].apply(clean_text)
df_test['text'] = df_test['text'].apply(clean_text)


In [5]:
# Extract labels
Y_train = df_train["label"].values
Y_test = df_test["label"].values

# Define labels
ABSTAIN = -1; NEGATIVE = 0; POSITIVE = 1;

print(f"Loaded {len(df_train)} training and {len(df_test)} test examples.")
display(df_train.head())

Loaded 1280000 training and 320000 test examples.


Unnamed: 0,text,label
0,ahhh i hope your ok,0
1,cool i have no tweet apps for my razr 2,0
2,i know just family drama its lamehey next time u hang out with kim n u guys like have a sleepover or whatever ill call u,0
3,school email wont open and i have geography stuff on there to revise stupid school,0
4,upper airways problem,0


Writing Slicing Functions (SFs)

Slicing Functions (SFs) are similar to LFs but output a boolean mask indicating whether a data point belongs to the slice. We'll define a few SFs relevant to sentiment analysis.

In [6]:
from snorkel.slicing import slicing_function

# SF for short tweets (similar to short_comment in the original)
@slicing_function()
def short_tweet(x):
    """Tweets with fewer than 5 words."""
    return len(x.text.split()) < 5

# SF for tweets containing negation words
negation_words = ["not", "no", "never", "ain't", "don't", "isn't", "can't", "won't"]
@slicing_function()
def has_negation(x):
    """Tweets containing common negation words."""
    return any(word in x.text for word in negation_words)

# SF for tweets that are questions
@slicing_function()
def is_question(x):
    """Tweets ending with a question mark (after cleaning)."""
    # Check the end after stripping potential trailing spaces
    return x.text.strip().endswith("?")

# SF using a preprocessor (e.g., TextBlob polarity)
from snorkel.preprocess import preprocessor
from textblob import TextBlob

@preprocessor(memoize=True)
def textblob_polarity_score(x):
    """Adds TextBlob polarity score."""
    scores = TextBlob(x.text)
    x.polarity = scores.sentiment.polarity
    return x

@slicing_function(pre=[textblob_polarity_score])
def high_positive_polarity(x):
    """Tweets with TextBlob polarity > 0.8"""
    return x.polarity > 0.8

# List of all SFs
sfs = [
    short_tweet,
    has_negation,
    is_question,
    high_positive_polarity
]

Visualize Slices

We can use slice_dataframe to see examples from a specific slice.

In [7]:
from snorkel.slicing import slice_dataframe

print("Examples from the 'has_negation' slice:")
negation_df = slice_dataframe(df_test.sample(500, random_state=SEED), has_negation) # Sample for faster display
display(negation_df[['text', 'label']].head())

Examples from the 'has_negation' slice:


100%|██████████| 500/500 [00:00<00:00, 32020.52it/s]


Unnamed: 0,text,label
57814,went to a harley davidson dealer to show some of my art this weekend allot of looks but no sales ill be judging a tattoo comp next,0
77033,itâs start raining now hmm second cloudy day,0
171509,just saw quotwallk to rememberquot again and its never get old,1
136333,i was when i found out you were in the same house im a huge fan toolol through my tv show im getting racing known in main stream,0
173418,ahh thank god he has leftan angry gay man with a hang over is not my idea of fun on a saturday morninghi 2 all in twitt land,1


## 2.Monitor Slice Performance with Scorer.score_slices

Now, let's train a baseline model and see how it performs overall and on our defined slices. This process works with any model framework.


Train a Baseline Classifier

- We'll use scikit-learn's LogisticRegression with TF-IDF features as our baseline.

In [8]:
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression

print("Featurizing data with TF-IDF...")
vectorizer = TfidfVectorizer(ngram_range=(1, 2))
# Important: Use the sparse matrix output, do NOT use .todense()
X_train, _ = utils.df_to_features(vectorizer, df_train, "train")
X_test, _ = utils.df_to_features(vectorizer, df_test, "test")
print("Featurization complete.")

print("Training baseline Logistic Regression model...")
baseline_model = LogisticRegression(solver="liblinear", C=0.1, random_state=SEED)
baseline_model.fit(X=X_train, y=Y_train)
print("Baseline model trained.")

# Get predictions and probabilities (needed for scorer)
preds_test = baseline_model.predict(X_test)
probs_test = baseline_model.predict_proba(X_test)

# Calculate overall accuracy
accuracy_baseline = baseline_model.score(X_test, Y_test)
print(f"\nBaseline Model Overall Test Accuracy: {accuracy_baseline * 100:.1f}%")

Featurizing data with TF-IDF...
Featurization complete.
Training baseline Logistic Regression model...
Baseline model trained.

Baseline Model Overall Test Accuracy: 80.3%


Apply SFs and Score Slices

We apply our SFs to the test set to create the slice matrix S_test, then use Scorer to evaluate the baseline model on each slice.

In [9]:
import pandas as pd
from snorkel.slicing import PandasSFApplier
from snorkel.analysis import Scorer


print("\nApplying Slicing Functions to test set...")
applier = PandasSFApplier(sfs)
S_test = applier.apply(df_test) # Apply SFs to get slice membership matrix
print("SFs applied.")


print("\nChecking slice coverage on the test set (S_test):")
empty_slices = []
if hasattr(S_test, 'dtype') and S_test.dtype.names: # Check if S_test is a structured array with names
    slice_names_in_S_test = S_test.dtype.names
    for slice_name in slice_names_in_S_test:
        try:
            coverage = S_test[slice_name].sum()
            print(f"- Slice '{slice_name}': {coverage} examples")
            if coverage == 0:
                empty_slices.append(slice_name)
        except KeyError:
            print(f"- Warning: Slice name '{slice_name}' not found in S_test dtype names.")
else:
    print("Warning: S_test might not be in the expected format (numpy structured array). Cannot reliably check coverage.")
    # Attempt to derive names if it's potentially from an older Snorkel version or different format
    if isinstance(S_test, np.ndarray) and len(sfs) == S_test.shape[1]:
         slice_names_in_S_test = [sf.name for sf in sfs]
         for i, slice_name in enumerate(slice_names_in_S_test):
             coverage = S_test[:, i].sum()
             print(f"- Slice '{slice_name}' (assumed index {i}): {coverage} examples")
             if coverage == 0:
                 empty_slices.append(slice_name)


if empty_slices:
    print(f"\nWarning: The following slices are empty in the test set: {empty_slices}")
    print("Scoring will likely fail or produce NaN for these slices. Consider removing them from 'sfs' list if unused.")
else:
    print("\nAll slices appear to have coverage on the test set.")

# Initialize scorer with desired metrics (e.g., accuracy)
scorer = Scorer(metrics=["accuracy"])

# Score model performance on slices (handles potential errors gracefully)
print("\nScoring model performance on slices:")
try:
    slice_scores = scorer.score_slices(
        S=S_test,
        golds=Y_test,       # Ground truth labels
        preds=preds_test,   # Model's hard predictions
        probs=probs_test,   # Model's predicted probabilities
        as_dataframe=True   # Return results as a pandas DataFrame
    )
    display(slice_scores)
except ValueError as e:
    print(f"\nError during scoring: {e}")
    print("This often happens if one or more slices are empty in the test set.")
except Exception as e:
    print(f"\nAn unexpected error occurred during scoring: {e}")


Applying Slicing Functions to test set...


100%|██████████| 320000/320000 [01:02<00:00, 5094.53it/s]


SFs applied.

Checking slice coverage on the test set (S_test):
- Slice 'short_tweet': 39573 examples
- Slice 'has_negation': 78197 examples
- Slice 'is_question': 0 examples
- Slice 'high_positive_polarity': 4099 examples

Scoring will likely fail or produce NaN for these slices. Consider removing them from 'sfs' list if unused.

Scoring model performance on slices:

Error during scoring: Cannot score empty labels
This often happens if one or more slices are empty in the test set.


This table shows the baseline model's accuracy on the overall test set and on each specific slice we defined. Look for slices where the accuracy is significantly lower than the overall accuracy – these are areas for potential improvement.

## 3. Improve Slice Performance with SliceAwareClassifier

In the previous section, we identified slices where our baseline model performed poorly compared to its overall accuracy. Now, we'll use Slice-based Learning, a technique that adds slice-specific components to our model to improve performance on those challenging subsets. Snorkel implements this via the SliceAwareClassifier.

Constructing the SliceAwareClassifier

First, we need a base PyTorch model architecture. We'll use the simple Multi-Layer Perceptron (MLP) defined in your utils.py file. Then, we initialize the SliceAwareClassifier.

base_architecture: The core PyTorch model (our MLP).

head_dim: The output dimension of the base_architecture before its final classification layer. This is used by the slice-specific heads.

slice_names: The names of the slices we want the model to be aware of.

In [10]:
import torch
from snorkel.slicing import SliceAwareClassifier
import utils # Ensure utils.py is saved with the new function

# --- Define Model Architecture Parameters ---
if 'X_train' not in locals():
    print("Error: X_train (TF-IDF features) not defined. Please run Step 2 first.")
elif 'sfs' not in locals():
    print("Error: 'sfs' list not defined. Please define slicing functions first.")
elif 'scorer' not in locals():
    print("Error: 'scorer' object not defined. Please run Step 2 first.")
else:
    bow_dim = X_train.shape[1]  # Input dimension from TF-IDF features
    hidden_dim = 128            # Dimension of the MLP's hidden layer and base output

    # --- Create the Base MLP Model using the new function ---
    # This function should return the MLP *without* the final classification layer
    base_mlp = utils.get_pytorch_mlp_base(input_dim=bow_dim, hidden_dim=hidden_dim, num_layers=1) # Using num_layers=1 for simplicity
    head_dim_to_use = hidden_dim # This now correctly matches the output of base_mlp

    # --- Initialize SliceAwareClassifier ---
    slice_model = SliceAwareClassifier(
        base_architecture=base_mlp,
        head_dim=head_dim_to_use,            # Dimension the slice heads take as input
        slice_names=[sf.name for sf in sfs], # List of slice names
        scorer=scorer,                       # Scorer object for evaluation
    )
    print("SliceAwareClassifier initialized successfully using get_pytorch_mlp_base.")

SliceAwareClassifier initialized successfully using get_pytorch_mlp_base.


Prepare Slice-Aware DataLoaders

We need to apply our Slicing Functions (SFs) to the training data (df_train) to get S_train. Then, we create special PyTorch DataLoader objects that include this slice information alongside the features and labels.

In [11]:
from snorkel.classification.data import DictDataset, DictDataLoader
import torch
import numpy as np
import scipy.sparse # Needed for sparse matrix check and conversion

# --- Helper Function to Convert SciPy Sparse to PyTorch Sparse ---
def scipy_sparse_to_torch_sparse(sparse_mx):
    """Converts a SciPy sparse matrix (COO format) to a PyTorch sparse tensor."""
    # Ensure matrix is in COOrdinate format
    if not isinstance(sparse_mx, scipy.sparse.coo_matrix):
        sparse_mx = sparse_mx.tocoo()
    sparse_mx = sparse_mx.astype(np.float32)
    indices = torch.from_numpy(
        np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64)
    )
    values = torch.from_numpy(sparse_mx.data)
    shape = torch.Size(sparse_mx.shape)
    # Use torch.sparse_coo_tensor for modern PyTorch versions
    return torch.sparse_coo_tensor(indices, values, shape)

# --- Code Block for Preparing Slice-Aware DataLoaders (Sparse Version) ---

print("Applying Slicing Functions to train set...")
if 'applier' not in locals():
    print("Error: SF Applier not defined. Please run Step 2 first.")
else:
    S_train = applier.apply(df_train) #
    print("SFs applied to train set.")

    # Convert features to PyTorch *SPARSE* tensors
    print("Converting features to PyTorch SPARSE tensors...")
    try:
        # Check if X_train/X_test are SciPy sparse matrices
        if isinstance(X_train, scipy.sparse.spmatrix) and isinstance(X_test, scipy.sparse.spmatrix):
            X_train_tensor = scipy_sparse_to_torch_sparse(X_train)
            X_test_tensor = scipy_sparse_to_torch_sparse(X_test)
        else:
            # Handle case where input might not be sparse (e.g., if using dense features)
            # This fallback might still cause memory errors if X_train/X_test are large dense numpy arrays
            print("Warning: X_train or X_test is not a SciPy sparse matrix. Attempting direct FloatTensor conversion.")
            X_train_tensor = torch.FloatTensor(X_train)
            X_test_tensor = torch.FloatTensor(X_test)

    except Exception as e:
        print(f"Error converting features to sparse tensors: {e}.")
        raise # Re-raise error to stop execution

    Y_train_tensor = torch.LongTensor(Y_train)
    Y_test_tensor = torch.LongTensor(Y_test)

    # Note: DictDataset *should* handle sparse tensors in the feature dictionary
    train_dataset = DictDataset.from_tensors(X_train_tensor, Y_train_tensor, "train")
    test_dataset = DictDataset.from_tensors(X_test_tensor, Y_test_tensor, "test")
    print("PyTorch sparse datasets created.")

    # Create slice-aware dataloaders
    BATCH_SIZE = 64
    print("Creating slice-aware dataloaders...")
    if 'slice_model' not in locals():
        print("Error: slice_model not initialized.")
    else:
        # The dataloader should yield batches where X is a sparse tensor
        train_dl_slice = slice_model.make_slice_dataloader(
            train_dataset, S_train, shuffle=True, batch_size=BATCH_SIZE #
        )
        if 'S_test' not in locals():
            print("Error: S_test not defined. Please apply SFs to test set first.")
        else:
            test_dl_slice = slice_model.make_slice_dataloader(
                test_dataset, S_test, shuffle=False, batch_size=BATCH_SIZE #
            )
            print("Dataloaders ready.")

print("\n--- IMPORTANT ---")
print("DataLoaders now use SPARSE tensors for features.")
print("Ensure your PyTorch model (`base_mlp` in SliceAwareClassifier) is designed to accept sparse input.")
print("Standard `nn.Linear` layers may need modification (e.g., using `torch.sparse.mm` or specific layers like `nn.EmbeddingBag`).")

Applying Slicing Functions to train set...


100%|██████████| 1280000/1280000 [05:17<00:00, 4032.51it/s]


SFs applied to train set.
Converting features to PyTorch SPARSE tensors...
PyTorch sparse datasets created.
Creating slice-aware dataloaders...
Dataloaders ready.

--- IMPORTANT ---
DataLoaders now use SPARSE tensors for features.
Ensure your PyTorch model (`base_mlp` in SliceAwareClassifier) is designed to accept sparse input.
Standard `nn.Linear` layers may need modification (e.g., using `torch.sparse.mm` or specific layers like `nn.EmbeddingBag`).


Train the Slice-Aware Model

Now we train the SliceAwareClassifier using Snorkel's Trainer. The trainer handles the multi-task learning process automatically.

In [12]:
from snorkel.classification import Trainer

print("\nTraining SliceAwareClassifier...")
if 'slice_model' not in locals() or 'train_dl_slice' not in locals():
    print("Error: slice_model or train_dl_slice not initialized. Cannot train.")
else:
    trainer = Trainer(n_epochs=3, lr=1e-3, progress_bar=True)
    trainer.fit(slice_model, [train_dl_slice])
    print("SliceAwareClassifier training complete.")


Training SliceAwareClassifier...


Epoch 0::   0%|          | 0/20000 [00:01<?, ?it/s]


NotImplementedError: Could not run 'aten::view' with arguments from the 'SparseCPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::view' is only available for these backends: [CPU, MPS, Meta, QuantizedCPU, MkldnnCPU, NestedTensorCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradHIP, AutogradXLA, AutogradMPS, AutogradIPU, AutogradXPU, AutogradHPU, AutogradVE, AutogradLazy, AutogradMTIA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, AutogradMeta, AutogradNestedTensor, Tracer, AutocastCPU, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

CPU: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/build/aten/src/ATen/RegisterCPU.cpp:30476 [kernel]
MPS: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/build/aten/src/ATen/RegisterMPS.cpp:28166 [kernel]
Meta: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/build/aten/src/ATen/RegisterMeta.cpp:26996 [kernel]
QuantizedCPU: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/build/aten/src/ATen/RegisterQuantizedCPU.cpp:954 [kernel]
MkldnnCPU: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/build/aten/src/ATen/RegisterMkldnnCPU.cpp:534 [kernel]
NestedTensorCPU: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/build/aten/src/ATen/RegisterNestedTensorCPU.cpp:825 [kernel]
BackendSelect: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/functorch/DynamicLayer.cpp:497 [backend fallback]
Functionalize: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/build/aten/src/ATen/RegisterFunctionalization_3.cpp:26243 [kernel]
Named: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/ConjugateFallback.cpp:21 [kernel]
Negative: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/native/NegateFallback.cpp:22 [kernel]
ZeroTensor: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/build/aten/src/ATen/RegisterZeroTensor.cpp:164 [kernel]
ADInplaceOrView: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/ADInplaceOrViewType_1.cpp:5390 [kernel]
AutogradOther: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradCPU: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradCUDA: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradHIP: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradXLA: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradMPS: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradIPU: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradXPU: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradHPU: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradVE: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradLazy: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradMTIA: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradPrivateUse1: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradPrivateUse2: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradPrivateUse3: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradMeta: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19529 [autograd kernel]
AutogradNestedTensor: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/VariableType_3.cpp:19508 [kernel]
Tracer: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/torch/csrc/autograd/generated/TraceType_3.cpp:14885 [kernel]
AutocastCPU: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/autocast_mode.cpp:321 [backend fallback]
AutocastXPU: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/autocast_mode.cpp:463 [backend fallback]
AutocastMPS: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/autocast_mode.cpp:209 [backend fallback]
AutocastCUDA: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/autocast_mode.cpp:165 [backend fallback]
FuncTorchBatched: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/functorch/BatchRulesViews.cpp:555 [kernel]
BatchedNestedTensor: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/LegacyBatchingRegistrations.cpp:1079 [kernel]
VmapMode: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/functorch/TensorWrapper.cpp:207 [backend fallback]
PythonTLSSnapshot: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/functorch/DynamicLayer.cpp:493 [backend fallback]
PreDispatch: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_e3pikzc5fh/croot/libtorch_1738337599132/work/aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]


Evaluate the Slice-Aware Model

Finally, evaluate the trained SliceAwareClassifier on the test set slices. The score_slices method will report metrics for the main task and for each slice-specific head.

In [None]:
print("\nEvaluating SliceAwareClassifier on test slices:")
if 'slice_model' not in locals() or 'test_dl_slice' not in locals():
     print("Error: slice_model or test_dl_slice not initialized. Cannot evaluate.")
else:
    slice_aware_scores = slice_model.score_slices([test_dl_slice], as_dataframe=True)

    if 'slice_scores' in locals():
        print("\nComparison with Baseline Model:")
        comparison_df = slice_scores.rename(columns={"accuracy": "baseline_accuracy"})
        score_col_name = 'score'
        if score_col_name in slice_aware_scores.columns:
             slice_aware_scores['slice_name'] = slice_aware_scores['label'].apply(lambda x: x.replace('task_slice:', '').replace('_pred', '') if 'task_slice:' in x else ('overall' if x == 'task' else x))
             # Ensure index alignment before merge
             if comparison_df.index.name != 'slice_name':
                 comparison_df.index.name = 'slice_name'
             comparison_df = comparison_df.reset_index().merge(
                 slice_aware_scores[['slice_name', score_col_name]],
                 on='slice_name',
                 how='left'
             )
             comparison_df = comparison_df.rename(columns={score_col_name: "slice_aware_accuracy"}).set_index('slice_name')
             # Select only relevant columns for final display
             if 'baseline_accuracy' in comparison_df.columns and 'slice_aware_accuracy' in comparison_df.columns:
                 display(comparison_df[['baseline_accuracy', 'slice_aware_accuracy']])
             else:
                 print("Warning: Could not create comparison table due to missing columns.")
                 display(slice_aware_scores)
        else:
            print(f"Could not find score column '{score_col_name}' in slice_aware_scores.")
            display(slice_aware_scores)
    else:
        print("\nBaseline scores not found for comparison. Displaying SliceAwareClassifier scores:")
        display(slice_aware_scores)