This notebook is used to recreate ForecastPFN in Pytorch using its Tensorflow saved weights and implementation provided by the authors.

In [1]:
import tensorflow as tf
from benchmark.utils.metrics import smape
import numpy as np
import torch

2025-06-04 00:06:30.063569: I tensorflow/core/util/util.cc:169] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-06-04 00:06:30.066011: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory
2025-06-04 00:06:30.066018: I tensorflow/stream_executor/cuda/cudart_stub.cc:29] Ignore above cudart dlerror if you do not have a GPU set up on your machine.
  from .autonotebook import tqdm as notebook_tqdm


# Tensorflow Model

In [2]:
tf_model = tf.keras.models.load_model("saved_weights/", custom_objects={'smape': smape})
tf_model.summary()

2025-06-04 00:06:35.014395: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:975] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2025-06-04 00:06:35.015834: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcufft.so.10'; dlerror: libcufft.so.10: cannot open shared object file: No such file or directory
2025-06-04 00:06:35.044542: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcusparse.so.11'; dlerror: libcusparse.so.11: cannot open shared object file: No such file or directory
2025-06-04 00:06:35.044566: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1850] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the requir

Model: "transformer_model"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 position_expansion (Positio  multiple                 0         
 nExpansion)                                                     
                                                                 
 position_expansion_1 (Posit  multiple                 0         
 ionExpansion)                                                   
                                                                 
 position_expansion_2 (Posit  multiple                 0         
 ionExpansion)                                                   
                                                                 
 position_expansion_3 (Posit  multiple                 0         
 ionExpansion)                                                   
                                                                 
 custom_scaling (CustomScali  multiple           

In [3]:
for layer in tf_model.layers:
    print(layer.name)
    for w in layer.weights:
        print(" ", w.name, w.shape)

position_expansion
position_expansion_1
position_expansion_2
position_expansion_3
custom_scaling
NoPosEnc
  transformer_model/NoPosEnc/kernel:0 (1, 36)
  transformer_model/NoPosEnc/bias:0 (36,)
ForPosEnc
  transformer_model/ForPosEnc/kernel:0 (1, 36)
  transformer_model/ForPosEnc/bias:0 (36,)
ConcatPos
ConcatEmbed
embedding
  transformer_model/embedding/embeddings:0 (10, 36)
AppendTarget
transformer_block
  transformer_model/transformer_block/transformer_block_attention/query/kernel:0 (72, 4, 72)
  transformer_model/transformer_block/transformer_block_attention/query/bias:0 (4, 72)
  transformer_model/transformer_block/transformer_block_attention/key/kernel:0 (72, 4, 72)
  transformer_model/transformer_block/transformer_block_attention/key/bias:0 (4, 72)
  transformer_model/transformer_block/transformer_block_attention/value/kernel:0 (72, 4, 72)
  transformer_model/transformer_block/transformer_block_attention/value/bias:0 (4, 72)
  transformer_model/transformer_block/transformer_block

In [4]:
# Load model
layers = {l.name: l for l in tf_model.layers}

In [5]:
def create_mock_input(B=2, context_len=100, target_len=1):
    # ts: (B, context_len, 5) — year, month, day, dow, placeholder
    ts = tf.constant(np.stack([
        np.stack([
            [2023, 1 + (i % 12), 1 + (i % 28), i % 7, 0]
            for i in range(context_len)
        ]) for _ in range(B)
    ]), dtype=tf.int64)

    # history: (B, context_len)
    history = tf.random.normal((B, context_len), dtype=tf.float32)

    # target_ts: (B, target_len, 5) — consistent with ts
    target_ts = tf.constant(np.stack([
        np.stack([
            [2023, 1, 1, 0, 0]  # dummy future time
        for _ in range(target_len)]) for _ in range(B)
    ]), dtype=tf.int64)

    # task: (B,)
    task = tf.constant(np.random.randint(0, 10, size=(B,)), dtype=tf.int32)

    return {
        "ts": ts,
        "history": history,
        "target_ts": target_ts,
        "task": task,
    }

In [6]:
mock_inputs = create_mock_input()
output = tf_model(mock_inputs, training=False)

# Extract from full model output
full_result = output["result"][:, :1]
full_scale  = output["scale"][:, :1]
print(full_result, full_scale)

tf.Tensor(
[[1.0238485]
 [1.0556059]], shape=(2, 1), dtype=float32) tf.Tensor(
[[1.2981894]
 [1.3832798]], shape=(2, 1), dtype=float32)


### Model Breakdown

In [7]:
ts        = mock_inputs["ts"]
history   = mock_inputs["history"]
target_ts = mock_inputs["target_ts"]
task      = mock_inputs["task"]
print(ts.shape, history.shape, target_ts.shape, task.shape)
epsilon = 1e-4
# Extract layers
position_expansion = layers["position_expansion"]
position_expansion_1 = layers["position_expansion_1"]
position_expansion_2 = layers["position_expansion_2"]
position_expansion_3 = layers["position_expansion_3"]
custom_scaling = layers["custom_scaling"]
no_pos_enc = layers["NoPosEnc"]
for_pos_enc = layers["ForPosEnc"]
concat_embed = layers["ConcatEmbed"]
append_target = layers["AppendTarget"]
embedding = layers["embedding"]
transformer_block_0 = layers["transformer_block"]
transformer_block_1 = layers["transformer_block_1"]
final_output = layers["FinalOutput"]

# ---- Block-by-block forward ----

# 1. Position encoding
year, month, day, dow = ts[:, :, 0], ts[:, :, 1], ts[:, :, 2], ts[:, :, 3]
delta_year = tf.clip_by_value(year[:, -1:] - year, 0, 9)
pos = tf.concat([
    position_expansion(delta_year),
    position_expansion_1(month),
    position_expansion_2(day),
    position_expansion_3(dow),
], axis=-1)

# 2. Scaling + history embedding
history_channels = tf.expand_dims(history, -1)
scale, scaled = custom_scaling(history_channels, epsilon)
embed_nopos = no_pos_enc(scaled)
embed_pos = for_pos_enc(scaled) + tf.cast(pos, tf.float32)
embedded = concat_embed([embed_nopos, embed_pos])

# 3. Target token
target_year = tf.clip_by_value(year[:, -1:] - target_ts[:, :, 0], 0, 9)
target_month, target_day, target_dow = target_ts[:, :, 1], target_ts[:, :, 2], target_ts[:, :, 3]
target_pos = tf.concat([
    position_expansion(target_year),
    position_expansion_1(target_month),
    position_expansion_2(target_day),
    position_expansion_3(target_dow),
], axis=-1)
task_embed = embedding(task)
target_token = concat_embed([task_embed, task_embed + tf.cast(tf.squeeze(target_pos, 1), tf.float32)])

# 4. Append target token
x = append_target([embedded, tf.expand_dims(target_token, axis=1)])

# 5. Mask
seq_mask = tf.cast(year > 0, tf.bool)
seq_mask = tf.pad(seq_mask, [[0, 0], [0, 1]], constant_values=True)
mask = tf.logical_and(tf.expand_dims(seq_mask, 1), tf.expand_dims(seq_mask, 2))

# 6. Transformer blocks
x = transformer_block_0(x, mask=mask, training=False)
x = transformer_block_1(x, mask=mask, training=False)
x = x[:, -1]

# 7. Output
rescaled = final_output(x) * tf.squeeze(scale[:, -1:], axis=-1)
print(final_output, rescaled)

(2, 100, 5) (2, 100) (2, 1, 5) (2,)
<keras.layers.core.dense.Dense object at 0x75ae957db820> tf.Tensor(
[[1.0238485]
 [1.0556059]], shape=(2, 1), dtype=float32)


Verify manually-created version vs. original tensorflow version

In [8]:
print("Manual forward output:", rescaled.numpy())
print("Full model output    :", full_result.numpy())
print("Match (result):", np.allclose(rescaled.numpy(), full_result.numpy(), atol=1e-5))

manual_scale = np.squeeze(scale.numpy(), axis=-1)  # → shape (2, 1)
print("Manual scale:", scale[:, -1:].numpy())
print("Full scale  :", full_scale.numpy())
print("Match (scale):", np.allclose(manual_scale, full_scale.numpy(), atol=1e-5))

Manual forward output: [[1.0238485]
 [1.0556059]]
Full model output    : [[1.0238485]
 [1.0556059]]
Match (result): True
Manual scale: [[[1.2981894]]

 [[1.3832798]]]
Full scale  : [[1.2981894]
 [1.3832798]]
Match (scale): True


# Recreate Components in Pytorch using Tensorflow Weights

In [9]:
import math
import torch.nn as nn

Mimic input in Pytorch

In [10]:
def convert_tf_to_torch(tf_batch):
    torch_batch = {}
    for k, v in tf_batch.items():
        torch_batch[k] = torch.from_numpy(v.numpy()).to(dtype=torch.float32 if v.dtype.is_floating else torch.long)
    return torch_batch

torch_input  = convert_tf_to_torch(mock_inputs)
ts_pt        = torch_input['ts']
history_pt   = torch_input['history']
target_ts_pt = torch_input['target_ts']
task_pt      = torch_input['task']
print(ts_pt.shape, history_pt.shape, target_ts_pt.shape, task_pt.shape)

torch.Size([2, 100, 5]) torch.Size([2, 100]) torch.Size([2, 1, 5]) torch.Size([2])


### Step 1: Input → Position Encodings

In [11]:
# tensorflow code
year, month, day, dow = ts[:, :, 0], ts[:, :, 1], ts[:, :, 2], ts[:, :, 3]
delta_year = tf.clip_by_value(year[:, -1:] - year, 0, 9)
pos_tf = tf.concat([
    position_expansion(delta_year),
    position_expansion_1(month),
    position_expansion_2(day),
    position_expansion_3(dow),
], axis=-1)

pos_tf.shape

TensorShape([2, 100, 36])

In [12]:
class PositionExpansion(nn.Module):
    def __init__(self, periods: int, freqs: int):
        super().__init__()
        self.periods = periods
        self.channels = freqs * 2

        i = torch.arange(periods + 1).unsqueeze(1)  # shape: [periods+1, 1]
        j = torch.arange(freqs).unsqueeze(0)        # shape: [1, freqs]
        angles = math.pi / periods * (2 ** j) * (i - 1)  # i-1 matches TF

        pe = torch.cat([torch.sin(angles), torch.cos(angles)], dim=1)  # [P+1, 2F]
        self.register_buffer("embedding", pe)

    def forward(self, tc):
        return self.embedding[tc]  # expects tc ∈ [0, periods]


In [13]:
# Convert tf tensors to numpy, then to torch
year_pt, month_pt, day_pt, dow_pt = ts_pt[:, :, 0], ts_pt[:, :, 1], ts_pt[:, :, 2], ts_pt[:, :, 3]
delta_year_pt = (year_pt[:, -1:] - year_pt).clamp(min=0, max=9)

In [14]:
position_expansion_pt = PositionExpansion(10, 4)
position_expansion_1_pt = PositionExpansion(12, 4)
position_expansion_2_pt = PositionExpansion(31, 6)
position_expansion_3_pt = PositionExpansion(7, 4)

# Positional encoding
pos_pt = torch.cat([
    position_expansion_pt(delta_year_pt),
    position_expansion_1_pt(month_pt),
    position_expansion_2_pt(day_pt),
    position_expansion_3_pt(dow_pt),
], dim=-1)

pos_pt.shape

torch.Size([2, 100, 36])

In [15]:
np.allclose(pos_pt.numpy(), pos_tf.numpy(), atol=1e-5)

True

### Step 2: Scaling + History Embedding

In [18]:
# tensorflow code
history_channels_tf = tf.expand_dims(history, -1)
scale_tf, scaled_tf = custom_scaling(history_channels_tf, epsilon)
embed_nopos_tf = no_pos_enc(scaled_tf)
embed_pos_tf = for_pos_enc(scaled_tf) + tf.cast(pos_tf, tf.float32)
embedded_tf = concat_embed([embed_nopos_tf, embed_pos_tf])
embedded_tf.shape

TensorShape([2, 100, 72])

In [19]:
class RobustScaler(nn.Module):
    """
    RobustScaler normalizes input time series while ignoring outliers and missing values.
    It masks out zeros, clips extreme values above mean + 2*std, rescales using
    mean + std of clipped data, and clips final output to [0, 3].
    This improves robustness in the presence of noise or missing data.
    """
    def forward(self, x, epsilon):
        # x: [B, T, 1]
        B, T, _ = x.shape
        x = x.squeeze(-1)  # → [B, T]
        scale = torch.zeros((B, 1, 1), device=x.device)
        scaled = torch.zeros((B, T, 1), device=x.device)

        for b in range(B):
            series = x[b]  # shape: [T]

            # First mask and stats
            non_zero = series[series != 0]
            if non_zero.numel() == 0:
                mean = std = torch.tensor(0.0, device=x.device)
            else:
                mean = non_zero.mean()
                std = non_zero.std(unbiased=False)

            upper = mean + 2 * std
            clipped = torch.clamp(series, min=0.0, max=upper)

            # Second pass stats
            non_zero_clipped = clipped[clipped != 0]
            if non_zero_clipped.numel() == 0:
                mean_clip = std_clip = torch.tensor(0.0, device=x.device)
            else:
                mean_clip = non_zero_clipped.mean()
                std_clip = non_zero_clipped.std(unbiased=False)

            s = mean_clip + std_clip + epsilon
            scale[b, 0, 0] = s
            scaled[b, :, 0] = torch.clamp(series / s, 0.0, 3.0)

        return scale, scaled

In [22]:
robust_scaler = RobustScaler()
# history_channels_torch = torch.from_numpy(history_channels.numpy())
history_channels_pt = history_pt.unsqueeze(-1)
scale_pt, scaled_pt = robust_scaler(history_channels_pt, epsilon)
np.allclose(scale_pt, scale_tf, atol=1e-5)

True

In [24]:
# 1. Create PyTorch Linear layer
expand_nopos_pt = nn.Sequential(
    nn.Linear(1, 36),
    nn.ReLU()
)

# 2. Get TF weights 
# tf_weights: shape (1, 36), tf_bias: shape (36,)
tf_weights = no_pos_enc.kernel.numpy().T  # TF: (1, 36) → PT: (36, 1)
tf_bias    = no_pos_enc.bias.numpy()      # shape (36,)

# 3. Copy weights to PyTorch layer
linear_nopos = expand_nopos_pt[0]
linear_nopos.weight.data.copy_(torch.from_numpy(tf_weights))
linear_nopos.bias.data.copy_(torch.from_numpy(tf_bias))

# 4. Test match
embed_nopos_pt = expand_nopos_pt(scaled_pt)  # [B, T, 36]

# 5. Compare
print("Match:", np.allclose(embed_nopos_pt.detach().numpy(), embed_nopos_tf.numpy(), atol=1e-5))

Match: True


In [25]:
expand_forpos_pt = nn.Sequential(
    nn.Linear(1, 36),
    nn.ReLU()
)

tf_weights = for_pos_enc.kernel.numpy().T
tf_bias    = for_pos_enc.bias.numpy()

linear_forpos = expand_forpos_pt[0]
linear_forpos.weight.data.copy_(torch.from_numpy(tf_weights))
linear_forpos.bias.data.copy_(torch.from_numpy(tf_bias))

embed_pos_pt = expand_forpos_pt(scaled_pt) + pos_pt.float()
np.allclose(embed_pos_pt.detach().numpy(), embed_pos_tf.numpy(), atol=1e-5)

True

In [26]:
embedded_pt = torch.cat([embed_nopos_pt, embed_pos_pt], dim=-1)
np.allclose(embedded.numpy(), embedded_pt.detach().numpy(), atol=1e-5)

True

### Step 3: Target Token & Mask

In [27]:
# tensorflow code
# 3. Target token
target_year = tf.clip_by_value(year[:, -1:] - target_ts[:, :, 0], 0, 9)
target_month, target_day, target_dow = target_ts[:, :, 1], target_ts[:, :, 2], target_ts[:, :, 3]
target_pos_tf = tf.concat([
    position_expansion(target_year),
    position_expansion_1(target_month),
    position_expansion_2(target_day),
    position_expansion_3(target_dow),
], axis=-1)
task_embed_tf = embedding(task)
target_token_tf = concat_embed([task_embed, task_embed + tf.cast(tf.squeeze(target_pos, 1), tf.float32)])
target_token_tf.shape

# 4. Append target token
x_tf = append_target([embedded, tf.expand_dims(target_token, axis=1)])

# 5. Mask
seq_mask = tf.cast(year > 0, tf.bool)
seq_mask = tf.pad(seq_mask, [[0, 0], [0, 1]], constant_values=True)
mask_tf = tf.logical_and(tf.expand_dims(seq_mask, 1), tf.expand_dims(seq_mask, 2))

In [63]:
# Extract components from target_ts
target_year_pt = (year_pt[:, -1:] - target_ts_pt[:, :, 0]).clamp(min=0, max=10)
target_month_pt = target_ts_pt[:, :, 1]
target_day_pt = target_ts_pt[:, :, 2]
target_dow_pt = target_ts_pt[:, :, 3]

In [64]:
# Position Encodings
target_pos_pt = torch.cat([
    position_expansion_pt(target_year_pt),
    position_expansion_1_pt(target_month_pt),
    position_expansion_2_pt(target_day_pt),
    position_expansion_3_pt(target_dow_pt),
], dim=-1)  # shape: [B, 1, 36]
np.allclose(target_pos_pt, target_pos_tf, atol=1e-5)

True

In [31]:
# Task embedding
embedding_pt = nn.Embedding(num_embeddings=10, embedding_dim=36)
embedding_pt.weight.data.copy_(torch.from_numpy(embedding.get_weights()[0]))
task_embed_pt = embedding_pt(task_pt)
np.allclose(task_embed_pt.detach().numpy(), task_embed_tf.numpy(), atol=1e-5)

True

In [32]:
# Target token
target_token_pt = torch.cat([task_embed_pt,
                             task_embed_pt + target_pos_pt.squeeze(1).float()], dim=-1).detach()
np.allclose(target_token_pt.numpy(), target_token_tf.numpy(), atol=1e-5)

True

In [33]:
# x: [B, T, D], target_token: [B, D]
x_pt = torch.cat([embedded_pt, target_token_pt.unsqueeze(1)], dim=1) # -> [B, T+1, D]
np.allclose(x_tf.numpy(), x_pt.detach().numpy(), atol=1e-5)

True

In [34]:
# Assume `year` is shape [B, T] from ts[:, :, 0]
seq_mask_pt = (year_pt > 0)  # → [B, T], bool
seq_mask_pt = torch.cat([seq_mask_pt, torch.ones_like(seq_mask_pt[:, :1], dtype=torch.bool)], dim=1)  # [B, T+1]

# Broadcast to [B, T+1, T+1]
mask_pt = seq_mask_pt.unsqueeze(1) & seq_mask_pt.unsqueeze(2)

np.allclose(mask_tf.numpy(), mask_pt.numpy(), atol=1e-5)

True

### Step 4: Transformer Blocks

In [35]:
x_tf = transformer_block_0(x_tf, mask=mask_tf, training=False)
# x_tf = transformer_block_1(x_tf, mask=mask_tf, training=False)
# x_tf = x_tf[:, -1]
x_tf.shape

TensorShape([2, 101, 288])

In [36]:
class CustomSelfAttention(nn.Module):
    def __init__(self, embed_dim=72, num_heads=4, value_dim=72):
        super().__init__()
        self.q_proj = nn.Linear(embed_dim, num_heads * value_dim)
        self.k_proj = nn.Linear(embed_dim, num_heads * value_dim)
        self.v_proj = nn.Linear(embed_dim, num_heads * value_dim)
        self.out_proj = nn.Linear(num_heads * value_dim, embed_dim)
        self.num_heads = num_heads
        self.value_dim = value_dim

    def forward(self, x, mask=None):
        B, T, D = x.size()
        H, V = self.num_heads, self.value_dim

        q = self.q_proj(x).view(B, T, H, V).transpose(1, 2)
        k = self.k_proj(x).view(B, T, H, V).transpose(1, 2)
        v = self.v_proj(x).view(B, T, H, V).transpose(1, 2)

        attn_scores = (q @ k.transpose(-2, -1)) / math.sqrt(V)

        if mask is not None:
            attn_scores = attn_scores.masked_fill(~mask[:, None, :, :], float('-inf'))

        attn_weights = torch.softmax(attn_scores, dim=-1)
        context = attn_weights @ v
        context = context.transpose(1, 2).reshape(B, T, H * V)

        return self.out_proj(context)

class TransformerBlock(nn.Module):
    def __init__(self, d_model=72, heads=4, value_dim=72):
        super().__init__()
        self.attn = CustomSelfAttention(d_model, heads, value_dim)
        self.ff1 = nn.Linear(d_model, 4 * heads * value_dim)  # 4×288=1152
        self.ff2 = nn.Linear(4 * heads * value_dim, heads * value_dim)  # → 288
        self.activation = nn.GELU()

    def forward(self, x, mask=None):
        x = self.attn(x, mask=mask)
        x = self.activation(self.ff1(x))
        x = self.activation(self.ff2(x))
        return x


In [37]:
block = TransformerBlock(d_model=72, heads=4, value_dim=72)

In [38]:
def copy_mha_weights(pt_mha, tf_weights):
    def reshape_tf_to_linear(w_keras, b_keras):
        # w_keras: (in_dim, n_heads, head_dim) → (in_dim, out_dim)
        w = w_keras.numpy().reshape(w_keras.shape[0], -1)
        b = b_keras.numpy().reshape(-1)
        return torch.from_numpy(w.T), torch.from_numpy(b)

    # Query
    q_w, q_b = reshape_tf_to_linear(tf_weights[0], tf_weights[1])
    pt_mha.q_proj.weight.data.copy_(q_w)
    pt_mha.q_proj.bias.data.copy_(q_b)

    # Key
    k_w, k_b = reshape_tf_to_linear(tf_weights[2], tf_weights[3])
    pt_mha.k_proj.weight.data.copy_(k_w)
    pt_mha.k_proj.bias.data.copy_(k_b)

    # Value
    v_w, v_b = reshape_tf_to_linear(tf_weights[4], tf_weights[5])
    pt_mha.v_proj.weight.data.copy_(v_w)
    pt_mha.v_proj.bias.data.copy_(v_b)

    # Output projection: (n_heads, head_dim, out_dim)
    out_w_tf = tf_weights[6].numpy().reshape(-1, tf_weights[6].shape[-1])  # (288, 72)
    pt_mha.out_proj.weight.data.copy_(torch.from_numpy(out_w_tf.T))       # (72, 288)
    pt_mha.out_proj.bias.data.copy_(torch.from_numpy(tf_weights[7].numpy()))

def copy_transformer_block_weights(block_pt, block_tf):
    copy_mha_weights(block_pt.attn, block_tf.attention.weights)

    block_pt.ff1.weight.data.copy_(torch.from_numpy(block_tf.ff1.kernel.numpy().T))
    block_pt.ff1.bias.data.copy_(torch.from_numpy(block_tf.ff1.bias.numpy()))

    block_pt.ff2.weight.data.copy_(torch.from_numpy(block_tf.ff2.kernel.numpy().T))
    block_pt.ff2.bias.data.copy_(torch.from_numpy(block_tf.ff2.bias.numpy()))

copy_transformer_block_weights(block, transformer_block_0)

In [41]:
x_pt = block(x_pt, mask=mask_pt)

In [42]:
np.allclose(x_pt.detach().numpy(), x_tf.numpy(), atol=1e-5)

True

In [43]:
x_tf = transformer_block_1(x_tf, mask=mask_tf, training=False)
x_tf.shape

TensorShape([2, 101, 288])

In [44]:
block_1 = TransformerBlock(d_model=288, heads=4, value_dim=72)
copy_transformer_block_weights(block_1, transformer_block_1)

In [45]:
x_pt = block_1(x_pt, mask=mask_pt)

In [46]:
np.allclose(x_pt.detach().numpy(), x_tf.numpy(), atol=1e-5)

True

In [47]:
x_tf = x_tf[:, -1]
x_tf.shape

TensorShape([2, 288])

In [49]:
x_pt = x_pt[:, -1]
x_pt.shape

torch.Size([2, 288])

In [50]:
np.allclose(x_pt.detach().numpy(), x_tf.numpy(), atol=1e-5) 

True

### Step 5: Final Output

In [51]:
# tensorflow code
rescaled_tf = final_output(x_tf) * tf.squeeze(scale_tf[:, -1:], axis=-1)
rescaled_tf

<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[1.0238485],
       [1.0556059]], dtype=float32)>

