# Transfer Learning Stack Demo

This notebook demonstrates the models and functionality introduced across the first five diffs in the transfer learning refactoring stack:

| # | Diff | What it introduces |
|---|------|--------------------|
| 1 | D92844568 | `FullyBayesianMultiTaskGP` as public API + Ax downstream |
| 2 | D92844565 | `SaasFullyBayesianMultiTaskGP` as thin subclass of the new base |
| 3 | D92844566 | Generalized internals — any `PyroModel(is_multitask=True)` works |
| 4 | D92844567 | `is_multitask` kwarg on `PyroModel` base class |
| 5 | D92836693 | `HeterogeneousMTGP` with inferred noise (`HadamardGaussianLikelihood`) |

In [20]:
import torch
from botorch.fit import fit_fully_bayesian_model_nuts
from botorch.models.fully_bayesian import (
    MaternPyroModel,
    SaasPyroModel,
)
from botorch.models.fully_bayesian_multitask import (
    FullyBayesianMultiTaskGP,
    SaasFullyBayesianMultiTaskGP,
)
from botorch.models.heterogeneous_mtgp import HeterogeneousMTGP

NUM_SAMPLES = 16
WARMUP = 32
THINNING = 1

print("Imports OK")

Imports OK


## Shared synthetic data

Generate a simple 2-task dataset with 3 input features. Task 0 is a shifted sinusoid, task 1 is a linear function.

In [21]:
torch.manual_seed(0)

d = 3  # input dimension
n_per_task = 20

# Task 0 data
X0 = torch.rand(n_per_task, d)
Y0 = torch.sin(X0.sum(dim=-1, keepdim=True)) + 0.1 * torch.randn(n_per_task, 1)
i0 = torch.zeros(n_per_task, 1)

# Task 1 data
X1 = torch.rand(n_per_task, d)
Y1 = X1.sum(dim=-1, keepdim=True) + 0.1 * torch.randn(n_per_task, 1)
i1 = torch.ones(n_per_task, 1)

# Combined training data with task feature as last column
train_X = torch.cat([torch.cat([X0, i0], dim=-1), torch.cat([X1, i1], dim=-1)])
train_Y = torch.cat([Y0, Y1])

# Known observation noise
train_Yvar = 0.01 * torch.ones_like(train_Y)

# Test points for task 0
test_X = torch.cat([torch.rand(5, d), torch.zeros(5, 1)], dim=-1)

print(f"train_X shape: {train_X.shape}")
print(f"train_Y shape: {train_Y.shape}")
print(f"test_X shape:  {test_X.shape}")

train_X shape: torch.Size([40, 4])
train_Y shape: torch.Size([40, 1])
test_X shape:  torch.Size([5, 4])


## 1. `SaasFullyBayesianMultiTaskGP` — the SAAS subclass (Diffs 1 + 2)

The existing `SaasFullyBayesianMultiTaskGP` now inherits from the new public `FullyBayesianMultiTaskGP` base and defaults to `SaasPyroModel(is_multitask=True)`. All existing behavior is preserved.

In [22]:
# The classic SAAS multi-task GP — unchanged API
model_saas = SaasFullyBayesianMultiTaskGP(
    train_X=train_X,
    train_Y=train_Y,
    train_Yvar=train_Yvar,
    task_feature=-1,
)

# Verify it is a subclass of the new public base
assert isinstance(model_saas, FullyBayesianMultiTaskGP)
print(f"SaasFullyBayesianMultiTaskGP is subclass of FullyBayesianMultiTaskGP: True")
print(f"PyroModel type: {type(model_saas.pyro_model).__name__}")

fit_fully_bayesian_model_nuts(
    model_saas,
    warmup_steps=WARMUP,
    num_samples=NUM_SAMPLES,
    thinning=THINNING,
    disable_progbar=True,
)

posterior_saas = model_saas.posterior(test_X)
print(f"Posterior mean shape: {posterior_saas.mean.shape}")
print(f"Median lengthscale:   {model_saas.median_lengthscale}")
print("✅ SaasFullyBayesianMultiTaskGP works as before")

[W 260210 12:34:24 autoreload:1553] The model inputs are of type torch.float32. It is strongly recommended to use double precision in BoTorch, as this improves both precision and stability and can help avoid numerical errors. See https://github.com/meta-pytorch/botorch/discussions/1444


TypeError: super(type, obj): obj must be an instance or subtype of type

## 2. `FullyBayesianMultiTaskGP` with `MaternPyroModel` (Diffs 3 + 4)

The refactored base class now accepts *any* `PyroModel(is_multitask=True)`. Here we use `MaternPyroModel` — a simpler prior without SAAS sparsity.

In [None]:
# Use the public base class directly with a non-SAAS PyroModel
model_matern = FullyBayesianMultiTaskGP(
    train_X=train_X,
    train_Y=train_Y,
    train_Yvar=train_Yvar,
    task_feature=-1,
    pyro_model=MaternPyroModel(is_multitask=True),
)

print(f"PyroModel type: {type(model_matern.pyro_model).__name__}")
print(f"is_multitask:   {model_matern.pyro_model.is_multitask}")

fit_fully_bayesian_model_nuts(
    model_matern,
    warmup_steps=WARMUP,
    num_samples=NUM_SAMPLES,
    thinning=THINNING,
    disable_progbar=True,
)

