# HVAE training
Hierarchical Variational Autoencoder (HVAE) training for abstractive text generation, specifically using question-answer pairs from the Physics Stack Exchange dataset.

We have tried to share as many outputs as possible. However, given that we are pretraining the model, with a lot of re-attempts and saving epochs on wandb as and when possible, its difficult to provide output of many cells.

Additional References to some of the training we did on Kaggle:

EBMs:
https://www.kaggle.com/code/syedyasir/trans-time-replybuffer-trunc-predictor

https://www.kaggle.com/code/macbro27/trans-time-reply-buffer-no-trunc-predictor?scriptVersionId=275546445

HVAE:
https://www.kaggle.com/code/macbro27/hvae-distill-attempt21


## Setup and Data Preprocessing

In [None]:

import numpy as np
import pandas as pd
# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [None]:
import wandb
wandb.login(key="hidden")

  | |_| | '_ \/ _` / _` |  _/ -_)
[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33myasir-alam14[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import pandas as pd
import nltk
from transformers import AutoTokenizer
import numpy as np
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import torch.nn.utils as utils
from tqdm import tqdm
from transformers import GPT2Model, GPT2Config
from sklearn.model_selection import train_test_split
from bs4 import BeautifulSoup
from transformers import DistilBertModel, DistilBertConfig
from transformers import DistilBertTokenizer
import matplotlib.pyplot as plt
import seaborn as sns



In [None]:
from datasets import load_dataset

# Load only the Physics Stack Exchange data
dataset = load_dataset(
    "HuggingFaceH4/stack-exchange-preferences",
    data_dir="data/physics.stackexchange.com",
    split="train"
)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

data/physics.stackexchange.com/train-000(…):   0%|          | 0.00/44.2M [00:00<?, ?B/s]

data/physics.stackexchange.com/train-000(…):   0%|          | 0.00/39.1M [00:00<?, ?B/s]

data/physics.stackexchange.com/train-000(…):   0%|          | 0.00/35.9M [00:00<?, ?B/s]

data/physics.stackexchange.com/train-000(…):   0%|          | 0.00/36.8M [00:00<?, ?B/s]

data/physics.stackexchange.com/train-000(…):   0%|          | 0.00/38.0M [00:00<?, ?B/s]

Generating train split: 0 examples [00:00, ? examples/s]

In [None]:
def get_best_qa_pair(row):
    answers = row['answers']
    best_answer = None

    # Priority 1: Find the answer marked as "Selected" (Accepted)
    for ans in answers:
        if ans['selected']:
            best_answer = ans['text']
            break

    # Priority 2: If none selected, pick the highest score
    if not best_answer and answers:
        # Sort by score descending and take the top one
        best_answer = sorted(answers, key=lambda x: x['pm_score'], reverse=True)[0]['text']

    return {
        'question': row['question'],
        'answer': best_answer
    }


In [None]:
qa_list = [get_best_qa_pair(row) for row in dataset]

df = pd.DataFrame(qa_list)

# Optional: Remove entries where no answer was found
df = df.dropna(subset=['answer'])

In [None]:
def clean_html_smart(text):
    if not text:
        return ""

    soup = BeautifulSoup(text, "html.parser")

    # 1. Handle explicit line breaks
    for br in soup.find_all("br"):
        br.replace_with("\n")

    # 2. Add newlines ONLY after block-level tags
    # We define what constitutes a "block" that needs separation
    block_tags = ['p', 'div', 'h1', 'h2', 'h3', 'h4', 'h5', 'h6', 'li', 'blockquote']

    for tag in soup.find_all(block_tags):
        # Insert a newline character immediately after the block tag
        # You can use '\n\n' if you prefer wider spacing between paragraphs
        tag.insert_after('\n')

    # 3. Extract text with NO separator
    # This allows inline tags (em, a, strong) to flow naturally into the text
    clean_text = soup.get_text()

    # 4. Final cleanup: Remove excessive whitespace but keep the structure
    return clean_text.strip()


In [None]:
df['question_clean'] = df['question'].apply(clean_html_smart)
df['answer_clean'] = df['answer'].apply(clean_html_smart)


In [None]:
import nltk

# Download the required resources
nltk.download('punkt')
nltk.download('punkt_tab')  # <--- REQUIRED for newer NLTK versions

# Define the counting function
def count_sentences(text):
    if not text:
        return 0
    return len(nltk.sent_tokenize(text))

# Apply it to your cleaned column
df['num_sentences'] = df['answer_clean'].apply(count_sentences)

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.
[nltk_data] Downloading package punkt_tab to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt_tab.zip.


In [None]:
df_filtered = df[(df['num_sentences'] >= 5) & (df['num_sentences'] <= 11)]

# Print the length of the filtered DataFrame
print(len(df_filtered))

35618


In [None]:
df_filtered.head()

Unnamed: 0,question,answer,question_clean,answer_clean,num_sentences
1,<p>How would you explain string theory to non-...,<p>I've noticed that none of these answers act...,How would you explain string theory to non-phy...,I've noticed that none of these answers actual...,6
4,<p>Hamilton's principle states that a dynamic ...,"<p>The notes from week 1 of <a href=""http://ma...",Hamilton's principle states that a dynamic sys...,The notes from week 1 of John Baez's course in...,6
7,<p>Why does the sky change color? Why is the s...,"<p>The keywords here are <a href=""http://en.wi...",Why does the sky change color? Why is the sky ...,The keywords here are Rayleigh scattering. Se...,5
9,"<p>Where is the <a href=""https://en.wikipedia....","<p><a href=""http://en.wikipedia.org/wiki/Monte...",Where is the Monte Carlo method used in physics?,Monte Carlo is a particular numerical techniqu...,6
10,<p>I think it's clear enough that if you turn ...,<p>The simple answer is that the angle between...,I think it's clear enough that if you turn you...,The simple answer is that the angle between th...,5


In [None]:
df_filtered = df_filtered.drop(columns=['question', 'answer'])

In [None]:
import nltk
import pandas as pd

# 1. Download NLTK sentence tokenizer data if not already present
try:
    # Check if the data exists
    nltk.data.find('tokenizers/punkt')
except LookupError:
    # If LookupError is raised, it means the data is missing, so we download it
    nltk.download('punkt')

# --- Configuration ---
ABSTRACT_COLUMN = 'answer_clean'
PREPROCESSED_COLUMN = 'answer_clean_preprocessed'

# Define the delimiter (We only need EOS for splitting later)
EOS_TOKEN = '<EOS>'
# ---------------------

# --- Corrected Preprocessing ---

def preprocess_for_hvae(text):
    """
    Segments text into sentences and joins them with a delimiter <EOS>.
    Does NOT add BOS or EOT tokens (handled by Tokenizer later).
    """
    if pd.isna(text) or not text:
        return ""

    # 1. Use NLTK for sentence segmentation
    sentences = nltk.sent_tokenize(text)

    if not sentences:
        return ""

    # 2. Join with <EOS> so process_dual_stream can split them later
    # Result: "Sentence 1. <EOS> Sentence 2. <EOS> Sentence 3."
    return f" {EOS_TOKEN} ".join([s.strip() for s in sentences])

# Apply the preprocessing
# Make sure df_filtered is defined before running this
df_filtered[PREPROCESSED_COLUMN] = df_filtered[ABSTRACT_COLUMN].apply(preprocess_for_hvae)

# Verify one example
print(df_filtered[PREPROCESSED_COLUMN].iloc[0])

I've noticed that none of these answers actually answer the question. <EOS> The simplest explanation of string theory I can think of:

Particles we currently consider "point particles" (electrons, quarks, photons, etc.) <EOS> are actually tiny pieces of string with each a characteristic vibration. <EOS> They interact in a sort of harmony that results in/manifests as the physical laws we observe. <EOS> If anyone with more knowledge in the field can correct me, I ask for improvements. <EOS> This is just how I personally explain it to people who ask, and I'd hate to give out false information.


In [None]:
def count_max_words(text):
    if not text:
        return 0
    # 1. Split by <EOS> to get sentences
    sentences = text.split('<EOS>')

    # 2. Count words in each valid sentence
    # split() without arguments splits by any whitespace
    counts = [len(s.strip().split()) for s in sentences if s.strip()]

    # 3. Return the maximum count (or 0 if list is empty)
    return max(counts) if counts else 0

In [None]:
df_filtered['max_sent_word_count'] = df_filtered['answer_clean_preprocessed'].apply(count_max_words)

In [None]:
df_filtered['max_sent_word_count'].describe()

Unnamed: 0,max_sent_word_count
count,35618.0
mean,44.494862
std,24.429469
min,6.0
25%,30.0
50%,39.0
75%,51.0
max,821.0


In [None]:
max_row = df_filtered[df_filtered['max_sent_word_count'] == 30]


# Print the specific text to debug why it is so long
print("\n--- Full Text of Longest 'Sentence' ---")
# Access the raw text of that row
full_text = max_row['answer_clean_preprocessed'].iloc[0]
print(full_text)


--- Full Text of Longest 'Sentence' ---
Your ear is an effective Fourier transformer. <EOS> An ear contains many small hair cells. <EOS> The hair cells differ in length, tension, and thickness, and therefore respond to different frequencies. <EOS> Different hair cells are mechanically linked to ion channels in different neurons, so different neurons in the brain get activated depending on the Fourier transform of the sound you're hearing. <EOS> A piano is a Fourier analyzer for a similar reason. <EOS> A prism or diffraction grating would be a Fourier analyzer for light. <EOS> It spreads out light of different frequencies, allowing us to analyze how much of each frequency is present in a given source.


In [None]:
df_filtered = df_filtered[(df_filtered['max_sent_word_count'] <= 40)]

In [None]:
df_filtered['max_sent_word_count'].describe()

Unnamed: 0,max_sent_word_count
count,19074.0
mean,30.527629
std,6.293533
min,6.0
25%,26.0
50%,31.0
75%,36.0
max,40.0


In [None]:
train_df, valid_df = train_test_split(df_filtered, test_size=0.2, random_state=42)

print(f"Train shape: {train_df.shape}")
print(f"Valid shape: {valid_df.shape}")

Train shape: (15259, 5)
Valid shape: (3815, 5)


## HVAE Model Architecture

Defines the MLPNetwork (a multi-layer perceptron with GELU activation and layer normalization, used for latent variable distribution) and HT_HVAE_InferenceNetwork classes (the hierarchical transformer encoder for the VAE). This sets up the encoder part of the HVAE model.

In [None]:
class MLPNetwork(nn.Module):
    """
    Optimized MLP for HVAE Encoder q(z|x).
    1. Removes bottlenecks (Maintains width).
    2. Uses GELU (Matches BERT/Transformer activations).
    3. Implements Near-Zero Initialization to prevent early KL shock.
    """
    def __init__(self, input_dim, latent_dim):
        super().__init__()

        # 1. Maintain Width: Do not compress to //2 or //4 immediately.
        # We want deep non-linearities, not compression.
        hidden_dim = input_dim

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)

        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)

        # Output layer
        self.fc_out = nn.Linear(hidden_dim, 2 * latent_dim, bias=False)

        # Match activation to DistilBERT/GPT-2 (GELU is smoother than ReLU)
        self.activation = nn.GELU()

        # --- CRITICAL: Near-Zero Initialization ---
        # This ensures that at Step 0, the posterior q(z|x) is very close to N(0,1).
        # This prevents the initial KL loss from being huge, which scares the
        # optimizer into "killing" the latent variable immediately (Collapse).
        torch.nn.init.normal_(self.fc_out.weight, mean=0.0, std=0.001)


    def forward(self, h):
        # Post-Norm architecture (standard for Transformers)
        x = self.fc1(h)
        x = self.ln1(x)
        x = self.activation(x)

        x = self.fc2(x)
        x = self.ln2(x)
        x = self.activation(x)

        output = self.fc_out(x)

        mu, raw_var_score = output.chunk(2, dim=-1)

        # Robust Softplus
        sigma2 = F.softplus(raw_var_score) + 1e-6

        return mu, sigma2

