# nanoTabStar: Manual Inference & Inspection

This notebook demonstrates how to load a trained **nanoTabStar** model and perform manual predictions on samples from the pretrain corpus. This is useful for inspecting how the model interprets specific feature combinations and target descriptions.

In [7]:
import os
import sys
import torch
import numpy as np
from transformers import AutoTokenizer

# Add project root to path to import nanotabstar
project_root = os.path.dirname(os.path.dirname(os.path.abspath("__file__")))
if project_root not in sys.path:
    sys.path.insert(0, project_root)

from nanotabstar.model import TabSTARModel
from nanotabstar.data_loader import TabSTARDataLoader
from nanotabstar.metrics import calculate_metrics

In [8]:
# Configuration
H5_PATH = "../data/pretrain_corpus_tabstar.h5"
MODEL_PATH = "../best_model.pt"
MODEL_NAME = "intfloat/e5-small-v2"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"Using device: {DEVICE}")

Using device: cuda


In [9]:
# 1. Initialize Tokenizer
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

# 2. Initialize Model Architecture
model = TabSTARModel(d_model=384, n_layers=6, n_heads=6)

# 3. Load Weights
if os.path.exists(MODEL_PATH):
    print(f"Loading weights from {MODEL_PATH}...")
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
else:
    print(f"WARNING: {MODEL_PATH} not found. Using a fresh model (random weights).")

model.to(DEVICE)
model.eval()
print("Model ready for inference.")

Loading weights from ../best_model.pt...
Model ready for inference.


In [10]:
# Initialize Dataloader on the validation split
# We use a small batch size for manual inspection
data_loader = TabSTARDataLoader(
    h5_path=H5_PATH,
    tokenizer=tokenizer,
    batch_size=10,
    steps_per_epoch=1,
    split='val'
)

TabSTARDataLoader (val) initialized with 5 datasets.


In [11]:
def predict_batch(batch):
    feat_ids = batch["feature_input_ids"].to(DEVICE)
    feat_mask = batch["feature_attention_mask"].to(DEVICE)
    feat_nums = batch["feature_num_values"].to(DEVICE)
    target_ids = batch["target_token_ids"].to(DEVICE)
    target_mask = batch["target_attention_mask"].to(DEVICE)
    task_type = batch["task_type"]
    
    with torch.no_grad():
        logits = model(
            feature_input_ids=feat_ids,
            feature_attention_mask=feat_mask,
            feature_num_values=feat_nums,
            target_token_ids=target_ids,
            target_attention_mask=target_mask,
            task_type=task_type
        )
    
    return logits, batch["labels"], task_type

In [12]:
# Fetch one batch and inspect
for batch in data_loader:
    logits, labels, task_type = predict_batch(batch)
    
    print(f"--- Dataset: {batch['dataset_name']} ({task_type}) ---")
    
    for i in range(len(labels)):
        print(f"\nSample {i+1}:")
        
        # 1. Show some features (first 3 for brevity)
        # Note: We don't have the raw strings here easily, but we can see the task
        print(f"  Target Description: {batch['dataset_name']} prediction")
        
        # 2. Show Ground Truth vs Prediction
        true_val = labels[i].item()
        
        if task_type == 'classification':
            probs = torch.softmax(logits[i], dim=0)
            pred_class = torch.argmax(probs).item()
            conf = probs[pred_class].item()
            print(f"  Ground Truth: Class {true_val}")
            print(f"  Prediction:   Class {pred_class} (Conf: {conf:.2%})")
        else:
            pred_val = logits[i].item()
            print(f"  Ground Truth: {true_val:.4f}")
            print(f"  Prediction:   {pred_val:.4f}")
    break # Only one batch for inspection

--- Dataset: CONCRETE_STRENGTH (regression) ---

Sample 1:
  Target Description: CONCRETE_STRENGTH prediction
  Ground Truth: -0.3951
  Prediction:   -0.5467

Sample 2:
  Target Description: CONCRETE_STRENGTH prediction
  Ground Truth: -0.1250
  Prediction:   -1.0056

Sample 3:
  Target Description: CONCRETE_STRENGTH prediction
  Ground Truth: 1.0685
  Prediction:   -0.4397

Sample 4:
  Target Description: CONCRETE_STRENGTH prediction
  Ground Truth: -0.4718
  Prediction:   -0.7448

Sample 5:
  Target Description: CONCRETE_STRENGTH prediction
  Ground Truth: 0.4643
  Prediction:   -0.6659

Sample 6:
  Target Description: CONCRETE_STRENGTH prediction
  Ground Truth: -0.0035
  Prediction:   -0.1680

Sample 7:
  Target Description: CONCRETE_STRENGTH prediction
  Ground Truth: -1.3354
  Prediction:   -1.2445

Sample 8:
  Target Description: CONCRETE_STRENGTH prediction
  Ground Truth: -0.3646
  Prediction:   -0.7703

Sample 9:
  Target Description: CONCRETE_STRENGTH prediction
  Ground Tru