# HVAE EBM
This notebook is designed to train a Hierarchical Variational Autoencoder (HVAE) combined with an Energy-Based Model (EBM) for natural language processing tasks. Here's a summary of the key steps:

**Data Preparation**: We load precomputed dataset of questions and answers.
and use a DistilBert model to encode the questions into variable-length embeddings.
**Custom Data Loading**: A custom collate_fn is implemented to handle the variable-length question embeddings and create appropriate masks for batch processing.
**Model Architecture**:

Autoregressive HVAE_Transform: Manages the transformation between noise and hierarchical latent variables.

Cross AttentionEBM and HierarchicalTransformerEBM: Implement the Energy-Based Model using cross-attention, designed to work with both global and local latent variables.

HT_HVAE_InferenceNetwork (Encoder): Uses DistilBert and a Transformer to encode text into global and local latent distributions.

HT_HVAE_GenerativeNetwork (Decoder): Employs a GRU for sentence planning and a modified GPT-2 for word generation, conditioned on the latent variables.

Model Loading: Loads pre-trained weights for the HVAE encoder and decoder from a Weights & Biases artifact.

EBM Training Setup: Initializes the HVAE transform and EBM models, sets up an optimizer, and defines a diffusion noise schedule for training the EBM.

## Setup and Data Loading

In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All"
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

/kaggle/input/h-latents-dataset/hvae_dataset_with_latents.pt


In [None]:
import wandb
wandb.login(key="0ce56922c7ea30310a87d49246b15bc7d7ca9c89")