class HT_HVAE_InferenceNetwork(nn.Module):
    """
    The Hierarchical Transformer Encoder (Inference Network) q(z|x).
    Implements shared-parameter word-level and sentence-level Transformers.
    """
    def __init__(self,hyperparams):
        super().__init__()

        self.latent_dim = hyperparams['latent_dim']
        self.d_model = hyperparams['d_model']
        self.vocab_size = hyperparams['vocab_size']
        self.max_sentences = hyperparams['max_sentences']
        self.max_words = hyperparams['max_words']
        self.n_heads = hyperparams['encoder_heads']
        self.dropout = hyperparams['encoder_dropout']
        self.n_layers = hyperparams['encoder_layers']

        self.word_encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')

        self.distilbert_dim = self.word_encoder.config.hidden_size
        self.word_projection = nn.Linear(self.distilbert_dim, self.d_model)

        self.sentence_position_embedding = nn.Embedding(self.max_sentences + 1, self.d_model)


        # 2.2. Sentence-Level Transformer Encoder
        # This is a separate Transformer stack for sentence-level attention.

        sentence_encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=self.n_heads,
            dim_feedforward=4*self.d_model,
            dropout=self.dropout,
            batch_first=True,
            # 1. Change activation from 'relu' (default) to 'gelu'
            activation='gelu',
            # 2. Enable Pre-Layer Normalization (Pre-LN)
            norm_first=True
        )
        self.transformer_sentence = nn.TransformerEncoder(sentence_encoder_layer, num_layers=self.n_layers)

        # 2.3. Latent Variable Networks (MLP)
        # These are used to fit the Gaussian distribution parameters.

        # MLP for Global Posterior q(zt | x): Learned from text-code H_s^0
        self.mlp_global = MLPNetwork(self.d_model, self.latent_dim)
        self.local_latent = self.latent_dim

        # MLP for Local Posterior q(zi | xi): Learned from sentence-code H_w^i
        # This MLP MUST SHARE PARAMETERS across all sentences.
        self.mlp_local = MLPNetwork(self.d_model, self.local_latent)
        self.doc_token = nn.Parameter(torch.randn(1, 1, self.d_model))


    def forward(self, input_ids, word_mask):
        """
        Args:
            input_ids (torch.LongTensor): Input tensor of shape
                (batch_size, MAX_SENTENCES, MAX_WORDS).
            word_mask (torch.BoolTensor): Mask tensor of shape
                (batch_size, MAX_SENTENCES, MAX_WORDS). True for padded tokens.

        Returns:
            tuple: (mu_t, sigma2_t, mu_i_list, sigma2_i_list)
        """
        # word mask is 1 for real tokens, 0 for pad
        batch_size, max_sentences, max_words = input_ids.shape

        # --- 2.1. Word-Level Transformer Encoding (Shared) ---

        # Reshape to treat all sentences as one batch for parameter sharing
        # (batch_size * MAX_SENTENCES, MAX_WORDS)
        flat_input_ids = input_ids.view(-1, max_words)
        flat_attention_mask = word_mask.view(-1, max_words)
        sentence_flat_word_mask = (word_mask.view(-1, max_words) == 0)



        # PyTorch TransformerEncoder expects the mask to be True for elements that SHOULD be IGNORED (padded)
        # and of shape (B, S). The mask must be the *attention mask* (source key padding mask).
        # Mask shape for transformer: (B*N, S)
        # Note: input_ids == 0 is often used for padding in simple setups.
        # Here we use the provided word_mask.

        distilbert_output = self.word_encoder(
            input_ids=flat_input_ids,
            attention_mask=flat_attention_mask
        )

        cls_embeddings = distilbert_output.last_hidden_state[:, 0, :]

        # Extract the Sentence Code (Hw_0): (B*N, D)
        # Hw_0 is the representation of the first token (usually <BOS> or the first word)
        H_w_0_flat = self.word_projection(cls_embeddings)

        is_padding_sentence = sentence_flat_word_mask[:, 0]

        # 2. "Firewall": Replace NaNs with 0.0 immediately
        # This protects BOTH the MLP and the Sentence Transformer downstream
        H_w_0_flat = H_w_0_flat.masked_fill(is_padding_sentence.unsqueeze(1), 0.0)

        # Reshape back: (batch_size, MAX_SENTENCES, D)
        H_w_0 = H_w_0_flat.view(batch_size, max_sentences, self.d_model)

        # --- 2.3. Local Posterior (q(zi | xi)) using Shared MLP ---

        # Calculate local distribution parameters from Hw_0
        # The mlp_local shares parameters because it is called on all Hw_0_flat
        mu_i_flat, sigma2_i_flat = self.mlp_local(H_w_0_flat)

        # Reshape back to (batch_size, MAX_SENTENCES, LATENT_DIM)
        mu_i = mu_i_flat.view(batch_size, max_sentences, self.local_latent)
        sigma2_i = sigma2_i_flat.view(batch_size, max_sentences, self.local_latent)

        # --- 2.2. Sentence-Level Transformer Encoding ---

        # Add sentence position codes to the sentence-codes Hw_0
        # We assume position codes 1 to MAX_SENTENCES are used, and 0 is reserved for H_s^0 position.
        # Create position IDs: 1, 2, 3...

        batch_doc_token = self.doc_token.expand(batch_size, -1, -1)
        H_sen_input_ = torch.cat([batch_doc_token, H_w_0], dim=1)
        position_ids = torch.arange(0, max_sentences + 1, device=input_ids.device)
        position_embeddings = self.sentence_position_embedding(position_ids) # (N, D)
        position_embeddings = position_embeddings.unsqueeze(0).expand(batch_size, -1, -1) # (B, N, D)

        H_sen_input = H_sen_input_ + position_embeddings

        # Need a sentence-level padding mask
        # Assuming sentence padding is where all words are 0/padded (i.e., word_mask[:, :, 0] is True)
        sentence_mask = (word_mask[:, :, 0] == 0) # (B, N)

        doc_mask = torch.zeros((batch_size, 1), dtype=torch.bool, device=input_ids.device)

        full_sentence_mask = torch.cat([doc_mask, sentence_mask], dim=1)

        # Sentence-Level Encoding (Hs): (B, N, D)
        H_s = self.transformer_sentence(
            src=H_sen_input,
            src_key_padding_mask=full_sentence_mask
        )

        H_s_0 = H_s[:, 0, :] # (B, D)

        # --- 2.3. Global Posterior (q(zt | x)) ---

        # Calculate global distribution parameters from H_s^0
        mu_t, sigma2_t = self.mlp_global(H_s_0) # (B, LATENT_DIM)

        return mu_t, sigma2_t, mu_i, sigma2_i


We define the MLPNetworkForPrior (an MLP used for the prior distribution in the VAE with context dropout) and HT_HVAE_GenerativeNetwork classes (the hierarchical VAE generative network, which includes a GRU for sentence planning and a modified GPT-2 for word generation). This sets up the decoder part of the HVAE model.

In [None]:
class MLPNetworkForPrior(nn.Module):
    """
    Optimized Prior Network p(z_i | z_t, z_i-1).
    1. Widened layers (No bottleneck).
    2. Context Dropout (Forces reliance on z_t).
    3. Near-Zero Init (Prevents initial KL explosion).
    """
    def __init__(self, latent_dim, context_dropout_rate=0.05):
        super().__init__()
        self.context_dropout_rate = context_dropout_rate

        # Input: Global (64) + Local (32) = 96
        input_dim = latent_dim + latent_dim

        # 1. Maintain Width:
        # Instead of compressing, we project up or keep equal.
        # 128 gives enough capacity to mix Global and Previous-Local info.
        hidden_dim = 2*input_dim

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)

        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)

        self.fc_out = nn.Linear(hidden_dim, 2 * latent_dim, bias=False) # Outputs (mu, var_param)

        self.activation = nn.GELU()

        # --- CRITICAL: Near-Zero Initialization ---
        # Initialize output to be very close to N(0, 1) parameters.
        # mu -> 0, sigma_param -> 0 (which becomes softplus(0) ~ 0.69)
        # This matches the initialization of your Encoder.
        torch.nn.init.normal_(self.fc_out.weight, mean=0.0, std=0.001)
        # torch.nn.init.constant_(self.fc_out.bias, 0)

    def forward(self, z_t, z_i_minus_1):
        """
        Args:
            z_t (Tensor): Global latent (B, D_global)
            z_i_minus_1 (Tensor): Previous local latent (B, D_local)
        """

        # --- 2. Context Dropout (The "Firewall") ---
        # Randomly zero out the previous sentence information.
        # This forces the network to look at z_t to guess the current sentence state.
        if self.training and self.context_dropout_rate > 0:
            mask_prob = torch.rand(z_t.shape[0], 1, device=z_t.device)
            # If random > rate, keep signal (1.0). Else drop (0.0).
            keep_mask = (mask_prob > self.context_dropout_rate).float()
            z_i_minus_1 = z_i_minus_1 * keep_mask

        # Concatenate inputs
        h = torch.cat([z_t, z_i_minus_1], dim=-1)

        # Block 1
        x = self.fc1(h)
        x = self.ln1(x)
        x = self.activation(x)

        # Block 2
        x = self.fc2(x)
        x = self.ln2(x)
        x = self.activation(x)

        # Output Head
        output = self.fc_out(x)

        mu, raw_var_score = output.chunk(2, dim=-1)

        # Robust Softplus
        sigma2 = F.softplus(raw_var_score) + 1e-6

        return mu, sigma2

