In [1]:
import sys
import os
sys.path.append('..')
os.chdir("../")

In [18]:
from transformers import AutoTokenizer

from shortcutfm.__main__ import parse_config
from shortcutfm.train.pl.trainer import create_dataloaders
from shortcutfm.train.pl.trainer_factory import create_criterion



config_path = "configs/training/qqp.yaml"
cfg = parse_config(config_path, [])

tokenizer = AutoTokenizer.from_pretrained(cfg.model.config_name)
criterion = create_criterion(cfg, tokenizer=tokenizer)

Tied lm_head.weight to word_embedding.weight after pretrained loading
word emebedding requires grad: False
lm head requires grad: False
lm_head tied to word_embedding: True


In [4]:
# Re-create dataloaders with num_workers=0 for Jupyter compatibility
train_dataloader, val_dataloader = create_dataloaders(cfg, num_workers=1)

2025-06-22 12:36:54,878 - INFO - Loading dataset...
2025-06-22 12:36:54,962 - INFO - Train dataset contains 144715 samples.
2025-06-22 12:36:54,971 - INFO - Validation dataset contains 2048 samples.


tensor([[ 101, 2054, 2515,  ...,    0,    0,    0],
        [ 101, 2054, 2024,  ...,    0,    0,    0],
        [ 101, 2054, 2024,  ...,    0,    0,    0],
        ...,
        [ 101, 2054, 2003,  ...,    0,    0,    0],
        [ 101, 2190, 2126,  ...,    0,    0,    0],
        [ 101, 2054, 2003,  ...,    0,    0,    0]])


In [6]:
batch = next(iter(train_dataloader))

In [7]:
print(batch.seqs.shape)

torch.Size([8, 128])


In [9]:
sentences = tokenizer.batch_decode(batch.seqs, skip_special_tokens=False)
print(sentences[0])

[CLS] would you rather vote for donald trump or hillary clinton? why? [SEP] [CLS] donald trump or hillary clinton? why? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]


In [12]:
emebddings = criterion.flow_matching_criterion.model.get_embeddings(batch.seqs)
print(emebddings.shape)

torch.Size([8, 128, 768])


In [14]:
predicted_tokens = criterion.flow_matching_criterion.model.compute_logits(emebddings)
print(predicted_tokens.shape)

torch.Size([8, 128, 30522])


# Ce loss netween token ids and predicted tokens

In [29]:
from torch import nn


ce = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
loss = ce(predicted_tokens.view(-1, predicted_tokens.size(-1)), batch.seqs.view(-1))
print(loss.item())


9.334590911865234


In [30]:
# are word_emebddings and and lm_head the same?
import torch


word_embeddings = criterion.flow_matching_criterion.model.module.word_embedding
lm_head = criterion.flow_matching_criterion.model.module.lm_head

print(word_embeddings.weight is lm_head.weight)

# shapes
print(word_embeddings.weight.shape)
print(lm_head.weight.shape)
# check if weight tensors are identical 
print(torch.equal(word_embeddings.weight, lm_head.weight))

# tensors are the same
#weryfiy ce loss 
predicted_tokens = criterion.flow_matching_criterion.model.compute_logits(emebddings)
loss = ce(predicted_tokens.view(-1, predicted_tokens.size(-1)), batch.seqs.view(-1))
print(loss.item())

True
torch.Size([30522, 768])
torch.Size([30522, 768])
True
9.334590911865234


In [32]:
print(f"Embedding norm: {emebddings.norm(dim=-1).mean().item()}")

Embedding norm: 1.1510801315307617


In [34]:
embeddings_normalized = emebddings / (emebddings.norm(dim=-1, keepdim=True) + 1e-6)
logits = criterion.flow_matching_criterion.model.compute_logits(embeddings_normalized)
loss = ce(logits.view(-1, logits.size(-1)), batch.seqs.view(-1))
print(f"Loss with normalized embeddings: {loss.item()}")

Loss with normalized embeddings: 9.516925811767578


# FUcniton to test weights tying