[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33myasir-alam14[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import sys
import os
import pandas as pd
import nltk
from transformers import AutoTokenizer
import numpy as np
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
import torch.nn.utils as utils
from tqdm import tqdm
from transformers import GPT2Model, GPT2Config
from sklearn.model_selection import train_test_split
from bs4 import BeautifulSoup
from transformers import DistilBertModel, DistilBertConfig
from transformers import DistilBertTokenizer
import matplotlib.pyplot as plt
import seaborn as sns
from datasets import load_dataset
import torch
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
import gc

2025-12-09 02:30:43.390726: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765247443.698298      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765247443.788057      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

In [None]:
class PrecomputedDataset(Dataset):
    """a custom PyTorch Dataset that loads precomputed data from a .pt file."""
    def __init__(self, pt_file_path):
        self.data = torch.load(pt_file_path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]

In [None]:
train_dataset = PrecomputedDataset('/kaggle/input/h-latents-dataset/hvae_dataset_with_latents.pt')

In [None]:
len(train_dataset)

15259

In [None]:
# PyTorch DataLoader that loads one item at a time for initial data inspection.
train_loader = DataLoader(train_dataset, batch_size = 1)

## Initializes the DistilBertTokenizer and DistilBertModel from Hugging Face
They are used for data encoding for the HVAE's inference network.

In [None]:
import torch
from transformers import DistilBertTokenizer, DistilBertModel
from tqdm import tqdm

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load pre-trained model and tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)
model.eval()

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

DistilBertModel(
  (embeddings): Embeddings(
    (word_embeddings): Embedding(30522, 768, padding_idx=0)
    (position_embeddings): Embedding(512, 768)
    (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer): Transformer(
    (layer): ModuleList(
      (0-5): 6 x TransformerBlock(
        (attention): DistilBertSdpaAttention(
          (dropout): Dropout(p=0.1, inplace=False)
          (q_lin): Linear(in_features=768, out_features=768, bias=True)
          (k_lin): Linear(in_features=768, out_features=768, bias=True)
          (v_lin): Linear(in_features=768, out_features=768, bias=True)
          (out_lin): Linear(in_features=768, out_features=768, bias=True)
        )
        (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (ffn): FFN(
          (dropout): Dropout(p=0.1, inplace=False)
          (lin1): Linear(in_features=768, out_features=3072, bias=True)
          (lin2): L

In [None]:
for batch in train_loader:
    print(batch['question'])
    break

['I recently learnt the concept that electric field lines do not cut each other this statement was proved with the logic that :\n" If electric field lines would intersect then there would be ultimately two directions of tangents for that very point which contradicts the laws of electric field lines."\n\nBut I\'m wondering if there\'s only one line of tangent associated with the electric field lines then how it would not have two directions with it. Namely;\nparallel and antiparallel directions']


Iterate through the train_loader, encoding each question using the loaded DistilBertModel (without padding to get variable-length embeddings). We also save the augmented data to a new .pt file named hvae_dataset_variable_len.pt. This pre-processes the question text into embeddings for efficiency.

In [None]:
new_data_list = []

with torch.no_grad():
    for batch in tqdm(train_loader, desc="Encoding Questions"):

        raw_question = batch['question'][0]

        # CHANGE 1: padding=False.
        # The tensor size will exactly match the word count.
        inputs = tokenizer(
            raw_question,
            return_tensors="pt",
            padding=False,  # <--- Crucial change
            truncation=True,
            max_length=512
        ).to(device)

        outputs = model(**inputs)

        # Shape: [Real_Seq_Len, 768] (e.g., [15, 768])
        full_embeddings = outputs.last_hidden_state[0].cpu()

        new_item = {
            'question': raw_question,
            'answer': batch['answer'][0],
            'enc_inputs': batch['enc_inputs'][0],
            'enc_wordMask': batch['enc_wordMask'][0],
            'dec_inputs': batch['dec_inputs'][0],
            'dec_wordmask': batch['dec_wordmask'][0],
            'global_latents': batch['global_latents'][0],
            'local_latents': batch['local_latents'][0],
            'question_encoded': full_embeddings
        }

        new_data_list.append(new_item)

torch.save(new_data_list, 'hvae_dataset_variable_len.pt')


Encoding Questions: 100%|██████████| 15259/15259 [02:04<00:00, 122.34it/s]


In [None]:
class DistilBertAugmentedDataset(Dataset):
    """similar to PrecomputedDataset. Loads the newly created hvae_dataset_variable_len.pt file."""
    def __init__(self, path):
        self.data = torch.load(path)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx]
new_dataset = DistilBertAugmentedDataset('/kaggle/working/hvae_dataset_variable_len.pt')

In [None]:
del new_data_list
gc.collect()

NameError: name 'new_data_list' is not defined

hvae_collate_fn, a custom collate function for the DataLoader. This function handles padding of variable-length question embeddings and creates corresponding masks when batching samples. It then re-creates new_train_loader using this custom function.

In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

def hvae_collate_fn(batch):
    # Initialize lists
    question_embs = []
    question_masks = []  # We will generate these now

    # Lists for other fixed columns
    enc_inputs = []
    enc_wordMask = []
    dec_inputs = []
    dec_wordmask = []
    global_latents = []
    local_latents = []
    questions_text = []
    answers_text = []

    for item in batch:
        # 1. Get the variable-length embedding
        # Check both key names just in case
        q_emb = item.get('question_emb', item.get('question_encoded'))

        # 2. CREATE THE MASK ON THE FLY
        # Since 'q_emb' contains only real data, the mask is all 1s (True)
        seq_len = q_emb.shape[0]
        q_mask = torch.ones(seq_len, dtype=torch.bool)

        question_embs.append(q_emb)
        question_masks.append(q_mask)

        # Collect other items
        enc_inputs.append(item['enc_inputs'])
        enc_wordMask.append(item['enc_wordMask'])
        dec_inputs.append(item['dec_inputs'])
        dec_wordmask.append(item['dec_wordmask'])
        global_latents.append(item['global_latents'])
        local_latents.append(item['local_latents'])

        questions_text.append(item['question'])
        answers_text.append(item['answer'])

    # 3. Pad Sequences
    # Pad embeddings with 0.0
    padded_q_emb = pad_sequence(question_embs, batch_first=True, padding_value=0.0)

    # Pad masks with False (0).
    # Result: 11111000 (1=Real, 0=Pad)
    padded_q_mask = pad_sequence(question_masks, batch_first=True, padding_value=False)

    return {
        'question_encoded': padded_q_emb,
        'question_mask': padded_q_mask, # Now this exists!

        'enc_inputs': torch.stack(enc_inputs),
        'enc_wordMask': torch.stack(enc_wordMask),
        'dec_inputs': torch.stack(dec_inputs),
        'dec_wordmask': torch.stack(dec_wordmask),
        'global_latents': torch.stack(global_latents),
        'local_latents': torch.stack(local_latents),

        'question': questions_text,
        'answer': answers_text
    }

# Re-create the loader
new_train_loader = DataLoader(
    new_dataset,
    batch_size=1,
    shuffle=True,
    num_workers=0,
    collate_fn=hvae_collate_fn
)

In [None]:
# We iterates through the new_train_loader once and prints the shapes and lengths
# of various tensors and lists within a batch, to verify if the collate_fn is working
# as expected and producing the correct tensor dimensions, including the padded question embeddings and masks.

for batch in new_train_loader:
    print(len(batch['question']))
    print(len(batch['answer']))
    print(batch['enc_inputs'].shape)
    print(batch['enc_wordMask'].shape)
    print(batch['dec_inputs'].shape)
    print(batch['dec_wordmask'].shape)
    print(batch['global_latents'].shape)
    print(batch['local_latents'].shape)
    print(batch['question_encoded'].shape)
    print(batch['question_mask'].shape)
    print(batch['global_latents'].shape)
    print(batch['local_latents'].shape)
    break

1
1
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 11, 50])
torch.Size([1, 32])
torch.Size([1, 11, 32])
torch.Size([1, 94, 768])
torch.Size([1, 94])
torch.Size([1, 32])
torch.Size([1, 11, 32])


## Model Architecture

In [None]:
import torch
import torch.nn as nn

class AutoregressiveHVAE_Transform(nn.Module):
    """
    class to handle the transformation between independent noise variables (u)
    and hierarchical latent variables (z), including both global and local latents,
    with an autoregressive dependency for local latents.
    """
    def __init__(self, prior_net, latent_dim=32, seq_len=10):
        super().__init__()
        self.seq_len = seq_len
        self.latent_dim = latent_dim

        # We now use the sophisticated MLPNetworkForPrior you defined earlier
        self.prior_net = prior_net

    def u_to_z(self, u_list):
        """
        Transform independent Noise (u) -> Hierarchical Latents (z)
        Must be SEQUENTIAL (Loop) because z_i depends on z_{i-1}
        """
        u_global, u_local_seq = u_list
        # u_global: (B, D)
        # u_local_seq: (B, Seq, D)

        batch_size = u_global.shape[0]

        # 1. Global Level (Standard Normal Prior)
        # z_t = u_t (since prior is N(0,I))
        z_t = u_global

        # 2. Local Level (Autoregressive Loop)
        z_local_list = []

        # Initialize z_{i-1} (z_prev) as zero vector for the first step
        z_prev = torch.zeros(batch_size, self.latent_dim, device=z_t.device)

        # Loop through time steps
        for i in range(self.seq_len):
            # Get the noise for this specific step
            u_i = u_local_seq[:, i, :]

            # Predict Prior Params: p(z_i | z_t, z_{i-1})
            mu_i, sigma2_i = self.prior_net(z_t, z_prev)
            sigma_i = torch.sqrt(sigma2_i)

            # Reparameterize: z_i = mu + sigma * u
            z_i = mu_i + sigma_i * u_i

            # Store and Update State
            z_local_list.append(z_i)
            z_prev = z_i

        # Stack list into tensor (B, Seq, D)
        z_local = torch.stack(z_local_list, dim=1)

        return [z_t, z_local]

    def z_to_u(self, z_list):
        """
        Inverse Transform: Recover Noise (u) from Latents (z)
        Can be PARALLELIZED using "Teacher Forcing" (Shifting inputs)
        """
        z_t, z_local = z_list
        # z_t: (B, D)
        # z_local: (B, Seq, D)

        batch_size = z_t.shape[0]
        device = z_t.device

        # 1. Global Level (z_t -> u_t)
        u_t = z_t

        # 2. Local Level (z_local -> u_local)
        # To compute u_i, we need the mean/std predicted by the prior.
        # The prior needs (z_t, z_{i-1}).

        # Create z_{i-1} sequence by shifting z_local to the right
        # [z1, z2, z3] -> [0, z1, z2]
        zeros = torch.zeros(batch_size, 1, self.latent_dim, device=device)
        z_prev_seq = torch.cat([zeros, z_local[:, :-1, :]], dim=1)

        # Expand z_t to match sequence length for batch processing
        # (B, D) -> (B, Seq, D)
        z_t_expanded = z_t.unsqueeze(1).expand(-1, self.seq_len, -1)

        # Run Prior Net in PARALLEL on the whole sequence
        # Note: Your MLPNetworkForPrior uses Linear layers, so it handles (B, Seq, D) automatically
        mu_seq, sigma2_seq = self.prior_net(z_t_expanded, z_prev_seq)
        sigma_seq = torch.sqrt(sigma2_seq)

        # Inverse Reparameterization: u = (z - mu) / sigma
        u_local = (z_local - mu_seq) / (sigma_seq + 1e-6)

        return [u_t, u_local]

In [None]:
import torch
import torch.nn as nn

class CrossAttentionEBM(nn.Module):
    """
    Unified EBM using Cross-Attention.
    Can be used for EBM1 (Global) or EBM2 (Sequence) by changing seq_len.
    Supports both single and sequence latent inputs and includes masked pooling.
    """
    def __init__(self, z_dim, context_dim, hidden_dim=128, num_heads=4, layers=2, seq_len=1):
        super().__init__()

        # 1. Projections
        self.z_proj = nn.Linear(z_dim, hidden_dim)
        self.ctx_proj = nn.Sequential(
            nn.Linear(context_dim, 512), # Maintain width
            nn.GELU(),                         # Non-linearity
            nn.LayerNorm(hidden_dim),          # Optional: Stabilizes energy magnitudes
            nn.Linear(512, hidden_dim)           # Squeeze to scalar
        )
        self.time_proj = nn.Linear(1, hidden_dim)

        # 2. Position Embedding (Only needed if seq_len > 1, e.g. for Local Latents)
        if seq_len > 1:
            self.pos_emb = nn.Parameter(torch.randn(1, seq_len, hidden_dim) * 0.02)
        else:
            self.pos_emb = None

        # 3. Transformer Decoder Layers
        # (Decoder = Self-Attn + Cross-Attn + FeedForward)
        # Note: We use 'TransformerDecoder' because it has Cross-Attention built-in.
        decoder_layer = nn.TransformerDecoderLayer(
            d_model=hidden_dim,
            nhead=num_heads,
            batch_first=True,
            norm_first=True # Usually stabilizes EBM training
        )
        self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=layers)

        # 4. Output Head
        self.energy_head = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim), # Maintain width
            nn.GELU(),                         # Non-linearity
            nn.LayerNorm(hidden_dim),          # Optional: Stabilizes energy magnitudes
            nn.Linear(hidden_dim, 1)           # Squeeze to scalar
        )

    def forward(self, z, context, t, context_mask=None, latent_word_mask=None):
        """
        Args:
            z: [Batch, seq_len, z_dim]
            context: [Batch, context_len, context_dim]
            t: [Batch, 1]

            context_mask: [Batch, context_len]
                          1 = Real Token, 0 = Padding

            latent_word_mask: [Batch, max_sent, max_words] (from enc_wordMask)
                              1 = Real Word, 0 = Padding
                              Used to determine which SENTENCES are zombies.
        """

        # --- 1. Prepare Masks ---

        # A. Context Mask (for Cross-Attention)
        # PyTorch expects True for "PAD/IGNORE".
        # Your mask has 0 for Pad. So we invert it: (mask == 0) -> True.
        if context_mask is not None:
            memory_key_padding_mask = (context_mask == 0)
        else:
            memory_key_padding_mask = None

        # B. Latent Mask (for Self-Attention & Pooling)
        # Determine if a sentence is padding based on its FIRST word
        tgt_key_padding_mask = None
        sentence_mask_float = None

        if latent_word_mask is not None and z.size(1) > 1:
            # Extract the first word of every sentence: [B, Max_Sent]
            # If first word is 0, sentence is 0.
            sentence_mask = latent_word_mask[:, :, 0]

            # Create boolean mask for Transformer (True = Ignore/Pad)
            tgt_key_padding_mask = (sentence_mask == 0)

            # Keep a float version (1.0 = Keep, 0.0 = Ignore) for pooling later
            sentence_mask_float = sentence_mask.float().unsqueeze(-1) # [B, S, 1]

        # --- 2. Embeddings ---
        z_emb = self.z_proj(z) # [B, S, H]
        t_emb = self.time_proj(t).unsqueeze(1)
        z_input = z_emb + t_emb

        if self.pos_emb is not None:
            z_input = z_input + self.pos_emb

        ctx_emb = self.ctx_proj(context) # [B, C, H]

        # --- 3. Transformer (With Masks) ---
        # tgt_key_padding_mask    -> z looking at z (Self-Attn)
        # memory_key_padding_mask -> z looking at context (Cross-Attn)
        out = self.transformer(
            tgt=z_input,
            memory=ctx_emb,
            tgt_key_padding_mask=tgt_key_padding_mask,
            memory_key_padding_mask=memory_key_padding_mask
        )

        # --- 4. Masked Energy Pooling ---
        if out.size(1) > 1:
            if sentence_mask_float is not None:

                # Zero out the energy of zombie sentences
                masked_out = out * sentence_mask_float

                # Sum valid energies
                sum_out = masked_out.sum(dim=1) # [B, H]

                # Count valid sentences (avoid div by zero)
                count_valid = sentence_mask_float.sum(dim=1) # [B, 1]
                count_valid = torch.clamp(count_valid, min=1.0)

                out_pooled = sum_out / count_valid
            else:
                # Fallback if no mask provided
                out_pooled = out.mean(dim=1)
        else:
            # For Global Latent (seq_len=1), just squeeze
            out_pooled = out.squeeze(1)

        return self.energy_head(out_pooled)

In [None]:
class HierarchicalTransformerEBM(nn.Module):
    """
    Class to combine two CrossAttentionEBM instances: one for global latents (ebm1)
    and one for local latents (ebm2).
    """
    def __init__(self, dim_z1=128, dim_z2=128, dim_context=768):
        super().__init__()

        self.ebm1 = CrossAttentionEBM(dim_z1, dim_context,seq_len=1)
        self.ebm2 = CrossAttentionEBM(dim_z2, dim_context,seq_len=11)

    def forward(self, z_list, context, t):
        z1, z2 = z_list

        # EBM 1 Energy
        energy1 = self.ebm1(z1, context, t, context_mask, latent_word_mask)

        # EBM 2 Energy
        energy2 = self.ebm2(z2, context, t, context_mask, latent_word_mask)

        return energy1.sum() + energy2.sum()


In [None]:
class MLPNetwork(nn.Module):
    """
    Optimized MLP for HVAE Encoder q(z|x).
    1. Removes bottlenecks (Maintains width).
    2. Uses GELU (Matches BERT/Transformer activations).
    3. Implements Near-Zero Initialization to prevent early KL shock.
    """
    def __init__(self, input_dim, latent_dim):
        super().__init__()

        # 1. Maintain Width: Do not compress to //2 or //4 immediately.
        # We want deep non-linearities, not compression.
        hidden_dim = input_dim

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)

        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)

        # Output layer
        self.fc_out = nn.Linear(hidden_dim, 2 * latent_dim, bias=False)

        # Match activation to DistilBERT/GPT-2 (GELU is smoother than ReLU)
        self.activation = nn.GELU()

        # --- CRITICAL: Near-Zero Initialization ---
        # This ensures that at Step 0, the posterior q(z|x) is very close to N(0,1).
        # This prevents the initial KL loss from being huge, which scares the
        # optimizer into "killing" the latent variable immediately (Collapse).
        torch.nn.init.normal_(self.fc_out.weight, mean=0.0, std=0.001)


    def forward(self, h):
        # Post-Norm architecture (standard for Transformers)
        x = self.fc1(h)
        x = self.ln1(x)
        x = self.activation(x)

        x = self.fc2(x)
        x = self.ln2(x)
        x = self.activation(x)

        output = self.fc_out(x)

        mu, raw_var_score = output.chunk(2, dim=-1)

        # Robust Softplus
        sigma2 = F.softplus(raw_var_score) + 1e-6

        return mu, sigma2