class HT_HVAE_GenerativeNetwork(nn.Module):
    """
    The Hierarchical VAE Generative Network (Decoder) p(x|z).
    Uses GRU for sentence planning and modified GPT-2 for word generation.
    """
    def __init__(self,hyperparams):
        super().__init__()

        # 3.1. Global Prior: Standard Gaussian N(0, I) is defined conceptually.
        self.latent_dim = hyperparams['latent_dim']
        self.local_latent = self.latent_dim
        self.d_model = hyperparams['d_model']
        self.gpt2_model_name = hyperparams['gpt2_model_name']
        self.vocab_size = hyperparams['vocab_size']
        self.gru_layers = hyperparams['gru_layers']

        # 3.2. Local Prior Network (HMM integration)
        self.prior_mlp = MLPNetworkForPrior(self.latent_dim)
        self.gru_initial_projection = nn.Linear(self.latent_dim, self.d_model)

        # 3.3. Sentence-Level Decoder (GRU)
        # Input size is LATENT_DIM (z_i), hidden size matches D_MODEL (for h_i), 1 layer.
        self.gru = nn.GRU(
        input_size=self.local_latent,
        hidden_size=self.d_model,
        num_layers=self.gru_layers,
        batch_first=True
      )

        # 3.4. Word-Level Decoder (Modified GPT-2)

        # Load pre-trained GPT-2 and modify its components for the VAE structure.
        # We load the full model to utilize its weights.
        self.gpt2_model = GPT2Model.from_pretrained(self.gpt2_model_name)

        self.gpt2_model.resize_token_embeddings(self.vocab_size)

        gpt2_hidden_size = self.gpt2_model.config.n_embd

        self.global_projector = nn.Linear(self.latent_dim, gpt2_hidden_size)


        # Define the projection layer here so weights are saved and trained
        self.latent_projection = nn.Linear(self.d_model, gpt2_hidden_size)

        # Set the embedding layer from GPT-2
        self.word_embedding = self.gpt2_model.wte

        self.word_dropout_rate = hyperparams['word_dropout_rate']
        self.plan_dropout_rate = hyperparams['plan_dropout_rate']
        self.mask_token_id = hyperparams.get('mask_token_id', self.gpt2_model.config.eos_token_id)


        # Get GPT-2's vocabulary head (unembedder)
        # In AutoModelForCausalLM, this is usually the same as wte.weight, but we need a specific layer
        # Since we modify the input to the final linear layer, we redefine the classification head.

        # Define the final linear layer: Input is [Hin || hi]
        # Output is the vocabulary size (Logits)
        self.final_linear = nn.Linear(gpt2_hidden_size + self.d_model, self.vocab_size, bias = False)



    def forward(self, input_ids,word_mask, z_t, z_i_samples, mu_i_prior=None, log_sigma2_i_prior=None):
        """
        Processes text sequentially for reconstruction loss calculation (training mode).

        Args:
            input_ids (torch.LongTensor): Word IDs of shape (B, N, S).
            z_t (torch.Tensor): Sampled global latent variable (B, LATENT_DIM).
            z_i_samples (torch.Tensor): Sampled local latent variables (B, N, LATENT_DIM).
            mu_i_prior (torch.Tensor, optional): Pre-calculated prior mean for KL divergence.
            log_sigma2_i_prior (torch.Tensor, optional): Pre-calculated prior log-variance for KL divergence.

        Returns:
            tuple: (reconstruction_logits, mu_i_prior, sigma2_i_prior)
        """

        batch_size, max_sentences, max_words = input_ids.shape
        global_drop_prob = 0.5

        if self.training and global_drop_prob > 0:
          # Create mask: 1 = Keep, 0 = Drop
          mask_prob = torch.rand(z_t.size(0), 1, device=z_t.device)
          keep_global_mask = (mask_prob > global_drop_prob).float()

          # Apply to z_t (This effectively zeroes h_0 for the GRU and the token for GPT)
          z_t_masked = z_t * keep_global_mask
        else:
          z_t_masked = z_t



        # --- 3.2. Sentence-Level Decoder (GRU) ---

        # The GRU processes the sequence of local latent variables (z_i_samples)
        # Sequence of z_i: (B, N, LATENT_DIM)

        # Initial hidden state for the GRU (often zeros, or derived from z_t if desired,
        # but the paper just uses z_i as input)

        # GRU output: plan_vectors (h_i) is the sequence of outputs for each z_i

        h_0_projected = self.gru_initial_projection(z_t_masked)
        h_0_projected = h_0_projected.unsqueeze(0)
        h_0 = h_0_projected.repeat(self.gru.num_layers, 1, 1) # Shape: (1, B, D_MODEL) ho is expected num_gru layers batch size hidden size

        plan_vectors, _ = self.gru(z_i_samples, h_0) # (B, N, D_MODEL)

        # --- 3.2. Local Prior Network calculation (for KL divergence) ---

        # We need z_i-1, which is the previous sampled local latent variable.
        # Create z_i-1 by shifting z_i_samples:
        z_i_minus_1 = torch.cat([
            torch.zeros_like(z_i_samples[:, :1, :]), # Use zero vector for z_i-1 of the first sentence
            z_i_samples[:, :-1, :]
        ], dim=1).view(-1, self.local_latent) # Flatten to (B*N, D)

        # Flatten z_t and z_i_samples for MLP processing
        z_t_flat = z_t_masked.unsqueeze(1).repeat(1, max_sentences, 1).view(-1, self.latent_dim) # (B*N, D)

        # Calculate the conditional prior parameters for all sentences
        mu_i_prior, sigma2_i_prior = self.prior_mlp(z_t_flat, z_i_minus_1)
        # Reshape to (B, N, LATENT_DIM)
        mu_i_prior = mu_i_prior.view(batch_size, max_sentences, self.local_latent)
        sigma2_i_prior = sigma2_i_prior.view(batch_size, max_sentences, self.local_latent)

        # --- 3.4. Word-Level Decoder (Modified GPT-2) ---

        # Reshape everything to feed into the word-level loop
        flat_input_ids = input_ids.view(-1, max_words) # (B*N, S)
        flat_plan_vectors = plan_vectors.reshape(-1, self.d_model) # (B*N, D) (h_i for each sentence)

        # 1. Word Embeddings (e_ij): (B*N, S, D)
        if self.training and self.word_dropout_rate > 0:
            # 1. Create a mask of random probabilities
            rand_mask = torch.rand(flat_input_ids.shape, device=input_ids.device)

            # 2. Identify tokens to drop (probability < rate)
            # We also ensure we do NOT drop padding tokens (if 0 is your pad)
            # Assuming word_mask indicates valid words (1) and padding (0)
            flat_word_mask_bool = word_mask.view(-1, max_words).bool()

            # Drop mask: True where we should replace with UNK
            drop_mask = (rand_mask < self.word_dropout_rate) & flat_word_mask_bool

            # 3. Create a clone to avoid in-place modification errors
            input_ids_for_decoder = flat_input_ids.clone()
            input_ids_for_decoder[drop_mask] = self.mask_token_id

        else:
            # No dropout during evaluation
            input_ids_for_decoder = flat_input_ids



        inputs_embeds = self.gpt2_model.wte(input_ids_for_decoder)

        # 2. Project Latents (Plan Vectors)
        # Ensure self.latent_projection maps from D_MODEL -> GPT_Hidden_Dim (e.g., 768)
        projected_latents = self.latent_projection(flat_plan_vectors) # (B*N, GPT_Hidden)

        # 3. Expand Latents to match Sequence Length
        # (B*N, 1, GPT_Hidden) -> (B*N, S, GPT_Hidden)
        latent_embeds = projected_latents.unsqueeze(1).expand(-1, max_words, -1)

        # 4. Additive Conditioning
        # The latent information is now fused into the input representation

        if self.training and self.plan_dropout_rate > 0:
            # We drop the ENTIRE plan vector for a sequence, not just random dimensions
            # Shape: (B*N, 1, 1) to broadcast over Sequence and Hidden dims
            p_mask = torch.rand(latent_embeds.size(0), 1, 1, device=latent_embeds.device)

            # Create boolean mask: 1 = Keep, 0 = Drop
            # If random number > dropout_rate, we keep the plan
            keep_plan_mask = (p_mask > self.plan_dropout_rate).float()

            # Apply mask: latent_embeds becomes all zeros where mask is 0
            latent_embeds = latent_embeds * keep_plan_mask

        inputs_embeds = inputs_embeds + latent_embeds

        # 5. Forward Pass
        # Note: We pass 'inputs_embeds' instead of 'input_ids'.
        # GPT-2 will automatically add Position Embeddings (wpe) to this internally.
        z_t_expanded = z_t.unsqueeze(1).expand(-1, max_sentences, -1).reshape(-1, self.latent_dim)

        # 2. Project z_t to GPT hidden dimension and add sequence dim
        # Requires: self.global_projector = nn.Linear(latent_dim, gpt_hidden_dim)
        z_t_emb = self.global_projector(z_t_expanded).unsqueeze(1) # (B*N, 1, GPT_Hidden)

        # 3. Concatenate global token to the beginning of inputs
        inputs_embeds = torch.cat([z_t_emb, inputs_embeds], dim=1) # (B*N, S+1, GPT_Hidden)

        # 4. Adjust Attention Mask
        flat_word_mask = word_mask.view(-1, max_words)
        # Create a mask of 1s for the global token
        global_mask_col = torch.ones((flat_word_mask.size(0), 1), device=flat_word_mask.device)
        # Concatenate mask: (B*N, S+1)
        extended_attention_mask = torch.cat([global_mask_col, flat_word_mask], dim=1)

        gpt2_output = self.gpt2_model(
            inputs_embeds=inputs_embeds,
            attention_mask=extended_attention_mask
        )
        # Extract the hidden states from the last layer (H_in)
        H_in = gpt2_output.last_hidden_state # (B*N, S, D)

        plan_unsqueezed = flat_plan_vectors.unsqueeze(1)

        # 2. Expand to match the sequence length 'S' of H_in: (B*N, 1, D) -> (B*N, S, D)
        # Using .expand() is more memory efficient than .repeat()
        copied_plan_vector = plan_unsqueezed.expand(-1, H_in.size(1), -1)

        # 3. Concatenate: Result shape is (B*N, S, GPT_Hidden + Plan_Dim)
        final_input = torch.cat([H_in, copied_plan_vector], dim=-1)


        # Calculate Logits (Final output for reconstruction loss)
        reconstruction_logits = self.final_linear(final_input) # (B*N, S, VOCAB_SIZE)

        # Reshape logits back to (B, N, S, VOCAB_SIZE) for loss calculation
        reconstruction_logits = reconstruction_logits.view(
            batch_size, max_sentences, max_words + 1, self.vocab_size
        )

        return reconstruction_logits, mu_i_prior, sigma2_i_prior

    def inference(self, z_t, z_i, max_length=50, bos_token_id=50256, eos_token_id=50256):
        """
        Autoregressively generates text matching the specific architecture of forward():
        1. Global Token z_t is the very first input.
        2. Local Latent (Plan) is ADDED to word embeddings (not concatenated).
        3. Plan Vector is concatenated to the output (Skip connection).
        """
        self.eval()
        batch_size, num_sentences, _ = z_i.shape
        device = z_t.device

        generated_sentences = []

        with torch.no_grad():
            # --- 1. Sentence-Level Decoder (GRU) ---
            # Initialize GRU with z_t (same as forward)
            h_0_projected = self.gru_initial_projection(z_t).unsqueeze(0)
            h_0 = h_0_projected.repeat(self.gru.num_layers, 1, 1)

            # Get plan vectors (h_i)
            plan_vectors, _ = self.gru(z_i, h_0) # (B, N, D_MODEL)

            # --- Pre-calculate Global Token Embedding ---
            # Logic matches Forward 3.4 Step 2 & 3: Project z_t and use as first token
            z_t_emb = self.global_projector(z_t).unsqueeze(1) # (B, 1, GPT_Hidden)

            # --- 2. Word-Level Generation Loop ---
            for n in range(num_sentences):
                # Extract plan for current sentence
                current_plan = plan_vectors[:, n, :] # (B, D_MODEL)

                # Prepare Additive Latent (Logic matches Forward 3.4 Step 2)
                # Project Plan: D_MODEL -> GPT_Hidden
                latent_emb = self.latent_projection(current_plan).unsqueeze(1) # (B, 1, GPT_Hidden)

                # A. Initialize Context with Global Token z_t
                # We feed z_t first to prime the cache (past_key_values)
                outputs = self.gpt2_model(inputs_embeds=z_t_emb)
                past_key_values = outputs.past_key_values

                # B. Initialize Generation with BOS
                curr_input_ids = torch.full((batch_size, 1), bos_token_id, dtype=torch.long, device=device)
                sentence_tokens = [curr_input_ids]

                for _ in range(max_length):
                    # 1. Word Embeddings
                    text_embeds = self.gpt2_model.wte(curr_input_ids) # (B, 1, GPT_Hidden)

                    # 2. Additive Conditioning (Matches Forward 3.4 Step 4)
                    # Instead of prefixing, we ADD the latent to the embedding
                    inputs_embeds = text_embeds + latent_emb

                    # 3. Forward Pass (using cache)
                    outputs = self.gpt2_model(
                        inputs_embeds=inputs_embeds,
                        past_key_values=past_key_values
                    )
                    past_key_values = outputs.past_key_values
                    last_hidden_state = outputs.last_hidden_state # (B, 1, GPT_Hidden)

                    # 4. Skip Connection (Matches Forward Final Step)
                    # Concatenate plan vector to the GPT output
                    plan_unsqueezed = current_plan.unsqueeze(1) # (B, 1, D_MODEL)
                    final_input = torch.cat([last_hidden_state, plan_unsqueezed], dim=-1)

                    # 5. Project to Vocab
                    logits = self.final_linear(final_input) # (B, 1, VOCAB_SIZE)

                    # Greedy decoding
                    next_token = torch.argmax(logits, dim=-1) # (B, 1)

                    sentence_tokens.append(next_token)
                    curr_input_ids = next_token

                    # Stop generation (Basic check for batch_size=1)
                    if next_token[0].item() == eos_token_id:
                        break

                # Concatenate all tokens for this sentence
                full_sentence = torch.cat(sentence_tokens, dim=1)
                generated_sentences.append(full_sentence)

        return generated_sentences
    def generate_latents(self, batch_size, num_sentences, device):
        """
        Generates z_t from standard normal, then autoregressively generates
        z_i sequence using the learned HMM prior.
        """
        # 1. Sample Global Latent z_t ~ N(0, I)
        # Shape: (B, LATENT_DIM)
        z_t = torch.randn(batch_size, self.latent_dim, device=device)

        # 2. Autoregressively Sample Local Latents z_i
        z_i_list = []

        # Initialize z_{i-1} for the first sentence as a zero vector
        # Shape: (B, LATENT_DIM)
        z_prev = torch.zeros(batch_size, self.latent_dim, device=device)

        for _ in range(num_sentences):
            # Calculate parameters for p(z_i | z_{i-1}, z_t)
            # Note: In training you flattened (B, N, D), here we just pass (B, D)
            mu_i_prior, sigma2_i_prior = self.prior_mlp(z_t, z_prev)

            # Sample z_i using the reparameterization trick
            eps = torch.randn_like(mu_i_prior)
            z_i = mu_i_prior + torch.sqrt(sigma2_i_prior) * eps

            # Store and update state
            z_i_list.append(z_i)
            z_prev = z_i

        # Stack to match shape (B, N, LATENT_DIM)
        z_i_samples = torch.stack(z_i_list, dim=1)

        return z_t, z_i_samples