In [52]:
final_output_pt = nn.Sequential(
    nn.Linear(288, 1),
    nn.ReLU()
)

tf_weights = final_output.kernel.numpy().T
tf_bias    = final_output.bias.numpy()

linear_final = final_output_pt[0]
linear_final.weight.data.copy_(torch.from_numpy(tf_weights))
linear_final.bias.data.copy_(torch.from_numpy(tf_bias))

rescaled_pt = final_output_pt(x_pt) * scale_pt[:, -1, 0:1]

np.allclose(rescaled_pt.detach().numpy(), rescaled_tf.numpy(), atol=1e-5)

True

In [53]:
np.allclose(rescaled_pt.detach().numpy(), full_result.numpy(), atol=1e-5)

True

In [54]:
np.allclose(scale_pt[:, -1, 0:1].detach().numpy(), full_scale.numpy(), atol=1e-5)

True

# Full Pytorch Implementation

In [57]:
NUM_TASKS = 10
YEAR = 0
MONTH = 1
DAY = 2
DOW = 3

In [56]:
class CustomScaling(nn.Module):
    """
    Used to normalize the historical input series before encoding.
    It ensures that time series values are on a comparable scale across samples
    """
    def __init__(self, method='robust'):
        super().__init__()
        if method == 'robust':
            self.scaler = RobustScaler()
        elif method == 'max':
            self.scaler = MaxScaler()
        else:
            raise ValueError(f"Unknown scaling method: {method}")

    def forward(self, history_channels, epsilon):
        return self.scaler(history_channels, epsilon)