class HT_HVAE_InferenceNetwork(nn.Module):
    """
    The Hierarchical Transformer Encoder (Inference Network) q(z|x).
    Implements shared-parameter word-level and sentence-level Transformers.
    """
    def __init__(self,hyperparams):
        super().__init__()

        self.latent_dim = hyperparams['latent_dim']
        self.d_model = hyperparams['d_model']
        self.vocab_size = hyperparams['vocab_size']
        self.max_sentences = hyperparams['max_sentences']
        self.max_words = hyperparams['max_words']
        self.n_heads = hyperparams['encoder_heads']
        self.dropout = hyperparams['encoder_dropout']
        self.n_layers = hyperparams['encoder_layers']

        self.word_encoder = DistilBertModel.from_pretrained('distilbert-base-uncased')

        self.distilbert_dim = self.word_encoder.config.hidden_size
        self.word_projection = nn.Linear(self.distilbert_dim, self.d_model)

        self.sentence_position_embedding = nn.Embedding(self.max_sentences + 1, self.d_model)


        # 2.2. Sentence-Level Transformer Encoder
        # This is a separate Transformer stack for sentence-level attention.

        sentence_encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.d_model,
            nhead=self.n_heads,
            dim_feedforward=4*self.d_model,
            dropout=self.dropout,
            batch_first=True,
            # 1. Change activation from 'relu' (default) to 'gelu'
            activation='gelu',
            # 2. Enable Pre-Layer Normalization (Pre-LN)
            norm_first=True
        )
        self.transformer_sentence = nn.TransformerEncoder(sentence_encoder_layer, num_layers=self.n_layers)

        # 2.3. Latent Variable Networks (MLP)
        # These are used to fit the Gaussian distribution parameters.

        # MLP for Global Posterior q(zt | x): Learned from text-code H_s^0
        self.mlp_global = MLPNetwork(self.d_model, self.latent_dim)
        self.local_latent = self.latent_dim

        # MLP for Local Posterior q(zi | xi): Learned from sentence-code H_w^i
        # This MLP MUST SHARE PARAMETERS across all sentences.
        self.mlp_local = MLPNetwork(self.d_model, self.local_latent)
        self.doc_token = nn.Parameter(torch.randn(1, 1, self.d_model))


    def forward(self, input_ids, word_mask):
        """
        Args:
            input_ids (torch.LongTensor): Input tensor of shape
                (batch_size, MAX_SENTENCES, MAX_WORDS).
            word_mask (torch.BoolTensor): Mask tensor of shape
                (batch_size, MAX_SENTENCES, MAX_WORDS). True for padded tokens.

        Returns:
            tuple: (mu_t, sigma2_t, mu_i_list, sigma2_i_list)
        """
        # word mask is 1 for real tokens, 0 for pad
        batch_size, max_sentences, max_words = input_ids.shape

        # --- 2.1. Word-Level Transformer Encoding (Shared) ---

        # Reshape to treat all sentences as one batch for parameter sharing
        # (batch_size * MAX_SENTENCES, MAX_WORDS)
        flat_input_ids = input_ids.view(-1, max_words)
        flat_attention_mask = word_mask.view(-1, max_words)
        sentence_flat_word_mask = (word_mask.view(-1, max_words) == 0)



        # PyTorch TransformerEncoder expects the mask to be True for elements that SHOULD be IGNORED (padded)
        # and of shape (B, S). The mask must be the *attention mask* (source key padding mask).
        # Mask shape for transformer: (B*N, S)
        # Note: input_ids == 0 is often used for padding in simple setups.
        # Here we use the provided word_mask.

        distilbert_output = self.word_encoder(
            input_ids=flat_input_ids,
            attention_mask=flat_attention_mask
        )

        cls_embeddings = distilbert_output.last_hidden_state[:, 0, :]

        # Extract the Sentence Code (Hw_0): (B*N, D)
        # Hw_0 is the representation of the first token (usually <BOS> or the first word)
        H_w_0_flat = self.word_projection(cls_embeddings)

        is_padding_sentence = sentence_flat_word_mask[:, 0]

        # 2. "Firewall": Replace NaNs with 0.0 immediately
        # This protects BOTH the MLP and the Sentence Transformer downstream
        H_w_0_flat = H_w_0_flat.masked_fill(is_padding_sentence.unsqueeze(1), 0.0)

        # Reshape back: (batch_size, MAX_SENTENCES, D)
        H_w_0 = H_w_0_flat.view(batch_size, max_sentences, self.d_model)

        # --- 2.3. Local Posterior (q(zi | xi)) using Shared MLP ---

        # Calculate local distribution parameters from Hw_0
        # The mlp_local shares parameters because it is called on all Hw_0_flat
        mu_i_flat, sigma2_i_flat = self.mlp_local(H_w_0_flat)

        # Reshape back to (batch_size, MAX_SENTENCES, LATENT_DIM)
        mu_i = mu_i_flat.view(batch_size, max_sentences, self.local_latent)
        sigma2_i = sigma2_i_flat.view(batch_size, max_sentences, self.local_latent)

        # --- 2.2. Sentence-Level Transformer Encoding ---

        # Add sentence position codes to the sentence-codes Hw_0
        # We assume position codes 1 to MAX_SENTENCES are used, and 0 is reserved for H_s^0 position.
        # Create position IDs: 1, 2, 3...

        batch_doc_token = self.doc_token.expand(batch_size, -1, -1)
        H_sen_input_ = torch.cat([batch_doc_token, H_w_0], dim=1)
        position_ids = torch.arange(0, max_sentences + 1, device=input_ids.device)
        position_embeddings = self.sentence_position_embedding(position_ids) # (N, D)
        position_embeddings = position_embeddings.unsqueeze(0).expand(batch_size, -1, -1) # (B, N, D)

        H_sen_input = H_sen_input_ + position_embeddings

        # Need a sentence-level padding mask
        # Assuming sentence padding is where all words are 0/padded (i.e., word_mask[:, :, 0] is True)
        sentence_mask = (word_mask[:, :, 0] == 0) # (B, N)

        doc_mask = torch.zeros((batch_size, 1), dtype=torch.bool, device=input_ids.device)

        full_sentence_mask = torch.cat([doc_mask, sentence_mask], dim=1)

        # Sentence-Level Encoding (Hs): (B, N, D)
        H_s = self.transformer_sentence(
            src=H_sen_input,
            src_key_padding_mask=full_sentence_mask
        )

        H_s_0 = H_s[:, 0, :] # (B, D)

        # --- 2.3. Global Posterior (q(zt | x)) ---

        # Calculate global distribution parameters from H_s^0
        mu_t, sigma2_t = self.mlp_global(H_s_0) # (B, LATENT_DIM)

        return mu_t, sigma2_t, mu_i, sigma2_i