We define utility functions for the VAE model: reparameterize for sampling from a Gaussian distribution, kl_divergence_two_gaussians for calculating KL divergence between two arbitrary Gaussians, and kl_divergence_standard_gaussian for KL divergence against a standard normal distribution. It also defines the HT_HVAE_Loss class, which implements the full ELBO objective (reconstruction loss + global KL + local KL).

In [None]:

def reparameterize(mu: torch.Tensor, sigma2: torch.Tensor) -> torch.Tensor:
    """
    Applies the reparameterization trick to sample z from N(mu, sigma^2).

    Args:
        mu (torch.Tensor): Mean vector (mu').
        sigma2 (torch.Tensor): Variance vector (sigma^2').

    Returns:
        torch.Tensor: Sampled latent vector z.
    """
    # Calculate standard deviation (sigma) from variance (sigma^2)
    std = torch.sqrt(sigma2)

    # Sample noise epsilon from Standard Gaussian N(0, I)
    epsilon = torch.randn_like(std)

    # z = mu + sigma * epsilon
    z = mu + std * epsilon
    return z

def kl_divergence_two_gaussians(
    mu_p: torch.Tensor, sigma2_p: torch.Tensor,
    mu_q: torch.Tensor, sigma2_q: torch.Tensor
) -> torch.Tensor:
    """
    Calculates D_KL(q || p) with input clamping for stability.
    """
    # 1. Clamp variances to a minimum value to avoid div by zero
    eps = 1e-6
    sigma2_p = torch.clamp(sigma2_p, min=eps)
    sigma2_q = torch.clamp(sigma2_q, min=eps)

    # 2. Calculate without adding epsilon to the denominator
    kl_loss = 0.5 * (
        torch.log(sigma2_p) - torch.log(sigma2_q) +
        (sigma2_q + (mu_q - mu_p).pow(2)) / sigma2_p - 1
    )

    # 3. Ensure non-negative output (handling remaining float precision noise)
    return torch.clamp(torch.sum(kl_loss, dim=-1), min=0.0)

def kl_divergence_standard_gaussian(
    mu_q: torch.Tensor, sigma2_q: torch.Tensor
) -> torch.Tensor:
    """
    Calculates the closed-form KL divergence D_KL(q || p) where
    p is the Standard Gaussian N(0, I). (Used for Global KL Term).
    """
    # Equivalent to kl_divergence_two_gaussians(0, 1, mu_q, sigma2_q)
    # The sum is over the latent dimension (LATENT_DIM)
    kl_loss = 0.5 * (
        sigma2_q + mu_q.pow(2) - 1 - torch.log(sigma2_q)
    )
    # sum across the latent dimension
    return torch.sum(kl_loss, dim=-1)


