In [None]:
# Load the saved model and info for inference
import torch
import os

In [None]:

# Check if CUDA is available and set device accordingly
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
# Define the SimpleTransformer model class (required for loading the model)
import torch.nn as nn
class SimpleTransformer(nn.Module):
    def __init__(self, vocab_size=50, embed_dim=32, num_heads=2, num_layers=1, num_outputs=6):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        encoder_layer = nn.TransformerEncoderLayer(d_model=embed_dim, nhead=num_heads, batch_first=True)
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.fc = nn.ModuleList([nn.Linear(embed_dim, vocab_size) for _ in range(num_outputs)])
        self.num_outputs = num_outputs
    def forward(self, x):
        batch_size, seq_len, num_features = x.shape
        x = self.embedding(x)  # [batch_size, seq_len, 7, embed_dim]
        x = x.mean(dim=2)      # [batch_size, seq_len, embed_dim]
        x = self.transformer(x)
        x = x[:, -1, :]        # use last token's output
        outs = [fc(x) for fc in self.fc]  # list of [batch_size, vocab_size]
        return outs

In [None]:
# Paths to saved files
save_dir = './saved_model'
model_path = os.path.join(save_dir, 'simple_transformer.pth')
info_path = os.path.join(save_dir, 'model_info.pth')

# Load model info (hyperparameters, etc.)
model_info = torch.load(info_path)
print("Loaded model info:", model_info)

# Recreate the model with loaded hyperparameters
model = SimpleTransformer(
    vocab_size=model_info['vocab_size'],
    embed_dim=model_info['embed_dim'],
    num_heads=model_info['num_heads'],
    num_layers=model_info['num_layers'],
    num_outputs=model_info['num_outputs']
    # seq_len and data_columns are for data prep, not model init
    # You can use them for preparing inference input
    ).to(device)

# Load model weights
model.load_state_dict(torch.load(model_path, map_location=device))
model.eval()
print("Model loaded and ready for inference.")