In [None]:
class MLPNetworkForPrior(nn.Module):
    """
    Optimized Prior Network p(z_i | z_t, z_i-1).
    1. Widened layers (No bottleneck).
    2. Context Dropout (Forces reliance on z_t).
    3. Near-Zero Init (Prevents initial KL explosion).
    """
    def __init__(self, latent_dim, context_dropout_rate=0.05):
        super().__init__()
        self.context_dropout_rate = context_dropout_rate

        # Input: Global (32) + Local (32) = 96
        input_dim = latent_dim + latent_dim

        # 1. Maintain Width:
        # Instead of compressing, we project up or keep equal.
        # 128 gives enough capacity to mix Global and Previous-Local info.
        hidden_dim = 2*input_dim

        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.ln1 = nn.LayerNorm(hidden_dim)

        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.ln2 = nn.LayerNorm(hidden_dim)

        self.fc_out = nn.Linear(hidden_dim, 2 * latent_dim, bias=False) # Outputs (mu, var_param)

        self.activation = nn.GELU()

        # --- CRITICAL: Near-Zero Initialization ---
        # Initialize output to be very close to N(0, 1) parameters.
        # mu -> 0, sigma_param -> 0 (which becomes softplus(0) ~ 0.69)
        # This matches the initialization of your Encoder.
        torch.nn.init.normal_(self.fc_out.weight, mean=0.0, std=0.001)
        # torch.nn.init.constant_(self.fc_out.bias, 0)

    def forward(self, z_t, z_i_minus_1):
        """
        Args:
            z_t (Tensor): Global latent (B, D_global)
            z_i_minus_1 (Tensor): Previous local latent (B, D_local)
        """

        # --- 2. Context Dropout (The "Firewall") ---
        # Randomly zero out the previous sentence information.
        # This forces the network to look at z_t to guess the current sentence state.
        if self.training and self.context_dropout_rate > 0:
            mask_prob = torch.rand(z_t.shape[0], 1, device=z_t.device)
            # If random > rate, keep signal (1.0). Else drop (0.0).
            keep_mask = (mask_prob > self.context_dropout_rate).float()
            z_i_minus_1 = z_i_minus_1 * keep_mask

        # Concatenate inputs
        h = torch.cat([z_t, z_i_minus_1], dim=-1)

        # Block 1
        x = self.fc1(h)
        x = self.ln1(x)
        x = self.activation(x)

        # Block 2
        x = self.fc2(x)
        x = self.ln2(x)
        x = self.activation(x)

        # Output Head
        output = self.fc_out(x)

        mu, raw_var_score = output.chunk(2, dim=-1)

        # Robust Softplus
        sigma2 = F.softplus(raw_var_score) + 1e-6

        return mu, sigma2