class HT_HVAE_Loss(nn.Module):
    """
    Implements the full ELBO objective L_HT-HVAE.
    NOTE: The formula in the paper is for maximization. We implement the negative
    of the formula for minimization, common in PyTorch loss functions.

    L_minimize = -L_HT-HVAE = Reconstruction_Loss + Global_KL + Local_KL
    """
    def __init__(self, hyperparameters):
        super().__init__()
        self.vocab_size = hyperparameters['vocab_size']
        self.pad_idx = hyperparameters['pad_index']
        self.latent_dim = hyperparameters['latent_dim']
        self.local_latent = self.latent_dim
        # Use CrossEntropyLoss for the negative log-likelihood (Reconstruction Term)
        # Note: We ignore the padding index during loss calculation
        self.cross_entropy_loss = nn.CrossEntropyLoss(ignore_index=self.pad_idx, reduction='none')


    def forward(self, mu_t_q, sigma2_t_q, mu_i_q, sigma2_i_q,
            reconstruction_logits, mu_i_p, sigma2_i_p,
            target_ids, word_mask, local_kl_beta=0.5, global_kl_beta = 0.1 ): # Args remain same


        batch_size, max_sentences, max_words = target_ids.shape

        shift_logits = reconstruction_logits[..., :-1, :].contiguous()
        # Keep ALL labels (Word 1 ... Word S)
        shift_labels = target_ids.contiguous()

        # --- A. Reconstruction (Sum over words, Mean over batch) ---
        # 1. Calculate raw NLL per token
        nll_loss_per_token = self.cross_entropy_loss(
            shift_logits.view(-1, self.vocab_size),
            shift_labels.view(-1)
        )

        # 2. Sum ALL losses (numerator)
        total_nll = nll_loss_per_token.sum()

        # 3. Divide by BATCH_SIZE (instead of total token count)
        # This scales the loss to be "Average Loss per Sequence"
        reconstruction_loss = total_nll / batch_size

        num_active_tokens = (shift_labels != self.pad_idx).sum()

        # Calculate mean (add epsilon to avoid div by zero)
        mean_token_loss = total_nll / (num_active_tokens + 1e-9)


        # --- B. Global KL (Sum over dim, Mean over batch) ---
        # kl_divergence_standard_gaussian now returns (B,) due to sum(dim=-1)
        global_kl_raw = kl_divergence_standard_gaussian(mu_t_q, sigma2_t_q)


        global_kl_loss = global_kl_raw.mean()

        # --- C. Local KL (Sum over dim, Mean over batch) ---
        # kl_divergence_two_gaussians now returns (B*N,) due to sum(dim=-1)

        local_kl_flat = kl_divergence_two_gaussians(
            mu_p=mu_i_p.reshape(-1, self.local_latent),
            sigma2_p=sigma2_i_p.reshape(-1, self.local_latent),
            mu_q=mu_i_q.reshape(-1, self.local_latent),
            sigma2_q=sigma2_i_q.reshape(-1, self.local_latent)
            )

        # Apply Free Bits per vector (crucial step)

        local_kl_reshaped = local_kl_flat.view(batch_size, max_sentences)
        # Mask padding sentences
        sentence_mask = word_mask[:, :, 0].float()
        local_kl_masked = local_kl_reshaped * sentence_mask

        # Sum over sentences (N), then Mean over batch (B)
        local_kl_loss = local_kl_masked.sum(dim=1).mean()

        # --- D. Total ---
        # With kl_beta 0->1, this balance is valid.
        total_loss = reconstruction_loss + global_kl_beta* global_kl_loss + local_kl_beta*local_kl_loss
        total_kl_unweighted = global_kl_loss + local_kl_loss
        kl_ratio = (reconstruction_loss / (total_kl_unweighted + 1e-8)).detach()

        return total_loss, reconstruction_loss, global_kl_loss, local_kl_loss, kl_ratio, mean_token_loss # healthy range of ratio is 10 - 100



We define a save_checkpoint function to save the model's state, optimizer, epoch, and validation loss, and to upload it as a versioned artifact to Weights & Biases.

In [None]:
def save_checkpoint(
    inference_net,
    generative_net,
    optimizer,
    epoch,
    val_loss,
    scheduler=None,
    is_best=False,
    filename="hvae_checkpoint.pth"
):
    """
    Saves model state, optimizer state, and epoch to a local file
    and uploads it to WandB as a versioned artifact.
    """

    # 1. Consolidate the state dictionary
    state = {
        'epoch': epoch,
        'inference_state_dict': inference_net.state_dict(),
        'generative_state_dict': generative_net.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'val_loss': val_loss
    }

    # 2. Save locally first
    torch.save(state, filename)

    # 3. Create a WandB Artifact
    # We name the artifact generally (e.g., 'hvae-model') so versions accumulate under one entry
    artifact = wandb.Artifact(
        name="hvae-model",
        type="model",
        metadata={"epoch": epoch, "val_loss": val_loss}
    )

    artifact.add_file(filename)

    # 4. Determine aliases (tags)
    # 'latest' allows you to always resume from the most recent upload
    # 'best' marks the model with the lowest loss for evaluation
    aliases = ["latest"]
    if is_best:
        aliases.append("best")

    # 5. Upload
    wandb.log_artifact(artifact, aliases=aliases)

 Here, we have defined a dictionary hyperParams containing various hyperparameters for the HVAE model, including architecture constraints, latent space dimensions, data dimensions, encoder/decoder configurations, and special token information.

In [None]:
hyperParams = {
    # --- Architecture Constraints ---
    'gpt2_model_name': 'gpt2',   # Start with 'gpt2' (124M params). Use 'gpt2-medium' only if you have high VRAM.
    'd_model': 768,              # MUST be 768 for 'gpt2', 1024 for 'gpt2-medium', 1280 for 'gpt2-large'.
    'vocab_size': 50257,         # Standard GPT-2 vocabulary size.

    # --- Latent Space ---
    'latent_dim': 32,           # 32-128 is standard. Too large = posterior collapse; Too small = poor reconstruction.

    # --- Data Dimensions (Abstract specific) ---
    'max_sentences': 11,         # Average abstract is 5-8 sentences; 10 provides a safety buffer.
    'max_words': 50,             # Average sentence is 20-30 words; 50 covers outliers.

    # --- Inference Network (Encoder) ---
    'encoder_layers': 2,         # Keep encoder shallow (2-4) so the heavy lifting happens in the decoder/latent space.
    'encoder_heads': 8,          # Standard for d_model=768 (768 / 8 = 96 dim per head).
    'encoder_dropout': 0.1,      # Standard transformer dropout.

    # --- Sentence Decoder (GRU) ---
    'gru_layers': 1,             # 1 layer is sufficient for the high-level plan; more layers add unnecessary complexity.

    # --- Special Tokens ---
    'pad_index': 50256,          # GPT-2 uses EOS (50256) as PAD by default unless you added a new token.
    'word_dropout_rate' : 0.5,
    'plan_dropout_rate':0.15,
    'mask_token_id':10
}


Now, we make the process_dual_stream function, which tokenizes text into two streams (one for the DistilBERT encoder and one for the GPT-2 decoder), handles special tokens like BOS, EOS, and EOT, and applies padding.

In [None]:
def process_dual_stream(text, enc_tokenizer, dec_tokenizer, max_sentences, max_words, eot_token_id):
    """
    Parses text into TWO sets of tensors.
    Appends eot_token_id to the very last sentence of the Decoder stream.
    """
    # 1. Clean and Split
    text = text.replace('<|endoftext|>', '').strip()
    if text.startswith('<BOS>'):
        text = text[5:]
    raw_sentences = text.split('<EOS>')
    sentences = [s.strip() for s in raw_sentences if s.strip()]

    enc_ids_rows, enc_mask_rows = [], []
    dec_ids_rows, dec_mask_rows = [], []

    num_sentences = len(sentences)

    for i, sent in enumerate(sentences):
        # --- A. ENCODER (DistilBERT) ---
        # No changes needed here based on your requirements
        enc_tokens = enc_tokenizer.encode(sent, add_special_tokens=False)
        enc_tokens = enc_tokens[:max_words - 2]
        full_enc = [enc_tokenizer.cls_token_id] + enc_tokens + [enc_tokenizer.sep_token_id]

        enc_mask = [1] * len(full_enc)
        pad_len_enc = max_words - len(full_enc)
        if pad_len_enc > 0:
            full_enc += [enc_tokenizer.pad_token_id] * pad_len_enc
            enc_mask += [0] * pad_len_enc

        enc_ids_rows.append(full_enc)
        enc_mask_rows.append(enc_mask)

        # --- B. DECODER (GPT-2) ---
        dec_tokens = dec_tokenizer.encode(sent, add_special_tokens=False)

        # Check if this is the absolute last sentence in the valid text
        is_last_sentence = (i == num_sentences - 1)

        if is_last_sentence:
            # Reserve 3 spots: BOS, EOS, and EOT
            dec_tokens = dec_tokens[:max_words - 3]
            full_dec = [dec_tokenizer.bos_token_id] + dec_tokens + [dec_tokenizer.eos_token_id] + [eot_token_id]
        else:
            # Reserve 2 spots: BOS, EOS
            dec_tokens = dec_tokens[:max_words - 2]
            full_dec = [dec_tokenizer.bos_token_id] + dec_tokens + [dec_tokenizer.eos_token_id]

        dec_mask = [1] * len(full_dec)
        pad_len_dec = max_words - len(full_dec)

        if pad_len_dec > 0:
            full_dec += [dec_tokenizer.pad_token_id] * pad_len_dec
            dec_mask += [0] * pad_len_dec

        dec_ids_rows.append(full_dec)
        dec_mask_rows.append(dec_mask)

    # --- Vertical Padding ---
    while len(enc_ids_rows) < max_sentences:
        enc_ids_rows.append([enc_tokenizer.pad_token_id] * max_words)
        enc_mask_rows.append([0] * max_words)

        dec_ids_rows.append([dec_tokenizer.pad_token_id] * max_words)
        dec_mask_rows.append([0] * max_words)

    return (
        enc_ids_rows[:max_sentences], enc_mask_rows[:max_sentences],
        dec_ids_rows[:max_sentences], dec_mask_rows[:max_sentences]
    )

We define the HVAEDataset class for handling the dataset and create_dataloaders function to prepare DataLoader objects for training and validation.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd


class HVAEDataset(Dataset):
    def __init__(self, dataframe, enc_tokenizer, dec_tokenizer, max_sentences, max_words,eot_token_id):
        self.data = dataframe['answer_clean_preprocessed'].tolist()
        self.enc_tokenizer = enc_tokenizer # DistilBERT
        self.dec_tokenizer = dec_tokenizer # GPT-2
        self.max_sentences = max_sentences
        self.max_words = max_words

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data[idx]

        enc_ids, enc_mask, dec_ids, dec_mask = process_dual_stream(
            text,
            self.enc_tokenizer,
            self.dec_tokenizer,
            self.max_sentences,
            self.max_words,
            eot_token_id
        )

        return {
            'enc_input_ids': torch.tensor(enc_ids, dtype=torch.long),
            'enc_word_mask': torch.tensor(enc_mask, dtype=torch.long), # DistilBERT mask
            'dec_input_ids': torch.tensor(dec_ids, dtype=torch.long),
            'dec_word_mask': torch.tensor(dec_mask, dtype=torch.long)  # GPT-2 mask
        }