In [79]:
class ForecastPFN(nn.Module):
    def __init__(self, epsilon=1e-4, scaler='robust'):
        super().__init__()
        self.epsilon = epsilon
        self.pos_year = PositionExpansion(10, 4)
        self.pos_month = PositionExpansion(12, 4)
        self.pos_day = PositionExpansion(31, 6)
        self.pos_dow = PositionExpansion(7, 4)
        self.scaler = CustomScaling(scaler)
        self.embed_size = sum(emb.channels for emb in (self.pos_year, self.pos_month, self.pos_day, self.pos_dow))
        self.expand_target_nopos = nn.Sequential(nn.Linear(1, 36),
                                                 nn.ReLU())
        self.expand_target_forpos = nn.Sequential(nn.Linear(1, 36),
                                                  nn.ReLU())
        self.target_marker = nn.Embedding(NUM_TASKS, self.embed_size)
        # Transformer Blocks
        self.d_model = self.embed_size * 2
        self.encoder0 = TransformerBlock(d_model=self.d_model)
        self.encoder1 = TransformerBlock(d_model=self.d_model * 4)
        self.final_output = nn.Sequential(
            nn.Linear(self.d_model * 4, 1),
            nn.ReLU()
        )
        
        
    @staticmethod
    def tc(ts, time_index):
        return ts[:, :, time_index]
    def forward(self, x):
        ts, history, target_ts, task = x['ts'], x['history'], x['target_ts'], x['task']
        
        # Build position encodings
        year = self.tc(ts, YEAR)
        delta_year = (year[:, -1:] - year).clamp(min=0, max=self.pos_year.periods)
        pos_embedding = torch.cat([
            self.pos_year(delta_year_pt),
            self.pos_month(self.tc(ts, MONTH)),
            self.pos_day(self.tc(ts, DAY)),
            self.pos_dow(self.tc(ts, DOW)),
            ], dim=-1)
        
        # Embed history
        history_channels = history.unsqueeze(-1)
        scale, scaled = self.scaler(history_channels, self.epsilon)
        embed_nopos = self.expand_target_nopos(scaled)
        embed_pos = self.expand_target_forpos(scaled) + pos_embedding
        embedded = torch.cat([embed_nopos, embed_pos], dim=-1)
        
        # Embed target
        target_year = (year[:, -1:] - self.tc(target_ts, YEAR)).clamp(min=0, max=self.pos_year.periods)
        target_pos_embed = torch.cat([
            self.pos_year(target_year),
            self.pos_month(self.tc(target_ts, MONTH)),
            self.pos_day(self.tc(target_ts, DAY)),
            self.pos_dow(self.tc(target_ts, DOW))
            ], dim=-1)
        target_pos_embed = target_pos_embed.squeeze(1)
        task_embed = self.target_marker(task)
        target = torch.cat([task_embed, task_embed + target_pos_embed], dim=-1)
        
        # Mask
        seq_mask = (year > 0)  # → [B, T], bool
        seq_mask = torch.cat([seq_mask, torch.ones_like(seq_mask[:, :1], dtype=torch.bool)], dim=1)  # [B, T+1]

        # Broadcast to [B, T+1, T+1]
        mask = seq_mask.unsqueeze(1) & seq_mask.unsqueeze(2)
        
        x = torch.cat([embedded, target.unsqueeze(1)], dim=1)
        x = self.encoder0(x, mask=mask)
        x = self.encoder1(x, mask=mask)
        scale = scale[:, -1, 0:1]
        result = self.final_output(x[:, -1, :]) * scale
        return {'result': result, 'scale': scale}