class HT_HVAE_GenerativeNetwork(nn.Module):
    """
    The Hierarchical VAE Generative Network (Decoder) p(x|z).
    Uses GRU for sentence planning and modified GPT-2 for word generation.
    """
    def __init__(self,hyperparams):
        super().__init__()

        # 3.1. Global Prior: Standard Gaussian N(0, I) is defined conceptually.
        self.latent_dim = hyperparams['latent_dim']
        self.local_latent = self.latent_dim
        self.d_model = hyperparams['d_model']
        self.gpt2_model_name = hyperparams['gpt2_model_name']
        self.vocab_size = hyperparams['vocab_size']
        self.gru_layers = hyperparams['gru_layers']

        # 3.2. Local Prior Network (HMM integration)
        self.prior_mlp = MLPNetworkForPrior(self.latent_dim)
        self.gru_initial_projection = nn.Linear(self.latent_dim, self.d_model)

        # 3.3. Sentence-Level Decoder (GRU)
        # Input size is LATENT_DIM (z_i), hidden size matches D_MODEL (for h_i), 1 layer.
        self.gru = nn.GRU(
        input_size=self.local_latent,
        hidden_size=self.d_model,
        num_layers=self.gru_layers,
        batch_first=True
      )

        # 3.4. Word-Level Decoder (Modified GPT-2)

        # Load pre-trained GPT-2 and modify its components for the VAE structure.
        # We load the full model to utilize its weights.
        self.gpt2_model = GPT2Model.from_pretrained(self.gpt2_model_name)

        self.gpt2_model.resize_token_embeddings(self.vocab_size)

        gpt2_hidden_size = self.gpt2_model.config.n_embd

        self.global_projector = nn.Linear(self.latent_dim, gpt2_hidden_size)


        # Define the projection layer here so weights are saved and trained
        self.latent_projection = nn.Linear(self.d_model, gpt2_hidden_size)

        # Set the embedding layer from GPT-2
        self.word_embedding = self.gpt2_model.wte

        self.word_dropout_rate = hyperparams['word_dropout_rate']
        self.plan_dropout_rate = hyperparams['plan_dropout_rate']
        self.mask_token_id = hyperparams.get('mask_token_id', self.gpt2_model.config.eos_token_id)


        # Get GPT-2's vocabulary head (unembedder)
        # In AutoModelForCausalLM, this is usually the same as wte.weight, but we need a specific layer
        # Since we modify the input to the final linear layer, we redefine the classification head.

        # Define the final linear layer: Input is [Hin || hi]
        # Output is the vocabulary size (Logits)
        self.final_linear = nn.Linear(gpt2_hidden_size + self.d_model, self.vocab_size, bias = False)



    def forward(self, input_ids,word_mask, z_t, z_i_samples, mu_i_prior=None, log_sigma2_i_prior=None):
        """
        Processes text sequentially for reconstruction loss calculation (training mode).

        Args:
            input_ids (torch.LongTensor): Word IDs of shape (B, N, S).
            z_t (torch.Tensor): Sampled global latent variable (B, LATENT_DIM).
            z_i_samples (torch.Tensor): Sampled local latent variables (B, N, LATENT_DIM).
            mu_i_prior (torch.Tensor, optional): Pre-calculated prior mean for KL divergence.
            log_sigma2_i_prior (torch.Tensor, optional): Pre-calculated prior log-variance for KL divergence.

        Returns:
            tuple: (reconstruction_logits, mu_i_prior, sigma2_i_prior)
        """

        batch_size, max_sentences, max_words = input_ids.shape
        global_drop_prob = 0.5

        if self.training and global_drop_prob > 0:
          # Create mask: 1 = Keep, 0 = Drop
          mask_prob = torch.rand(z_t.size(0), 1, device=z_t.device)
          keep_global_mask = (mask_prob > global_drop_prob).float()

          # Apply to z_t (This effectively zeroes h_0 for the GRU and the token for GPT)
          z_t_masked = z_t * keep_global_mask
        else:
          z_t_masked = z_t



        # --- 3.2. Sentence-Level Decoder (GRU) ---

        # The GRU processes the sequence of local latent variables (z_i_samples)
        # Sequence of z_i: (B, N, LATENT_DIM)

        # Initial hidden state for the GRU (often zeros, or derived from z_t if desired,
        # but the paper just uses z_i as input)

        # GRU output: plan_vectors (h_i) is the sequence of outputs for each z_i

        h_0_projected = self.gru_initial_projection(z_t_masked)
        h_0_projected = h_0_projected.unsqueeze(0)
        h_0 = h_0_projected.repeat(self.gru.num_layers, 1, 1) # Shape: (1, B, D_MODEL) ho is expected num_gru layers batch size hidden size

        plan_vectors, _ = self.gru(z_i_samples, h_0) # (B, N, D_MODEL)

        # --- 3.2. Local Prior Network calculation (for KL divergence) ---

        # We need z_i-1, which is the previous sampled local latent variable.
        # Create z_i-1 by shifting z_i_samples:
        z_i_minus_1 = torch.cat([
            torch.zeros_like(z_i_samples[:, :1, :]), # Use zero vector for z_i-1 of the first sentence
            z_i_samples[:, :-1, :]
        ], dim=1).view(-1, self.local_latent) # Flatten to (B*N, D)

        # Flatten z_t and z_i_samples for MLP processing
        z_t_flat = z_t_masked.unsqueeze(1).repeat(1, max_sentences, 1).view(-1, self.latent_dim) # (B*N, D)

        # Calculate the conditional prior parameters for all sentences
        mu_i_prior, sigma2_i_prior = self.prior_mlp(z_t_flat, z_i_minus_1)
        # Reshape to (B, N, LATENT_DIM)
        mu_i_prior = mu_i_prior.view(batch_size, max_sentences, self.local_latent)
        sigma2_i_prior = sigma2_i_prior.view(batch_size, max_sentences, self.local_latent)

        # --- 3.4. Word-Level Decoder (Modified GPT-2) ---

        # Reshape everything to feed into the word-level loop
        flat_input_ids = input_ids.view(-1, max_words) # (B*N, S)
        flat_plan_vectors = plan_vectors.reshape(-1, self.d_model) # (B*N, D) (h_i for each sentence)

        # 1. Word Embeddings (e_ij): (B*N, S, D)
        if self.training and self.word_dropout_rate > 0:
            # 1. Create a mask of random probabilities
            rand_mask = torch.rand(flat_input_ids.shape, device=input_ids.device)

            # 2. Identify tokens to drop (probability < rate)
            # We also ensure we do NOT drop padding tokens (if 0 is your pad)
            # Assuming word_mask indicates valid words (1) and padding (0)
            flat_word_mask_bool = word_mask.view(-1, max_words).bool()

            # Drop mask: True where we should replace with UNK
            drop_mask = (rand_mask < self.word_dropout_rate) & flat_word_mask_bool

            # 3. Create a clone to avoid in-place modification errors
            input_ids_for_decoder = flat_input_ids.clone()
            input_ids_for_decoder[drop_mask] = self.mask_token_id

        else:
            # No dropout during evaluation
            input_ids_for_decoder = flat_input_ids



        inputs_embeds = self.gpt2_model.wte(input_ids_for_decoder)

        # 2. Project Latents (Plan Vectors)
        # Ensure self.latent_projection maps from D_MODEL -> GPT_Hidden_Dim (e.g., 768)
        projected_latents = self.latent_projection(flat_plan_vectors) # (B*N, GPT_Hidden)

        # 3. Expand Latents to match Sequence Length
        # (B*N, 1, GPT_Hidden) -> (B*N, S, GPT_Hidden)
        latent_embeds = projected_latents.unsqueeze(1).expand(-1, max_words, -1)

        # 4. Additive Conditioning
        # The latent information is now fused into the input representation

        if self.training and self.plan_dropout_rate > 0:
            # We drop the ENTIRE plan vector for a sequence, not just random dimensions
            # Shape: (B*N, 1, 1) to broadcast over Sequence and Hidden dims
            p_mask = torch.rand(latent_embeds.size(0), 1, 1, device=latent_embeds.device)

            # Create boolean mask: 1 = Keep, 0 = Drop
            # If random number > dropout_rate, we keep the plan
            keep_plan_mask = (p_mask > self.plan_dropout_rate).float()

            # Apply mask: latent_embeds becomes all zeros where mask is 0
            latent_embeds = latent_embeds * keep_plan_mask

        inputs_embeds = inputs_embeds + latent_embeds

        # 5. Forward Pass
        # Note: We pass 'inputs_embeds' instead of 'input_ids'.
        # GPT-2 will automatically add Position Embeddings (wpe) to this internally.
        z_t_expanded = z_t.unsqueeze(1).expand(-1, max_sentences, -1).reshape(-1, self.latent_dim)

        # 2. Project z_t to GPT hidden dimension and add sequence dim
        # Requires: self.global_projector = nn.Linear(latent_dim, gpt_hidden_dim)
        z_t_emb = self.global_projector(z_t_expanded).unsqueeze(1) # (B*N, 1, GPT_Hidden)

        # 3. Concatenate global token to the beginning of inputs
        inputs_embeds = torch.cat([z_t_emb, inputs_embeds], dim=1) # (B*N, S+1, GPT_Hidden)

        # 4. Adjust Attention Mask
        flat_word_mask = word_mask.view(-1, max_words)
        # Create a mask of 1s for the global token
        global_mask_col = torch.ones((flat_word_mask.size(0), 1), device=flat_word_mask.device)
        # Concatenate mask: (B*N, S+1)
        extended_attention_mask = torch.cat([global_mask_col, flat_word_mask], dim=1)

        gpt2_output = self.gpt2_model(
            inputs_embeds=inputs_embeds,
            attention_mask=extended_attention_mask
        )
        # Extract the hidden states from the last layer (H_in)
        H_in = gpt2_output.last_hidden_state # (B*N, S, D)

        plan_unsqueezed = flat_plan_vectors.unsqueeze(1)

        # 2. Expand to match the sequence length 'S' of H_in: (B*N, 1, D) -> (B*N, S, D)
        # Using .expand() is more memory efficient than .repeat()
        copied_plan_vector = plan_unsqueezed.expand(-1, H_in.size(1), -1)

        # 3. Concatenate: Result shape is (B*N, S, GPT_Hidden + Plan_Dim)
        final_input = torch.cat([H_in, copied_plan_vector], dim=-1)


        # Calculate Logits (Final output for reconstruction loss)
        reconstruction_logits = self.final_linear(final_input) # (B*N, S, VOCAB_SIZE)

        # Reshape logits back to (B, N, S, VOCAB_SIZE) for loss calculation
        reconstruction_logits = reconstruction_logits.view(
            batch_size, max_sentences, max_words + 1, self.vocab_size
        )

        return reconstruction_logits, mu_i_prior, sigma2_i_prior