def create_dataloaders(df, enc_tokenizer,dec_tokenizer, hyperparams, batch_size=32):
    # 1. Create Dataset
    dataset = HVAEDataset(
        dataframe=df,
        enc_tokenizer=enc_tokenizer,
        dec_tokenizer = dec_tokenizer,
        max_sentences=hyperparams['max_sentences'],
        max_words=hyperparams['max_words'],
        eot_token_id = hyperparams['eot_token_id']
    )

    # 2. Create DataLoader
    # num_workers=0 is safer for debugging; increase for speed later
    dataloader = DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=0
    )

    return dataloader




In [None]:
wandb.init(
    project="HVAE-distilbert_plan_masking_full",
    config=hyperParams,
    name="max_0.2_cyclic_aggressive_better_encoder_prior_mlps_no_zt_prepend"
)

In [None]:
def get_monotonic_beta(
    epoch_index,
    batch_idx,
    steps_per_epoch,
    WARMUP_EPOCHS=4,
    ANNEAL_EPOCHS=10,  # Duration of the ramp
    MAX_BETA=0.1
):
    """
    Calculates Beta with a smooth step-wise ramp defined by epoch boundaries.

    Schedule:
    1. Warmup: Beta = 0 for WARMUP_EPOCHS.
    2. Anneal: Linear ramp from 0 to MAX_BETA over ANNEAL_EPOCHS.
    3. Plateau: Hold at MAX_BETA forever after.
    """

    # 1. Warm-up Phase: Strictly 0
    if epoch_index < WARMUP_EPOCHS:
        return 0.0

    # 2. Calculate steps relative to the end of warmup
    # How many steps have passed since we started the annealing phase?
    epochs_since_warmup = epoch_index - WARMUP_EPOCHS
    steps_into_annealing = (epochs_since_warmup * steps_per_epoch) + batch_idx

    # Total steps allocated for the ramp
    total_annealing_steps = ANNEAL_EPOCHS * steps_per_epoch

    # 3. Calculate Progress (0.0 -> 1.0)
    # If we are past the annealing time, this will be > 1.0
    progress = steps_into_annealing / total_annealing_steps

    # 4. Calculate Beta
    if progress >= 1.0:
        return MAX_BETA
    else:
        return progress * MAX_BETA

In [None]:
def get_kl_beta(
    epoch_index,
    batch_idx,
    steps_per_epoch,
    WARMUP_EPOCHS=10,
    MIN_BETA=0.1,   # User request: Lower non-zero beta (e.g., 0.001 or 0.1)
    MAX_BETA=0.2,
    CYCLE_EPOCHS=50,
    LOW_RATIO=0,  # % of cycle spent holding MIN_BETA
    RAMP_RATIO=1  # % of cycle spent ramping up
):
    """
    Function to calculate beta values for KL divergence using a cyclic trapezoidal
    schedule with a warmup phase, allowing for dynamic annealing.
    """

    # 1. Warm-up Phase: Strictly 0
    if epoch_index < WARMUP_EPOCHS:
        return (0.0, 0.0)

    # 2. Cycle Logic
    effective_epoch = epoch_index - WARMUP_EPOCHS
    current_step_in_cycle_phase = effective_epoch * steps_per_epoch + batch_idx

    epochs_per_cycle = CYCLE_EPOCHS
    steps_per_cycle = steps_per_epoch * epochs_per_cycle

    # Progress 0.0 -> 1.0 within the current cycle
    cycle_progress = (current_step_in_cycle_phase % steps_per_cycle) / steps_per_cycle

    # 3. Trapezoidal Schedule
    if cycle_progress < LOW_RATIO:
        # Phase A: Hold at MIN_BETA
        current_beta = MIN_BETA

    elif cycle_progress < (LOW_RATIO + RAMP_RATIO):
        # Phase B: Linear Ramp from MIN to MAX
        # Normalize progress 0->1 for this specific segment
        segment_progress = (cycle_progress - LOW_RATIO) / RAMP_RATIO
        current_beta = MIN_BETA + (segment_progress * (MAX_BETA - MIN_BETA))

    else:
        # Phase C: Hold at MAX_BETA
        current_beta = MAX_BETA

    return (current_beta, current_beta) # Returning same for local/global

In [None]:
def compute_active_units(mu_global, mu_local, threshold=0.01):
    """
    Returns a dictionary of Active Unit metrics and the Heatmap image.
    Does NOT call wandb.log().
    """

    # --- 1. Global Active Units ---
    global_vars = torch.var(mu_global, dim=0)
    num_active_global = (global_vars > threshold).sum().item()

    # --- 2. Local Active Units (Aggregate) ---
    local_vars_agg = torch.var(mu_local.reshape(-1, 32), dim=0)
    num_active_local = (local_vars_agg > threshold).sum().item()

    # --- 3. Local Activity Map ---
    local_activity_map = torch.var(mu_local, dim=0).cpu().numpy()


    # --- Prepare Dictionary ---
    metrics = {
        "AU/Global_Count": num_active_global,
        "AU/Local_Count": num_active_local,
        "AU/Global_Variance_Avg": global_vars.mean().item(),
    }

    # --- Visualization ---
    fig, ax = plt.subplots(figsize=(10, 6))
    sns.heatmap(local_activity_map > threshold, ax=ax, cbar=False, cmap="vlag")
    ax.set_title("Active Units per Sentence Position (Binary)")
    ax.set_ylabel("Sentence Index")
    ax.set_xlabel("Latent Dimension")

    # Add image to dictionary
    metrics["AU/Local_Heatmap"] = wandb.Image(fig)
    plt.close(fig)

    return metrics

