In [6]:
import torch
from tabpfn.model.loading import load_model_criterion_config
from tabpfn import TabPFNClassifier

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

---
### PyTorch interface

In [7]:
model_name = "TabPFN-Wide-8k" 
assert model_name in ["TabPFN-Wide-1.5k", "TabPFN-Wide-5k", "TabPFN-Wide-8k", "TabPFNv2"], f"Model name {model_name} not recognized."
checkpoint_path = f"./models/{model_name}_submission.pt"

In [8]:
model, _, _ = load_model_criterion_config(
            model_path=None,
            check_bar_distribution_criterion=False,
            cache_trainset_representation=False,
            which='classifier',
            version='v2',
            download=True,
        )
if model_name != "TabPFNv2":
    model.features_per_group = 1
    checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint)

---
### Sklearn interface (requires PyTorch model to be loaded)

In [None]:
from tabpfnwide.patches import fit
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
tabpfn_classifier = TabPFNClassifier(n_estimators=1, device=device, ignore_pretraining_limits=True) # Turn off ensembling 
# Patch the fit method to implement an easy way to fit TabPFN-Wide while using the sklearn interface
setattr(TabPFNClassifier, 'fit', fit)

# Example data
X, y = make_classification(n_samples=50, n_features=10, n_informative=2, n_redundant=2, n_classes=2, random_state=42)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [5]:
tabpfn_classifier.fit(X_train, y_train, model=model)
tabpfn_classifier.score(X_test, y_test)

0.9