# Task 5: Text Embedding using Hugging Face

This notebook demonstrates how to generate text embeddings using Hugging Face Transformers for text-to-image generation.


In [None]:
import matplotlib.pyplot as plt
from transformers import BertTokenizer, BertModel
import torch
import numpy as np


In [None]:
# Load tokenizer and model
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')
model.eval()

print("Model and tokenizer loaded successfully!")


In [None]:
# Generate embeddings for text descriptions
text = "A beautiful sunset over the mountains"

# Tokenize
encoded = tokenizer(text, return_tensors='pt', padding=True, truncation=True, max_length=77)
tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
token_ids = encoded['input_ids'][0].tolist()

print("Original Text:", text)
print("Tokens:", tokens)
print("Token IDs:", token_ids)

# Generate embeddings
with torch.no_grad():
    outputs = model(**encoded)
    embeddings = outputs.last_hidden_state

print(f"\nEmbedding shape: {embeddings.shape}")

# Visualize token IDs
plt.figure(figsize=(12, 4))
plt.bar(range(len(token_ids)), token_ids)
plt.xticks(range(len(token_ids)), tokens, rotation=45, ha='right')
plt.title("Token IDs for Text Description")
plt.xlabel("Tokens")
plt.ylabel("Token IDs")
plt.tight_layout()
plt.show()


In [None]:
# Example with multiple text descriptions
text_samples = [
    "Generate an image of a cat",
    "A red car driving on a highway",
    "Beautiful landscape with mountains and lake"
]

for text in text_samples:
    encoded = tokenizer(text, return_tensors='pt', padding='max_length', truncation=True, max_length=77)
    tokens = tokenizer.convert_ids_to_tokens(encoded['input_ids'][0])
    token_ids = encoded['input_ids'][0].tolist()
    
    with torch.no_grad():
        outputs = model(**encoded)
        embeddings = outputs.last_hidden_state
    
    print(f"\nText: {text}")
    print(f"Embedding shape: {embeddings.shape}")
    
    plt.figure(figsize=(12, 4))
    plt.bar(range(len(token_ids)), token_ids)
    plt.xticks(range(len(token_ids)), tokens, rotation=45, ha='right')
    plt.title(f"Token IDs: {text[:40]}...")
    plt.xlabel("Tokens")
    plt.ylabel("Token IDs")
    plt.tight_layout()
    plt.show()