In [80]:
model_pt = ForecastPFN()
model_pt

ForecastPFN(
  (pos_year): PositionExpansion()
  (pos_month): PositionExpansion()
  (pos_day): PositionExpansion()
  (pos_dow): PositionExpansion()
  (scaler): CustomScaling(
    (scaler): RobustScaler()
  )
  (expand_target_nopos): Sequential(
    (0): Linear(in_features=1, out_features=36, bias=True)
    (1): ReLU()
  )
  (expand_target_forpos): Sequential(
    (0): Linear(in_features=1, out_features=36, bias=True)
    (1): ReLU()
  )
  (target_marker): Embedding(10, 36)
  (encoder0): TransformerBlock(
    (attn): CustomSelfAttention(
      (q_proj): Linear(in_features=72, out_features=288, bias=True)
      (k_proj): Linear(in_features=72, out_features=288, bias=True)
      (v_proj): Linear(in_features=72, out_features=288, bias=True)
      (out_proj): Linear(in_features=288, out_features=72, bias=True)
    )
    (ff1): Linear(in_features=72, out_features=1152, bias=True)
    (ff2): Linear(in_features=1152, out_features=288, bias=True)
    (activation): GELU(approximate='none')
  )
 

In [81]:
# Copy weights from TF model
linear_nopos = model_pt.expand_target_nopos[0]
linear_nopos.weight.data.copy_(torch.from_numpy(no_pos_enc.kernel.numpy().T))
linear_nopos.bias.data.copy_(torch.from_numpy(no_pos_enc.bias.numpy().T))

linear_forpos = model_pt.expand_target_forpos[0]
linear_forpos.weight.data.copy_(torch.from_numpy(for_pos_enc.kernel.numpy().T))
linear_forpos.bias.data.copy_(torch.from_numpy(for_pos_enc.bias.numpy()))

model_pt.target_marker.weight.data.copy_(torch.from_numpy(embedding.get_weights()[0]))

pt_block0 = model_pt.encoder0
pt_block1 = model_pt.encoder1
copy_transformer_block_weights(pt_block0, transformer_block_0)
copy_transformer_block_weights(pt_block1, transformer_block_1)

linear_final = model_pt.final_output[0]
linear_final.weight.data.copy_(torch.from_numpy(final_output.kernel.numpy().T))
linear_final.bias.data.copy_(torch.from_numpy(final_output.bias.numpy()))

model_pt(torch_input)


{'result': tensor([[1.0238],
         [1.0556]], grad_fn=<MulBackward0>),
 'scale': tensor([[1.2982],
         [1.3833]])}

In [82]:
output

{'scale': <tf.Tensor: shape=(2, 1), dtype=float32, numpy=
 array([[1.2981894],
        [1.3832798]], dtype=float32)>,
 'result': <tf.Tensor: shape=(2, 1), dtype=float32, numpy=
 array([[1.0238485],
        [1.0556059]], dtype=float32)>}

Verify pytorch version vs. original tensorflow version

In [83]:
pt_output = model_pt(torch_input)
# TensorFlow output
tf_result = output['result'].numpy()
tf_scale  = output['scale'].numpy()

# PyTorch output
pt_result = pt_output['result'].detach().numpy()
pt_scale  = pt_output['scale'].detach().numpy()

# Comparison
print("Result match:", np.allclose(tf_result, pt_result, atol=1e-5))
print("Scale match :", np.allclose(tf_scale, pt_scale, atol=1e-5))

Result match: True
Scale match : True


### Save Pytorch model

In [84]:
torch.save(model_pt.state_dict(), "forecastpfn_pytorch.pth")