In [None]:
import os
import torch
import wandb

def load_models_for_inference(
    artifact_path,
    inference_net,
    generative_net,
    filename="hvae_checkpoint.pth",
    device="cuda"
):
    """
    Downloads the artifact and loads ONLY the model weights.
    Ignores optimizer, scheduler, and training state.
    """
    print(f"Fetching inference models from: {artifact_path}")

    # 1. Initialize WandB if not active (required to download artifacts)
    if wandb.run is None:
        # You can use "anonymous" or your specific entity/project
        wandb.init(project="my_project", job_type="inference", mode="online")

    # 2. Download the artifact
    artifact = wandb.use_artifact(artifact_path, type='model')
    artifact_dir = artifact.download()
    filepath = os.path.join(artifact_dir, filename)

    # 3. Load the dictionary
    # We load the whole file into RAM, but we only extract what we need
    if torch.cuda.is_available() and device == 'cuda':
        checkpoint = torch.load(filepath)
    else:
        checkpoint = torch.load(filepath, map_location=torch.device('cpu'))

    # 4. Load ONLY the model states
    # strict=True ensures the keys match exactly (good for safety)
    inference_net.load_state_dict(checkpoint['inference_state_dict'])
    generative_net.load_state_dict(checkpoint['generative_state_dict'])

    # 5. Set to Evaluation Mode
    # Critical: This turns off Dropout and fixes Batch Norm layers
    inference_net.eval()
    generative_net.eval()

    print(f"Successfully loaded models from Epoch {checkpoint.get('epoch', 'Unknown')}")

    return inference_net, generative_net






