✅ Summary: What This Code Does

* Loads prompt→schema pairs

* Builds vocab, prepares model

* Trains a transformer to predict columns based on prompt

* Generates realistic fake data from schema prediction

* Saves model and supports re-use

In [1]:
import json
from pprint import pprint

import torch
import torch.nn as nn

📌 What it does:

* `json`: load training data from file.

* `pprint`: nicely prints fake data.

* `torch`: core PyTorch functionality.

* `nn`: contains layers like Embedding, MultiheadAttention, Linear.

✅ Example: Prepares you to define a neural network and handle data.

## Load data

In [2]:
with open("training_data.json") as f:
    training_data = json.load(f)

📌 Loads your training prompts and expected schema columns.

✅ Example input (from JSON):

```json
[
  ["generate a bank dataset", ["account_number", "name", "balance", "currency"]],
  ["create student data", ["student_id", "name", "grade", "email"]]
]
```

## Add [CLS] token to prompts

In [3]:
training_data = [
    (["[CLS]"] + prompt, columns)
    for prompt, columns in training_data
]

📌 Prepends [CLS] token to every prompt.
This acts like a special summary token.

✅ Example:

```python
["generate", "a", "bank", "dataset"] → ["[CLS]", "generate", "a", "bank", "dataset"]
```

## Build Vocabulary

In [None]:
all_tokens = set()
all_columns = set()
for prompt, columns in training_data:
    all_tokens.update(prompt)
    all_tokens.update(columns)
    all_columns.update(columns)
vocab = sorted(all_tokens)
columns_vocab = sorted(all_columns)

📌 Creates vocabularies of:

* All unique words (vocab)

* All unique schema fields (columns_vocab)

✅ Example:

```python
vocab = ["[CLS]", "generate", "bank", "dataset", ..., "email"]
columns_vocab = ["account_number", "name", "balance", "currency", ..., "grade"]
```

## Map Words to Indices

In [None]:
token_to_idx = {token: idx for idx, token in enumerate(vocab)}
idx_to_token = {idx: token for token, idx in token_to_idx.items()}
column_to_idx = {col: i for i, col in enumerate(columns_vocab)}
idx_to_column = {i: col for col, i in column_to_idx.items()}

vocab_size = len(vocab)
num_columns = len(columns_vocab)

📌 Creates mappings between tokens and indices for encoding/decoding.

✅ Example:

```python
token_to_idx["bank"] → 4
column_to_idx["email"] → 7
```

## Encode Functions

In [5]:
def encode(tokens):
    return [token_to_idx[t] for t in tokens]

def encode_columns(cols):
    vec = torch.zeros(num_columns)
    for col in cols:
        if col in column_to_idx:
            vec[column_to_idx[col]] = 1.0
    return vec


📌 Converts:

* Prompt tokens into integer IDs

* Schema column labels into multi-label vectors

```python
encode(["[CLS]", "generate", "bank"]) → [0, 1, 2]
encode_columns(["account_number", "balance"]) →
    tensor([1, 0, 1, 0, ..., 0])  # 1 at relevant column indices
```

## Define Transformer-Based Classifier

In [6]:
class SchemaPredictor(nn.Module):
    def __init__(self, vocab_size, num_labels, embed_dim=64, num_heads=4):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
        self.ln = nn.LayerNorm(embed_dim)
        self.output_proj = nn.Linear(embed_dim, num_labels)

    def forward(self, x):
        x = self.embedding(x)
        attn_output, _ = self.attn(x, x, x)
        x = self.ln(x + attn_output)
        cls_token = x[:, 0, :]  # use [CLS] token
        return self.output_proj(cls_token)

📌 This model:

* Embeds each token (into 64-dim vector)

* Applies multi-head self-attention

* Uses [CLS] token output to predict schema columns

✅ Example:
Prompt → embeddings → attention → [CLS] vector → output prediction

## Training Loop

In [7]:
def train_model(model, training_data, epochs=300, lr=0.01):
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    loss_fn = nn.BCEWithLogitsLoss()
    for epoch in range(epochs):
        total_loss = 0
        for prompt, cols in training_data:
            x = torch.tensor([encode(prompt)], dtype=torch.long)
            y = encode_columns(cols).unsqueeze(0)

            logits = model(x)
            loss = loss_fn(logits, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            total_loss += loss.item()
        if epoch % 50 == 0:
            print(f"Epoch {epoch}: Loss = {total_loss:.4f}")
    return model

📌 This function:

* Trains your model using **BCEWithLogitsLoss** (multi-label classification)

* Compares model predictions to expected schema columns

* Uses Adam optimizer

✅ Example:

```python
Prompt: ["generate", "bank", "dataset"]
Target: ["account_number", "balance", "currency"]
```

Model learns to output 1.0 for these columns.

## Inference: Predict Columns

In [8]:
def predict_columns(model, prompt_tokens, threshold=0.5):
    x = torch.tensor([encode(["[CLS]"] + prompt_tokens)], dtype=torch.long)
    logits = model(x)
    probs = torch.sigmoid(logits).squeeze()
    pred_indices = (probs > threshold).nonzero(as_tuple=True)[0].tolist()
    return [idx_to_column[i] for i in pred_indices]

📌 Takes a user prompt like "`generate a bank dataset`", feeds it through the model, and returns column labels where probability > threshold.

✅ Example:

```python
["bank", "dataset"] → ["account_number", "name", "balance"]
```

## Generate Fake Data

In [9]:
from fake_data_utils import FAKE_VALUE_FUNCTIONS, generate_fake_data

📌 Uses dictionary of lambda functions (e.g., fake.name(), fake.iban()) to populate a dataset row-by-row.

✅ Example:

```python
columns = ["account_number", "name"]
→ [["DE123...", "John Doe"], ["GB432...", "Alice Smith"]]
```

## Train the Model

In [10]:
model = SchemaPredictor(vocab_size, num_columns)
model = train_model(model, training_data)

Epoch 0: Loss = 43.7987
Epoch 50: Loss = 0.0124
Epoch 100: Loss = 0.0003
Epoch 150: Loss = 0.0000
Epoch 200: Loss = 0.0000
Epoch 250: Loss = 0.0000


📌 Instantiates and trains your model on the data.

## Save Trained Model

In [11]:
torch.save(model.state_dict(), "schema_model.pt")

📌 Saves weights so that can be later loaded without retraining.

## Full Inference Run

In [13]:
prompt = "generate a sport dataset"
tokens = prompt.lower().split()
columns = predict_columns(model, tokens, threshold=0.4)
columns = [col for col in columns if col in FAKE_VALUE_FUNCTIONS and col not in tokens]
columns = list(dict.fromkeys(columns))

print("🔍 Predicted schema columns:", columns)
pprint(generate_fake_data(columns, n=5))

🔍 Predicted schema columns: ['member_id', 'membership_status']
{'columns': ['member_id', 'membership_status'],
 'rows': [['03fff75e-4475-456a-8bf9-c9cc5065caf8', 'active'],
          ['82e314da-f918-4392-9c26-c22bf68af4f1', 'active'],
          ['8030b490-e90f-46c7-9490-21541179088c', 'inactive'],
          ['95ef72bf-f951-4e27-829c-538062874f70', 'inactive'],
          ['3d110616-98b9-4571-b0af-7d23d9ee7977', 'inactive']]}


✅ Full example:

```json
Prompt: "generate a bank dataset"
→ Predicted: ['account_number', 'balance', 'currency', 'name']
→ Fake Data:
[
  ['GB89...', '1234.56', 'USD', 'Alice'],
  ['GB90...', '7823.19', 'EUR', 'Bob'],
  ...
]
```