posterior_matern = model_matern.posterior(test_X)
print(f"Posterior mean shape: {posterior_matern.mean.shape}")
print(f"Num MCMC samples:     {model_matern.num_mcmc_samples}")
print("✅ FullyBayesianMultiTaskGP + MaternPyroModel works")

## 3. `FullyBayesianMultiTaskGP` with `SaasPyroModel` explicitly (Diff 3)

We can also pass `SaasPyroModel(is_multitask=True)` explicitly to the base class — this is equivalent to using `SaasFullyBayesianMultiTaskGP` but demonstrates the generalized constructor.

In [None]:
# Explicit SaasPyroModel passed to the base class
model_explicit_saas = FullyBayesianMultiTaskGP(
    train_X=train_X,
    train_Y=train_Y,
    train_Yvar=train_Yvar,
    task_feature=-1,
    pyro_model=SaasPyroModel(is_multitask=True),
)

print(f"PyroModel type: {type(model_explicit_saas.pyro_model).__name__}")

fit_fully_bayesian_model_nuts(
    model_explicit_saas,
    warmup_steps=WARMUP,
    num_samples=NUM_SAMPLES,
    thinning=THINNING,
    disable_progbar=True,
)

posterior_explicit = model_explicit_saas.posterior(test_X)
print(f"Posterior mean shape: {posterior_explicit.mean.shape}")
print("✅ FullyBayesianMultiTaskGP + explicit SaasPyroModel works")

## 4. Validation: `is_multitask=True` is enforced (Diff 4)

Passing a single-task `PyroModel` to `FullyBayesianMultiTaskGP` should raise a clear error. Similarly, `set_inputs` with `is_multitask=True` and `task_feature=None` gives a helpful `ValueError`.

In [None]:
# 4a. Verify that a non-multitask PyroModel is rejected
rejected = False
try:
    FullyBayesianMultiTaskGP(
        train_X=train_X,
        train_Y=train_Y,
        train_Yvar=train_Yvar,
        task_feature=-1,
        pyro_model=MaternPyroModel(is_multitask=False),  # Wrong!
    )
except ValueError as e:
    rejected = True
    print(f"✅ Non-multitask PyroModel correctly rejected: {e}")

assert rejected, "Should have raised ValueError"

# 4b. Verify that task_feature=None with is_multitask=True raises ValueError
pm = MaternPyroModel(is_multitask=True)
task_feature_rejected = False
try:
    pm.set_inputs(
        train_X=train_X,
        train_Y=train_Y,
        train_Yvar=train_Yvar,
        task_feature=None,  # Missing!
    )
except ValueError as e:
    task_feature_rejected = True
    print(f"✅ Missing task_feature correctly rejected: {e}")

assert task_feature_rejected, "Should have raised ValueError"

## 5. `HeterogeneousMTGP` with inferred noise (Diff 5)

Diff 5 (D92836693) adds support for `HeterogeneousMTGP` with inferred noise by using a `HadamardGaussianLikelihood`. We demonstrate this by constructing a model *without* specifying `train_Yvars`.

In [None]:
torch.manual_seed(42)

# Task 0: 3 features [0, 1, 2]
X0_het = torch.rand(15, 3)
Y0_het = torch.sin(X0_het.sum(-1, keepdim=True)) + 0.1 * torch.randn(15, 1)

# Task 1: 2 features [0, 2] (different search space)
X1_het = torch.rand(15, 2)
Y1_het = X1_het.sum(-1, keepdim=True) + 0.1 * torch.randn(15, 1)

feature_indices = [[0, 1, 2], [0, 2]]  # task 1 shares features 0 and 2
full_feature_dim = 3

# With known noise
model_het_known = HeterogeneousMTGP(
    train_Xs=[X0_het, X1_het],
    train_Ys=[Y0_het, Y1_het],
    train_Yvars=[0.01 * torch.ones(15, 1), 0.01 * torch.ones(15, 1)],
    feature_indices=feature_indices,
    full_feature_dim=full_feature_dim,
)
print(f"HeterogeneousMTGP (known noise) likelihood: {type(model_het_known.likelihood).__name__}")

# With inferred noise (new in Diff 5)
model_het_inferred = HeterogeneousMTGP(
    train_Xs=[X0_het, X1_het],
    train_Ys=[Y0_het, Y1_het],
    train_Yvars=None,  # Infer noise!
    feature_indices=feature_indices,
    full_feature_dim=full_feature_dim,
)
print(f"HeterogeneousMTGP (inferred noise) likelihood: {type(model_het_inferred.likelihood).__name__}")

# Test posterior for the inferred-noise model
test_X_het = torch.rand(5, 3)  # task 0 features only
posterior_het = model_het_inferred.posterior(test_X_het)
print(f"Posterior mean shape: {posterior_het.mean.shape}")
print("✅ HeterogeneousMTGP with inferred noise (HadamardGaussianLikelihood) works")

## Summary

All five diffs are exercised:

| Diff | Model / Feature | Status |
|------|-----------------|--------|
| D92844568 | `FullyBayesianMultiTaskGP` as public API | ✅ |
| D92844565 | `SaasFullyBayesianMultiTaskGP` as subclass | ✅ |
| D92844566 | Generalized internals (`MaternPyroModel`, `SaasPyroModel`) | ✅ |
| D92844567 | `is_multitask` on `PyroModel` + validation | ✅ |
| D92836693 | `HeterogeneousMTGP` inferred noise | ✅ |