In [None]:
hyperParams = {
    # --- Architecture Constraints ---
    'gpt2_model_name': 'gpt2',   # Start with 'gpt2' (124M params). Use 'gpt2-medium' only if you have high VRAM.
    'd_model': 768,              # MUST be 768 for 'gpt2', 1024 for 'gpt2-medium', 1280 for 'gpt2-large'.
    'vocab_size': 50257,         # Standard GPT-2 vocabulary size.

    # --- Latent Space ---
    'latent_dim': 32,           # 32-128 is standard. Too large = posterior collapse; Too small = poor reconstruction.

    # --- Data Dimensions (Abstract specific) ---
    'max_sentences': 11,         # Average abstract is 5-8 sentences; 10 provides a safety buffer.
    'max_words': 50,             # Average sentence is 20-30 words; 50 covers outliers.

    # --- Inference Network (Encoder) ---
    'encoder_layers': 2,         # Keep encoder shallow (2-4) so the heavy lifting happens in the decoder/latent space.
    'encoder_heads': 8,          # Standard for d_model=768 (768 / 8 = 96 dim per head).
    'encoder_dropout': 0.1,      # Standard transformer dropout.

    # --- Sentence Decoder (GRU) ---
    'gru_layers': 1,             # 1 layer is sufficient for the high-level plan; more layers add unnecessary complexity.

    # --- Special Tokens ---
    'pad_index': 50256,          # GPT-2 uses EOS (50256) as PAD by default unless you added a new token.
    'word_dropout_rate' : 0.5,
    'plan_dropout_rate':0.15,
    'mask_token_id':10
}


In [None]:
from transformers import GPT2Tokenizer, DistilBertTokenizer

# 1. Load standard tokenizers
dec_tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
enc_tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# 2. Add Special Tokens
# We use 'additional_special_tokens' for custom tokens like EOT
special_tokens_dict = {
    'bos_token': '<|BOS|>',
    'eos_token': '<|EOS|>',
    'pad_token': '<|PAD|>',
    'mask_token': '<|MASK|>',
    'additional_special_tokens': ['<|EOT|>']
}

# This returns the number of added tokens
num_added_toks = dec_tokenizer.add_special_tokens(special_tokens_dict)

# 3. Get IDs
# Standard attributes exist for bos/eos/pad/mask
pad_idx = dec_tokenizer.pad_token_id
eos_idx = dec_tokenizer.eos_token_id
mask_idx = dec_tokenizer.mask_token_id

# For EOT, we must look it up manually since .eot_token_id doesn't exist
eot_token_id = dec_tokenizer.convert_tokens_to_ids('<|EOT|>')

new_vocab_size = len(dec_tokenizer)

print(f"New Vocab Size: {new_vocab_size}")
print(f"PAD ID: {pad_idx} | EOS ID: {eos_idx}")
print(f"MASK ID: {mask_idx} | EOT ID: {eot_token_id}")

# 4. Update hyperparameters
hyperParams['vocab_size'] = new_vocab_size
hyperParams['pad_index'] = pad_idx
hyperParams['mask_token_id'] = mask_idx
hyperParams['eot_token_id'] = eot_token_id

tokenizer_config.json:   0%|          | 0.00/26.0 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

New Vocab Size: 50262
PAD ID: 50259 | EOS ID: 50258
MASK ID: 50260 | EOT ID: 50261


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
inference_net = HT_HVAE_InferenceNetwork(hyperParams).to(device)
generative_net = HT_HVAE_GenerativeNetwork(hyperParams).to(device)

The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. To disable this, use `mean_resizing=False`


In [None]:
wandb.init(
    project="HDHEBM",
    config=hyperParams,
)

In [None]:
# 2. Load weights

load_models_for_inference(
    artifact_path="yasir-alam14/HVAE-distilbert_plan_masking_full/hvae-model:v6",
    inference_net=inference_net,
    generative_net=generative_net
)


Fetching inference models from: yasir-alam14/HVAE-distilbert_plan_masking_full/hvae-model:v6