In [41]:
def test_weight_tying(word_embedding, lm_head, vocab_size=10, test_tokens=None):
    """Test if weight tying is working correctly."""

    print("=== Weight Tying Test ===")

    # Check if weights are actually tied
    print(f"Weights are tied: {lm_head.weight is word_embedding.weight}")
    print(f"Weight shapes - Embedding: {word_embedding.weight.shape}, LM Head: {lm_head.weight.shape}")

    # Test with a few token IDs
    if test_tokens is None:
        test_tokens = torch.tensor([0, 1, 2, 100, 500])  # Adjust based on your vocab size

    print(f"\nTesting with tokens: {test_tokens.tolist()}")

    # Forward pass: tokens -> embeddings -> logits
    embeddings = word_embedding(test_tokens)  # [num_tokens, embed_dim]
    logits = lm_head(embeddings)  # [num_tokens, vocab_size]

    # Get predicted tokens (highest logit)
    predicted_tokens = torch.argmax(logits, dim=-1)

    print(f"Original tokens:   {test_tokens.tolist()}")
    print(f"Predicted tokens:  {predicted_tokens.tolist()}")
    print(f"Match rate: {(test_tokens == predicted_tokens).float().mean().item():.2%}")

    # Check logit values for original tokens
    print("\nLogit analysis:")
    for i, token_id in enumerate(test_tokens):
        original_logit = logits[i, token_id].item()
        max_logit = logits[i].max().item()
        max_token = logits[i].argmax().item()
        print(f"Token {token_id}: logit={original_logit:.3f}, max_logit={max_logit:.3f} (token {max_token})")

    # Test with bias removed (if bias exists)
    if lm_head.bias is not None:
        print("\n=== Test without bias ===")
        logits_no_bias = lm_head(embeddings) - lm_head.bias
        predicted_no_bias = torch.argmax(logits_no_bias, dim=-1)
        print(f"Predicted (no bias): {predicted_no_bias.tolist()}")
        print(f"Match rate (no bias): {(test_tokens == predicted_no_bias).float().mean().item():.2%}")

    #compute ce loss
    criterion = nn.CrossEntropyLoss()
    loss = criterion(logits, test_tokens)
    print(f"\nCross-Entropy Loss: {loss.item():.4f}")


# Tied pretrained embedding weights

In [None]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer, BertModel

# Initialize tokenizer and BERT model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
bert = BertModel.from_pretrained("bert-base-uncased")

# Create embedding layer and initialize with pretrained weights
embedding = nn.Embedding(30522, 768)
with torch.no_grad():
    embedding.weight.copy_(bert.embeddings.word_embeddings.weight)

# Create lm_head and tie weights to embedding
lm_head = nn.Linear(768, 30522, bias=True)
with torch.no_grad():
    lm_head.weight = embedding.weight  # Tie weights
    lm_head.bias.zero_()  # Initialize bias to zero

# Freeze weights
embedding.weight.requires_grad = False
lm_head.weight.requires_grad = False

# Prepare input
text = ["This is a test sentence."]
encoding = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
input_ids = encoding["input_ids"]

# Compute embeddings and logits
embeddings = embedding(input_ids)
logits = lm_head(embeddings)

# Compute cross-entropy loss
ce = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
loss = ce(logits.view(-1, logits.size(-1)), input_ids.view(-1))
print(f"Cross-entropy loss: {loss.item()}")

Cross-entropy loss: 9.256082534790039


In [42]:
test_weight_tying(embedding, lm_head, vocab_size=30522, test_tokens=input_ids[0])

=== Weight Tying Test ===
Weights are tied: True
Weight shapes - Embedding: torch.Size([30522, 768]), LM Head: torch.Size([30522, 768])

Testing with tokens: [101, 2023, 2003, 1037, 3231, 6251, 1012, 102]
Original tokens:   [101, 2023, 2003, 1037, 3231, 6251, 1012, 102]
Predicted tokens:  [101, 2023, 2003, 1037, 3231, 6251, 1012, 101]
Match rate: 87.50%

Logit analysis:
Token 101: logit=4.124, max_logit=4.124 (token 101)
Token 2023: logit=0.816, max_logit=0.816 (token 2023)
Token 2003: logit=0.760, max_logit=0.760 (token 2003)
Token 1037: logit=0.758, max_logit=0.758 (token 1037)
Token 3231: logit=1.310, max_logit=1.310 (token 3231)
Token 6251: logit=1.555, max_logit=1.555 (token 6251)
Token 1012: logit=0.623, max_logit=0.623 (token 1012)
Token 102: logit=0.586, max_logit=0.643 (token 101)

=== Test without bias ===
Predicted (no bias): [101, 2023, 2003, 1037, 3231, 6251, 1012, 101]
Match rate (no bias): 87.50%

Cross-Entropy Loss: 9.2561


In [38]:
import torch
import torch.nn as nn
from transformers import AutoTokenizer

# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

# Create embedding layer with random weights
embedding = nn.Embedding(30522, 768)
nn.init.normal_(embedding.weight, mean=0.0, std=0.02)  # Random initialization with small std
# embedding.weight.requires_grad = False

# Create lm_head and tie weights to embedding
lm_head = nn.Linear(768, 30522, bias=False)
with torch.no_grad():
    lm_head.weight = embedding.weight  # Tie weights
    # lm_head.bias.zero_()  # Initialize bias to zero
# lm_head.weight.requires_grad = False


# Prepare input
text = ["This is a test sentence."]
encoding = tokenizer(text, padding=True, truncation=True, return_tensors="pt")
input_ids = encoding["input_ids"]

# Compute embeddings and logits
embeddings = embedding(input_ids)
logits = lm_head(embeddings)

# Compute cross-entropy loss
ce = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
loss = ce(logits.view(-1, logits.size(-1)), input_ids.view(-1))
print(f"Cross-entropy loss: {loss.item()}")

Cross-entropy loss: 10.014132499694824


In [None]:
test_weight_tying(embedding, lm_head, vocab_size=30522, test_tokens=input_ids[0])

In [36]:
print(f"Embedding norm: {embeddings.norm(dim=-1).mean().item()}")

Embedding norm: 1.077894687652588