In [None]:
def load_checkpoint_from_wandb(
    artifact_path,
    inference_net,
    generative_net,
    optimizer,
    scheduler=None,
    filename="hvae_checkpoint.pth"
):
    """
    Downloads a specific artifact from WandB, loads the state dictionaries,
    and returns the epoch to resume from.
    """
    print(f"Resuming from WandB artifact: {artifact_path}")

    # 1. Download the artifact
    # explicit run init is usually required if not already active,
    # but wandb.use_artifact works if run is active.
    artifact = wandb.use_artifact(artifact_path, type='model')
    artifact_dir = artifact.download()
    filepath = os.path.join(artifact_dir, filename)

    # 2. Load the file
    if torch.cuda.is_available():
        checkpoint = torch.load(filepath)
    else:
        checkpoint = torch.load(filepath, map_location=torch.device('cpu'))

    # 3. Restore states
    inference_net.load_state_dict(checkpoint['inference_state_dict'])
    generative_net.load_state_dict(checkpoint['generative_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])

    if scheduler and checkpoint.get('scheduler_state_dict'):
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

    # 4. Return the next epoch
    # If we saved at epoch 9 (finished), we want to start at 10.
    start_epoch = checkpoint['epoch'] + 1
    val_loss = checkpoint.get('val_loss', 'N/A')

    print(f"Checkpoint loaded. Resuming from Epoch {start_epoch} (Last Val Loss: {val_loss})")
    return start_epoch

In [None]:

def reparameterize(mu, sigma2):
    """Sampling z ~ N(mu, sigma^2) using the reparameterization trick."""
    std = torch.sqrt(sigma2)
    eps = torch.randn_like(std)
    return mu + eps * std

def train_one_epoch(
    inference_net,
    generative_net,
    loss_module,
    dataloader,
    optimizer,
    device,
    epoch_index,
    total_epochs,
    scheduler = None
):
    """
    Encapsulates the training loop for a single epoch. It performs forward/backward passes,
    calculates loss, applies KL annealing, logs metrics to Weights & Biases,
    and includes gradient clipping and learning rate scheduling.
    """
    inference_net.train()
    generative_net.train()

    total_epoch_loss = 0
    total_recon_loss = 0
    total_kl_loss = 0


    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch_index+1}")
    steps_per_epoch = len(dataloader)

    target_batch_size = 64
    batch_size = dataloader.batch_size # e.g., 8
    accumulation_steps = target_batch_size // batch_size

    optimizer.zero_grad()
    if epoch_index == 0:
        for param in inference_net.word_encoder.transformer.layer[-2:].parameters():
            param.requires_grad = True
    if epoch_index == 0:
         for param in generative_net.gpt2_model.parameters():
            param.requires_grad = True



    for batch_idx, batch in enumerate(progress_bar):

        local_beta,global_beta = get_kl_beta(
                    epoch_index,
                    batch_idx,
                    steps_per_epoch,
                    WARMUP_EPOCHS=0,
                    MIN_BETA=0.01,   # User request: Lower non-zero beta (e.g., 0.001 or 0.1)
                    MAX_BETA=0.07,
                    CYCLE_EPOCHS=10,
                    LOW_RATIO=0,  # % of cycle spent holding MIN_BETA
                    RAMP_RATIO=0.9  # % of cycle spent ramping up
                )
        # 1. Prepare Data
        # Assuming batch is a tuple/list; adjust unpacking based on your CollateFn
        enc_input_ids = batch['enc_input_ids'].to(device)
        enc_word_mask = batch['enc_word_mask'].to(device)
        dec_input_ids = batch['dec_input_ids'].to(device)
        dec_word_mask = batch['dec_word_mask'].to(device)

        # Target IDs are typically the same as input_ids for reconstruction
        # The Loss module should handle shifting (input[t] -> target[t+1]) internally
        # or via masking.
        target_ids = dec_input_ids.clone()


        # 2. Inference Network (Encoder)
        # Forward pass to get posterior parameters q(z|x)
        mu_t_q, sigma2_t_q, mu_i_q, sigma2_i_q = inference_net(enc_input_ids, enc_word_mask)

        # 3. Sampling (Reparameterization)
        z_t = reparameterize(mu_t_q, sigma2_t_q)
        z_i_samples = reparameterize(mu_i_q, sigma2_i_q)

        # 4. Generative Network (Decoder)
        # Reconstruct inputs based on samples and calculate priors p(z)
        reconstruction_logits, mu_i_p, sigma2_i_p = generative_net(
            dec_input_ids,
            dec_word_mask,
            z_t,
            z_i_samples
        )

        # 5. Loss Calculation
        loss, recon, global_kl, local_kl, kl_ratio, per_token_loss = loss_module(
            # From Encoder (Posterior q)
            mu_t_q, sigma2_t_q, mu_i_q, sigma2_i_q,
            # From Decoder (Likelihood + Prior p)
            reconstruction_logits, mu_i_p, sigma2_i_p,
            # Targets
            target_ids, dec_word_mask,
            # Annealing
            local_kl_beta=0.02,
            global_kl_beta = global_beta,

        )

        scaled_loss = loss / accumulation_steps
        scaled_loss.backward()
        if batch_idx == 0 and epoch_index == 0:
          print(f"Global Mu Mean:   {mu_t_q.mean().item():.5f} | Std: {mu_t_q.std().item():.5f}")
          print(f"Global Sigma2 Mean: {sigma2_t_q.mean().item():.5f}")

        if (batch_idx + 1) % accumulation_steps == 0:
            # Gradient Clipping
            utils.clip_grad_norm_(inference_net.parameters(), max_norm=1.0)
            utils.clip_grad_norm_(generative_net.parameters(), max_norm=1.0)

            optimizer.step()
            if scheduler is not None:
                scheduler.step()

            optimizer.zero_grad() # Reset for next set of accumulation



        current_lr_1 = optimizer.param_groups[0]['lr']
        current_lr_2 = optimizer.param_groups[1]['lr']

        # 7. Logging
        total_epoch_loss += loss.item()
        total_recon_loss += recon.item()
        # Sum global and local KL for display
        total_kl_loss += (global_kl.item() + local_kl.item())
        current_kl = global_kl.item() + local_kl.item()

        if (batch_idx+ 1)%230 == 0:
            log_dict = {
        "Losses/loss": loss.item(),
        "Losses/recon": recon.item(),
        "KL/kl_ratio": kl_ratio.item(),
        "KL/kl_local": local_kl.item(),
        "KL/kl_total": current_kl,
        "KL/kl_global": global_kl.item(),
        "Losses/per_token_loss": per_token_loss.item(),
        'Betas/local_kl_beta': 0.02,
        'Betas/global_kl_beta': global_beta,
        'LRs/lr1': current_lr_1,
        'LRs/lr2': current_lr_2,
    }
            au_metrics = compute_active_units(
            mu_t_q.detach().cpu(),
          mu_i_q.detach().cpu(),
          threshold=0.01
        )

    # 3. Merge dictionaries
            log_dict.update(au_metrics)

    # 4. Log everything once
            wandb.log(log_dict)

    # save_checkpoint(
    #     inference_net=inference_net,
    #     generative_net=generative_net,
    #     optimizer=optimizer,
    #     epoch=epoch,
    #     val_loss=0.1,
    #     scheduler=scheduler,
    #     is_best=False
    # )

## Model Execution
1. Initializes GPT2Tokenizer and DistilBertTokenizer, and add custom special tokens (BOS, EOS, PAD, MASK, EOT) to the GPT-2 tokenizer, updates the hyperParams dictionary with new vocabulary size and token IDs.

In [None]:
from transformers import GPT2Tokenizer, DistilBertTokenizer

# 1. Load standard tokenizers
dec_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
enc_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# 2. Add Special Tokens
# We use 'additional_special_tokens' for custom tokens like EOT
special_tokens_dict = {
    'bos_token': '<|BOS|>',
    'eos_token': '<|EOS|>',
    'pad_token': '<|PAD|>',
    'mask_token': '<|MASK|>',
    'additional_special_tokens': ['<|EOT|>']
}

# This returns the number of added tokens
num_added_toks = dec_tokenizer.add_special_tokens(special_tokens_dict)

# 3. Get IDs
# Standard attributes exist for bos/eos/pad/mask
pad_idx = dec_tokenizer.pad_token_id
eos_idx = dec_tokenizer.eos_token_id
mask_idx = dec_tokenizer.mask_token_id

# For EOT, we must look it up manually since .eot_token_id doesn't exist
eot_token_id = dec_tokenizer.convert_tokens_to_ids('<|EOT|>')

new_vocab_size = len(dec_tokenizer)

print(f"New Vocab Size: {new_vocab_size}")
print(f"PAD ID: {pad_idx} | EOS ID: {eos_idx}")
print(f"MASK ID: {mask_idx} | EOT ID: {eot_token_id}")

# 4. Update hyperparameters
hyperParams['vocab_size'] = new_vocab_size
hyperParams['pad_index'] = pad_idx
hyperParams['mask_token_id'] = mask_idx
hyperParams['eot_token_id'] = eot_token_id

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

config.json:   0%|          | 0.00/665 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

New Vocab Size: 50262
PAD ID: 50259 | EOS ID: 50258
MASK ID: 50260 | EOT ID: 50261


In [None]:
train_subset = train_df.sample(frac=1, random_state=42)
train_loader = create_dataloaders(train_subset, enc_tokenizer,dec_tokenizer, hyperParams, batch_size=64)
valid_loader = create_dataloaders(valid_df, enc_tokenizer,dec_tokenizer ,hyperParams, batch_size=8)

### Training process:
defines device, num_epochs, learning rate schedulers (LinearLR for warmup and CosineAnnealingLR for decay) for different parameter groups of the optimizer.

Also, we initialize the HT_HVAE_InferenceNetwork, HT_HVAE_GenerativeNetwork, and HT_HVAE_Loss and then starts the training loop by calling train_one_epoch repeatedly. It also includes logic for resuming training from a Weights & Biases artifact and saves checkpoints periodically.


In [None]:
from torch.optim.lr_scheduler import SequentialLR, LinearLR, CosineAnnealingLR

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

WANDB_ARTIFACT_PATH = "yasir-alam14/HVAE-distilbert_freeze_encoder/hvae-model:v16" # Set this to None if starting fresh
RESUME_TRAINING = False

num_epochs = 75
# 1. Calculate total training steps
target_batch_size = 64
batch_size = train_loader.batch_size # e.g., 8
accumulation_steps = target_batch_size // batch_size

total_steps = (num_epochs * len(train_loader)) // accumulation_steps

# 2. Define Warmup (e.g., 5% of total steps)
warmup_steps = int(0.1 * total_steps)
decay_steps = total_steps - warmup_steps

print(f"Total Steps: {total_steps} | Warmup Steps: {warmup_steps}")

# Instantiate Models
inference_net = HT_HVAE_InferenceNetwork(hyperParams).to(device)
generative_net = HT_HVAE_GenerativeNetwork(hyperParams).to(device)
loss_module = HT_HVAE_Loss(hyperParams).to(device)

for param in inference_net.word_encoder.parameters():
    param.requires_grad = False


for param in generative_net.gpt2_model.parameters():
            param.requires_grad = False

generative_net.gpt2_model.wte.weight.requires_grad = True

# WPE: So it learns to handle the Global Prefix at Pos 0
generative_net.gpt2_model.wpe.weight.requires_grad = True
# --- KEY MODIFICATION START: Parameter Splitting ---

# Define your Learning Rates
LR_GPT2 = 1e-5   # Low LR for pre-trained weights
LR_BERT = 5e-5
LR_REST = 5e-4   # High LR for new Encoders/GRU/MLPs

gpt2_params = list(generative_net.gpt2_model.parameters())
distilbert_params = list(inference_net.word_encoder.parameters())

# 2. Identify "Everything Else"
# Create a set of IDs for parameters already assigned to avoid duplication
assigned_ids = set(map(id, gpt2_params + distilbert_params))
scratch_params = []

# Iterate through both networks and collect unassigned parameters
for model in [inference_net, generative_net]:
    for param in model.parameters():
        if id(param) not in assigned_ids:
            scratch_params.append(param)

# 3. Create 3-Group Optimizer
optimizer = torch.optim.AdamW([
    {'params': gpt2_params, 'lr': LR_GPT2},       # Group 1: GPT-2
    {'params': distilbert_params, 'lr': LR_BERT}, # Group 2: DistilBERT
    {'params': scratch_params, 'lr': LR_REST}     # Group 3: Scratch/Rest
], weight_decay=0.01)# --- KEY MODIFICATION END ---


# 3. Create the Schedulers
# Note: Schedulers scale the *current* LR of each group multiplicatively.
# Start factor 0.01 means:
#   - Group 0 starts at 5e-6 (1% of 5e-4)
#   - Group 1 starts at 5e-7 (1% of 5e-5)

# Phase 1: Linear Warmup
scheduler_warmup = LinearLR(
    optimizer,
    start_factor=0.01,
    end_factor=1.0,
    total_iters=warmup_steps
)

# Phase 2: Cosine Decay
# Note on eta_min: You can pass a list if you want different minimums,
# but a single float applies to all groups. 1e-6 is fine for both.
scheduler_decay = CosineAnnealingLR(
    optimizer,
    T_max=decay_steps,
    eta_min=1e-6
)

# Combine parameters for optimizer
scheduler = SequentialLR(
    optimizer,
    schedulers=[scheduler_warmup, scheduler_decay],
    milestones=[warmup_steps]
)

start_epoch = 0

if RESUME_TRAINING and WANDB_ARTIFACT_PATH:
    # Ensure wandb is initialized if it hasn't been already
    if wandb.run is None:
        wandb.init(project="my_project", resume="allow")

    start_epoch = load_checkpoint_from_wandb(
        artifact_path=WANDB_ARTIFACT_PATH,
        inference_net=inference_net,
        generative_net=generative_net,
        optimizer=optimizer,
        scheduler=scheduler
    )


#Run Loop
special_tokens_set = {'<PAD>', '<BOS>', '<EOS>', '<|endoftext|>'}

for epoch in range(start_epoch, num_epochs):
    if epoch >= num_epochs:
        print("Training already completed for this number of epochs.")
        break
    train_one_epoch(inference_net, generative_net, loss_module, train_loader, optimizer, device, epoch, num_epochs,scheduler)
    #results = evaluate_model(inference_net, generative_net, loss_module, valid_loader, tokenizer, device, special_tokens_set)
    #validation_logs = {f"validation/{k}": v for k, v in results.items()}
    #wandb.log(validation_logs)
    if (epoch + 1 )%5 == 0:
        save_checkpoint(
            inference_net=inference_net,
            generative_net=generative_net,
            optimizer=optimizer,
            epoch=epoch,
            val_loss=0.1,
            scheduler=scheduler,
            is_best=False
        )


Total Steps: 17925 | Warmup Steps: 1792


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/548M [00:00<?, ?B/s]

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`
Epoch 1:   0%|          | 0/239 [00:00<?, ?it/s]

Global Mu Mean:   -0.00036 | Std: 0.01853
Global Sigma2 Mean: 0.69238


Epoch 1:  68%|██████▊   | 163/239 [07:39<03:33,  2.81s/it]Token indices sequence length is longer than the specified maximum sequence length for this model (584 > 512). Running this sequence through the model will result in indexing errors
Epoch 1: 100%|██████████| 239/239 [11:11<00:00,  2.81s/it]
Epoch 2: 100%|██████████| 239/239 [11:06<00:00,  2.79s/it]
Epoch 3: 100%|██████████| 239/239 [11:06<00:00,  2.79s/it]
Epoch 4: 100%|██████████| 239/239 [11:06<00:00,  2.79s/it]
Epoch 5: 100%|██████████| 239/239 [11:06<00:00,  2.79s/it]
Epoch 6: 100%|██████████| 239/239 [11:07<00:00,  2.79s/it]
Epoch 7: 100%|██████████| 239/239 [11:08<00:00,  2.80s/it]
Epoch 8: 100%|██████████| 239/239 [11:07<00:00,  2.79s/it]
Epoch 9: 100%|██████████| 239/239 [11:07<00:00,  2.79s/it]
Epoch 10: 100%|██████████| 239/239 [11:07<00:00,  2.79s/it]
Epoch 11: 100%|██████████| 239/239 [11:07<00:00,  2.79s/it]
Epoch 12: 100%|██████████| 239/239 [11:06<00:00,  2.79s/it]
Epoch 13: 100%|██████████| 239/239 [11:06<00:00, 

In [None]:
for name, param in inference_net.named_parameters():
    if torch.isnan(param).any():
        print(f"FATAL: Found NaN in model weights: {name}")
        break
else:
    print("Weights are clean. The issue is in the forward pass calculation.")

In [None]:
first_batch = next(iter(train_loader))

### Performs a single forward pass and loss calculation using the first batch, demonstrating the flow of data through the encoder, reparameterization, and decoder, and calculates all loss components.

In [None]:
enc_input_ids = first_batch['enc_input_ids'].to(device)
enc_word_mask = first_batch['enc_word_mask'].to(device)
dec_input_ids = first_batch['dec_input_ids'].to(device)
dec_word_mask = first_batch['dec_word_mask'].to(device)

# Target IDs are typically the same as input_ids for reconstruction
# The Loss module should handle shifting (input[t] -> target[t+1]) internally
# or via masking.
target_ids = dec_input_ids.clone()


# 2. Inference Network (Encoder)
# Forward pass to get posterior parameters q(z|x)
mu_t_q, sigma2_t_q, mu_i_q, sigma2_i_q = inference_net(enc_input_ids, enc_word_mask)

# 3. Sampling (Reparameterization)
z_t = reparameterize(mu_t_q, sigma2_t_q)
z_i_samples = reparameterize(mu_i_q, sigma2_i_q)

# 4. Generative Network (Decoder)
# Reconstruct inputs based on samples and calculate priors p(z)
reconstruction_logits, mu_i_p, sigma2_i_p = generative_net(
    dec_input_ids,
    dec_word_mask,
    z_t,
    z_i_samples
)

# 5. Loss Calculation
loss, recon, global_kl, local_kl, kl_ratio, per_token_loss = loss_module(
    # From Encoder (Posterior q)
    mu_t_q, sigma2_t_q, mu_i_q, sigma2_i_q,
    # From Decoder (Likelihood + Prior p)
    reconstruction_logits, mu_i_p, sigma2_i_p,
    # Targets
    target_ids, dec_word_mask,
    # Annealing
    local_kl_beta=1,
    global_kl_beta = 1,

)


In [None]:
z_i_samples.shape

In [None]:
print(f"Recon loss is {per_token_loss}\n global_kl is {global_kl} \n local is {local_kl}")

In [None]:
np.exp(per_token_loss.detach().cpu())

In [None]:
mu_t_q.shape

In [None]:
mu_i_q.shape

Computes and print active unit metrics for the sampled mu_t_q and mu_i_q from the first batch.

In [None]:
au_metrics = compute_active_units(
            mu_t_q.detach().cpu(),
          mu_i_q.detach().cpu(),
          threshold=0.01
        )

In [None]:
print(au_metrics)

In [None]:
test_zt = z_t[0]
test_zi = z_i_samples[0]
test_zt = test_zt.unsqueeze(0)
test_zi = test_zi.unsqueeze(0)

In [None]:
def decode_generated_sequences(generated_sequences, tokenizer, skip_special_tokens=True):
    """
    Decodes the list of tensors returned by the inference method into text.

    Args:
        generated_sequences (list[torch.Tensor]): List of tensors where each tensor
                                                  has shape (Batch, Seq_Len).
        tokenizer: A tokenizer with a .decode() method (e.g., HuggingFace).
        skip_special_tokens (bool): If True, removes BOS/EOS/PAD tokens.

    Returns:
        list[str]: A list of decoded strings. (If Batch > 1, it flattens results or
                   can be modified to return list of lists).
    """
    decoded_texts = []

    # Iterate through each sentence step (N sentences)
    for seq_tensor in generated_sequences:
        # Iterate through the batch dimension (B)
        for i in range(seq_tensor.size(0)):
            # Convert tensor to python list of IDs
            token_ids = seq_tensor[i].tolist()

            # Decode
            text = tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens)
            decoded_texts.append(text)

    return decoded_texts


def decode_original_batch(input_ids, tokenizer):
    """
    Decodes a 3D tensor of input_ids into text.

    Args:
        input_ids (torch.Tensor): Shape (Batch_Size, Max_Sentences, Max_Words)
        tokenizer: HuggingFace tokenizer

    Returns:
        list[list[str]]: Nested list of strings [Batch][Sentence]
    """
    # Ensure it's on CPU and convert to list if it's a tensor
    if hasattr(input_ids, 'cpu'):
        input_ids = input_ids.cpu()

    batch_size, num_sentences, max_words = input_ids.shape
    all_decoded_text = []

    for b in range(batch_size):
        batch_sentences = []
        print(f"\n--- Batch Sample {b} ---")

        for n in range(num_sentences):
            # Get the sequence for this specific sentence
            sequence = input_ids[b, n]

            # Remove padding (optional, depending on tokenizer behavior)
            # sequence = sequence[sequence != tokenizer.pad_token_id]

            text = tokenizer.decode(sequence, skip_special_tokens=True)

            # Filter out empty strings if entire row was padding
            if text.strip():
                print(f"  Sentence {n}: {text}")
                batch_sentences.append(text)

        all_decoded_text.append(batch_sentences)

    return all_decoded_text

In [None]:
first_batch['dec_input_ids'][0].shape

In [None]:
predicted_ids = generative_net.inference(test_zt, test_zi, max_length=50, bos_token_id=dec_tokenizer.bos_token_id, eos_token_id=dec_tokenizer.eos_token_id)


In [None]:
real_ids = first_batch['dec_input_ids'][0].unsqueeze(0)

In [None]:
print(decode_original_batch(real_ids, dec_tokenizer))
print(decode_generated_sequences(predicted_ids,dec_tokenizer))

In [None]:
# generated_zt, generated_zi = generative_net.generate_latents(1, 11, device)

In [None]:
# generated_zi.shape

In [None]:
# generated_text = generative_net.inference(generated_zt, generated_zi, max_length=50, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id)

In [None]:
# print(decode_generated_sequences(generated_text,tokenizer))

# References

1. Physics Stack Exchange Dataset
@online{h4stackexchange,
  author = {Lambert, Nathan and Tunstall, Lewis and Rajani, Nazneen and Thrush, Tristan},
  title = {HuggingFace H4 Stack Exchange Preference Dataset},
  year = 2023,
  url = {https://huggingface.co/datasets/HuggingFaceH4/stack-exchange-preferences},
}

2. torch.nn (PyTorch):
@article{paszke2019pytorch,
  title={PyTorch: An Imperative Style, High-Performance Deep Learning Library},
  author={Paszke, Adam and Gross, Sam and Massa, Francisco and Lerer, Adam and Bradbury, James and Chanan, Gregory and Killeen, Trevor and Lin, Zeming and Gimelshein, Natalya and Antiga, Luca and others},
  journal={Advances in Neural Information Processing Systems},
  volume={32},
  year={2019}
}

3. transformers (Hugging Face Transformers library):
@article{wolf2020transformers,
  title={Transformers: State-of-the-Art Natural Language Processing},
  author={Wolf, Thomas and Debut, Lysandre and Sanh, Victor and Chaumond, Julien and Delangue, Clement and Moi, Anthony and Cistac, Pierric and Rault, Timoth{\'e}e and Santus, Luca and Pereira, Oriol and others},
  journal={Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing: System Demonstrations},
  pages={38--45},
  year={2020}
}

4. Sentence BLEU (from nltk.translate.bleu_score):
@article{papineni2002bleu,
  title={BLEU: a method for automatic evaluation of machine translation},
  author={Papineni, Kishore and Roukos, Salim and Ward, Todd and Zhu, Wei-Jing},
  journal={Proceedings of the 40th annual meeting on association for computational linguistics},
  pages={311--318},
  year={2002}
}

5. NLTK:
@book{bird2009natural,
  title={Natural Language Processing with Python: Analyzing Text with the Natural Language Toolkit},
  author={Bird, Steven and Klein, Ewan and Loper, Edward},
  year={2009},
  publisher={O'Reilly Media, Inc.}
}

6. GPT-2:
@article{radford2019language,
  title={Language Models are Unsupervised Multitask Learners},
  author={Radford, Alec and Wu, Jeffrey and Child, Rewon and Luan, David and Amodei, Dario and Sutskever, Ilya},
  year={2019},
  url={https://openai.com/research/language-models-are-unsupervised-multitask-learners}
}

7. DistilBertModel (DistilBERT):
@article{sanh2019distilbert,
  title={DistilBERT, a distilled version of BERT: smaller, faster, cheaper and lighter},
  author={Sanh, Victor and Debut, Lysandre and Wolf, Thomas and Lhoest, Gobert},
  journal={arXiv preprint arXiv:1910.01108},
  year={2019}
}