[34m[1mwandb[0m: Downloading large artifact hvae-model:v6, 2895.48MB. 1 files... 
[34m[1mwandb[0m:   1 of 1 files downloaded.  
Done. 0:0:7.1 (408.3MB/s)


Successfully loaded models from Epoch 34


(HT_HVAE_InferenceNetwork(
   (word_encoder): DistilBertModel(
     (embeddings): Embeddings(
       (word_embeddings): Embedding(30522, 768, padding_idx=0)
       (position_embeddings): Embedding(512, 768)
       (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
       (dropout): Dropout(p=0.1, inplace=False)
     )
     (transformer): Transformer(
       (layer): ModuleList(
         (0-5): 6 x TransformerBlock(
           (attention): DistilBertSdpaAttention(
             (dropout): Dropout(p=0.1, inplace=False)
             (q_lin): Linear(in_features=768, out_features=768, bias=True)
             (k_lin): Linear(in_features=768, out_features=768, bias=True)
             (v_lin): Linear(in_features=768, out_features=768, bias=True)
             (out_lin): Linear(in_features=768, out_features=768, bias=True)
           )
           (sa_layer_norm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
           (ffn): FFN(
             (dropout): Dropout(p=0.1, inp

In [None]:
prio_net = generative_net.prior_mlp

In [None]:
# --- Setup ---
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Instantiate Models
hvae_transform = AutoregressiveHVAE_Transform(prio_net).to(device) # Pre-trained/Fixed backbone
ebm = HierarchicalTransformerEBM().to(device)           # The model we are training

optimizer = torch.optim.Adam(ebm.parameters(), lr=1e-4)

# --- Diffusion Noise Schedule ---
# Simple linear schedule for alpha/sigma
T_steps = 100
betas = torch.linspace(1e-4, 0.02, T_steps).to(device)
alphas = 1.0 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sigmas = torch.sqrt(1.0 - alphas_cumprod) # Approx noise level

# --- Training Loop ---
# Assume data_loader provides z_posterior samples from your trained Encoder
# real_z_list = [z1_batch, z2_batch]
# Assume optimizer, ebm, hvae_transform are defined
# Assume standard scheduler vars (alphas, betas, alphas_cumprod) are defined

for epoch in range(10):
    for batch in new_train_loader: # Changed to 'batch' (dict) to access masks

        # 0. Prepare Data & Masks
        # Extract components from batch
        # Note: We need 'enc_wordMask' to determine which sentences are Zombies
        enc_word_mask = batch['enc_wordMask'].to(device) # [B, Max_Sent, Max_Words]

        # Create Sentence Mask: 1.0 = Real, 0.0 = Zombie
        # Logic: If first word is padded (0), the whole sentence is padded
        sentence_mask = enc_word_mask[:, :, 0].float() # [B, Max_Sent]

        # Expand for broadcasting against u_local [B, Max_Sent, 1]
        local_valid_mask = sentence_mask.unsqueeze(-1)

        # Get Latents
        real_z_list = [
            batch['global_latents'].to(device),
            batch['local_latents'].to(device)
        ]

        # Get Context for EBM
        context = batch['question_encoded'].to(device)
        context_mask = batch['question_mask'].to(device)

        batch_size = real_z_list[0].shape[0]

        # 1. Inverse Transform: Get Real u_0
        with torch.no_grad():
            u0_list = hvae_transform.z_to_u(real_z_list)

        # 2. Sample Random Time Step t
        t_indices = torch.randint(0, T_steps, (batch_size,), device=device)

        # Get params (Simplified for brevity)
        a_t = torch.sqrt(alphas_cumprod[t_indices]).view(-1, 1)
        s_t = torch.sqrt(1 - alphas_cumprod[t_indices]).view(-1, 1)

        # Noise for Forward Process
        # We generally don't need to mask this noise (u_t for zombies can be noisy),
        # BUT technically u_0 for zombies is garbage anyway.
        noise_list = [torch.randn_like(u) for u in u0_list]

        # Create Noisy Target (u_t)
        u_t_list = []
        for i, (u0, eps) in enumerate(zip(u0_list, noise_list)):
            # Broadcast helper
            if u0.ndim == 3:
                a_view, s_view = a_t.unsqueeze(1), s_t.unsqueeze(1)
            else:
                a_view, s_view = a_t, s_t

            u_t_val = a_view * u0 + s_view * eps

            # OPTIONAL: Force u_t zombies to be pure noise?
            # Not strictly necessary if we mask updates, but cleaner.
            u_t_list.append(u_t_val)

        # Create Anchor (u_{t+1})
        step_alpha = torch.sqrt(alphas[t_indices]).view(-1, 1)
        step_sigma = torch.sqrt(betas[t_indices]).view(-1, 1)

        u_anchor_list = []
        for i, u_t in enumerate(u_t_list):
            eps = torch.randn_like(u_t)
            if u_t.ndim == 3:
                sa_view, ss_view = step_alpha.unsqueeze(1), step_sigma.unsqueeze(1)
            else:
                sa_view, ss_view = step_alpha, step_sigma

            u_anchor = sa_view * u_t + ss_view * eps
            u_anchor_list.append(u_anchor)

        # 3. Negative Sampling (Langevin Dynamics)
        u_neg_list = [u.clone().detach().requires_grad_(True) for u in u_anchor_list]

        # MCMC Loop
        for k in range(10): # k steps
            # A. Transform u -> z
            z_neg_list = hvae_transform.u_to_z(u_neg_list)

            # B. Compute Energy
            t_input = t_indices.float().view(-1, 1)

            # Pass Masks to EBM! (Using the logic we defined in CrossAttentionEBM)
            energy = ebm(
                z_neg_list,
                context,
                t_input,
                context_mask=context_mask,
                latent_word_mask=enc_word_mask # EBM uses this to mask pooling/attention
            )

            # C. Gradients
            grads = torch.autograd.grad(energy, u_neg_list)

            # D. Update with Masking Logic
            new_u_list = []
            for i, (u, g, anchor) in enumerate(zip(u_neg_list, grads, u_anchor_list)):

                # Check dimensions to set views
                if u.ndim == 3:
                    ss_view = step_sigma.unsqueeze(1)
                else:
                    ss_view = step_sigma

                # Calculate Forces
                grad_anchor = -(u - anchor) / (ss_view**2 + 1e-6)
                grad_prior = -u
                total_grad = g + grad_anchor + grad_prior

                # Step Size
                step_size = 1e-2 * (ss_view**2)

                # Noise
                noise = torch.randn_like(u)

                # --- CRITICAL: MASKING LOGIC ---
                # Check if this is the Local Latent tensor (Index 1)
                if i == 1:
                    # 1. Mask the Gradient (Forces become 0 for Zombies)
                    total_grad = total_grad * local_valid_mask

                    # 2. Mask the Noise (No Brownian drift for Zombies)
                    noise = noise * local_valid_mask
                # -------------------------------

                # Langevin Update
                u_new = u + step_size * total_grad + torch.sqrt(2*step_size) * noise
                new_u_list.append(u_new.detach().requires_grad_(True))

            u_neg_list = new_u_list

        # 4. Compute Loss
        # Positive Energy
        u_t_list = [u.detach().requires_grad_(True) for u in u_t_list]
        z_pos_list = hvae_transform.u_to_z(u_t_list)
        pos_energy = ebm(z_pos_list, context, t_input, context_mask, enc_word_mask)

        # Negative Energy
        z_neg_list = hvae_transform.u_to_z(u_neg_list)
        neg_energy = ebm(z_neg_list, context, t_input, context_mask, enc_word_mask)

        # Loss
        loss = -(pos_energy - neg_energy)

        optimizer.zero_grad()
        loss.backward()

        # Optional: Clip Gradients
        torch.nn.utils.clip_grad_norm_(ebm.parameters(), 1.0)

        optimizer.step()

        print(f"Loss: {loss.item()}")

In [None]:
# The generate_samples function outlines the reverse diffusion process for
# sampling new latent variables (u) from the EBM. It initializes pure noise
# and then iteratively refines the samples using Langevin dynamics based on the EBM's energy function.
import torch

def generate_samples(
    hvae,           # The HVAE_Transform module
    ebm,            # The HierarchicalEBM module
    batch_size,
    device,
    n_langevin_steps=30, # Steps per diffusion level (Refinement)
    step_size_base=1e-2  # Base learning rate for Langevin
):
    # --- 1. Define Noise Schedule (Must match training) ---
    T = 100
    betas = torch.linspace(1e-4, 0.02, T).to(device)
    # In reverse, we need sigma_{t+1} for the anchor term
    # For simplicity in this script, we approximate sigma ~ sqrt(beta)

    # --- 2. Initialize u_T (Pure Noise) ---
    # u1: Global [B, 128]
    # u2: Sequence [B, 10, 128]
    u1 = torch.randn(batch_size, 128, device=device)
    u2 = torch.randn(batch_size, 10, 128, device=device)
    u_list = [u1, u2]

    print("Starting Reverse Diffusion...")

    # --- 3. Reverse Diffusion Loop (T-1 -> 0) ---
    for t in range(T - 1, -1, -1):

        # A. Define the Anchor (u_{t+1})
        # The current state is u_{t+1}. We want to find u_t.
        u_anchor_list = [u.clone().detach() for u in u_list]

        # Get noise level for this step (sigma_{t+1})
        # Used for the "Leash" strength
        curr_beta = betas[t] # Scalar

        # B. Inner Langevin Refinement Loop
        for k in range(n_langevin_steps):
            # Enable gradients
            for u in u_list: u.requires_grad_(True)

            # 1. Transform u -> z (Uses HVAE Decoder)
            z_list = hvae.u_to_z(u_list)

            # 2. Compute Energy
            # Pass t as float tensor [B, 1]
            t_tensor = torch.full((batch_size, 1), t, device=device).float()
            energy = ebm(z_list, t_tensor)

            # 3. Compute Gradients
            grads = torch.autograd.grad(energy, u_list)

            # 4. Update u
            new_u_list = []
            for u, g, anchor in zip(u_list, grads, u_anchor_list):
                # Handle broadcasting
                beta_view = curr_beta if u.ndim == 2 else curr_beta.view(1, 1, 1)

                # Gradients
                term_energy = g
                term_prior = -u
                # Anchor: -(u - anchor) / sigma^2
                term_anchor = -(u - anchor) / (beta_view + 1e-6)

                total_grad = term_energy + term_prior + term_anchor

                # Scale step size by noise level (Heuristic from paper/Song et al.)
                # Steps should be smaller when noise is small
                s = step_size_base * beta_view

                # Langevin Update
                noise = torch.randn_like(u)
                u_new = u + s * total_grad + torch.sqrt(2 * s) * noise

                new_u_list.append(u_new.detach())

            u_list = new_u_list

        if t % 10 == 0:
            print(f"Step {t} finished.")

    # --- 4. Final Transform u_0 -> z_0 ---
    # These are your high-quality latent codes ready for the decoder
    final_z_list = hvae.u_to_z(u_list)

    return final_z_list