# %% [markdown]

 # Section 1: Text Autoencoders - Exploring SONAR
 This notebook explores Meta's SONAR text autoencoder, which can encode text
 into fixed-size vectors and decode them back to (approximately) the original text.
 Learning objectives:
 1. Load and use SONAR for text encoding/decoding
 2. Understand the properties of text embeddings
 3. Test robustness to noise
 4. Explore how text length affects embeddings
 5. Experiment with token swapping and sentence combinations

# %% [markdown]

 ## Setup and Installation

 First, we need to install SONAR and its dependencies. Just run, nothing worth reading here unless you get errors.
 Note: You may need to adjust the CUDA version in fairseq2 installation.

In [None]:
# %%

!pip install -q fairseq2==0.4.5 sonar-space==0.4.0 torchvision==0.21.0 torch==2.6.0 torchaudio==2.6.0 plotly nbformat

import torch
import numpy as np
from sonar.inference_pipelines.text import TextToEmbeddingModelPipeline
from sonar.inference_pipelines.text import EmbeddingToTextModelPipeline
import torch.nn as nn
import pandas as pd
import plotly.express as px
import plotly.graph_objects as go
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import matplotlib.pyplot as plt
from datasets import load_dataset
import json
from jaxtyping import Float

# Check if CUDA is available
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DEVICE = torch.device(DEVICE)
torch.set_grad_enabled(False)  # We're only doing inference
print(f"Using device: {DEVICE}")

# %% [markdown]

 ## Loading SONAR Models

 SONAR (Sentence-Level Multimodal and Language-Agnostic Representations) is Meta's text autoencoder
 that can encode entire sentences/paragraphs into fixed-size vectors and decode them back to approximately
 the original text.

 **What are Text Autoencoders?**

 Text Autoencoders are models that compress entire input sequences (sentences/paragraphs) into a single
 fixed-size vector representation (the "bottleneck"), then reconstruct the original text from that vector.
 Unlike typical text embedding models that only encode, these models have both an encoder AND decoder.

 ![Text Autoencoder Architecture](https://39669.cdn.cke-cs.com/rQvD3VnunXZu34m86e5f/images/db8d350884974ce6dcb1281011c5053e11b65711c12a4556.png)

 **How Text Autoencoders Work:**
 1. **Encoder**: Takes input text → processes through Transformer → outputs single fixed-size vector (1024-dim)
 2. **Bottleneck**: The compressed representation that captures semantic meaning in a dense vector
 3. **Decoder**: Takes the vector → generates text that approximates the original input

 **Key Properties:**
 - **Lossy compression**: Some information is lost, but semantic meaning is preserved
 - **Fixed-size representation**: Any length text becomes same-size vector (useful for comparison/clustering)
 - **Cross-lingual**: Can encode in one language and decode in another
 - **Reconstruction capability**: Unlike embedding-only models, you can decode back to text
 - **Semantic preservation**: The bottleneck captures core meaning even with compression

 **SONAR Specifically:**
 - Trained on ~100B tokens with denoising and translation objectives
 - Uses 24-layer Transformer encoder and decoder, with mean-pooling to create the bottleneck vector
 - Supports 200+ languages and can handle up to 512 tokens of context
 - Currently one of the best-performing text autoencoders available


# %% [markdown]

 We start by loading the models.

In [None]:
print("Loading SONAR models...")
text2vec = TextToEmbeddingModelPipeline(
    encoder="text_sonar_basic_encoder",
    tokenizer="text_sonar_basic_encoder",
    device=DEVICE
)
vec2text = EmbeddingToTextModelPipeline(
    decoder="text_sonar_basic_decoder",
    tokenizer="text_sonar_basic_encoder",
    device=DEVICE
)
print("Models loaded successfully!")

# %% [markdown]

 ## Basic Usage - Encoding and Decoding

 Test basic encoding and decoding functionality.

In [None]:
# %%

# Simple example sentences
sentences = [
    'My name is SONAR.',
    'I can embed sentences into vectorial space.'
]

# Encode sentences to vectors
embeddings = text2vec.predict(sentences, source_lang="eng_Latn")
print(f"Embeddings shape: {embeddings.shape}")  # Should be [2, 1024]
print(f"Embedding dimension: {embeddings.shape[1]}")
print(f"L2 norm of embeddings: {torch.norm(embeddings, dim=1).tolist()}")

# Decode vectors back to text
reconstructed = vec2text.predict(embeddings, target_lang="eng_Latn", max_seq_len=512)
print("\nReconstruction quality:")
for orig, rec in zip(sentences, reconstructed):
    print(f"Original:      {orig}")
    print(f"Reconstructed: {rec}")
    print()

# %% [markdown]

 ## Exercise 1: Testing with Longer, More Realistic Text
 Let's test how well SONAR handles paragraph-length text.

 Write a function to reconstruct text from SONAR embeddings, and try testing with some longer text.

In [None]:
def reconstruct_text(texts: list[str]) -> list[str]:
    """Reconstruct text from SONAR embedding, by first encoding and then decoding the text.

    Args:
        texts: List of strings to embed and then reconstruct.

    Returns:
        List of reconstructed strings.
    """
    # [your implementation here]
    embedding = text2vec.predict(texts, source_lang="eng_Latn")
    return vec2text.predict(embedding, target_lang="eng_Latn", max_seq_len=512)

# Longer example paragraphs
paragraph1 = """SONAR is a model from August 2023, trained as a semantic text auto-encoder,
converting text into semantic embed vectors, which can later be decoded back into text.
Additionally, the model is trained such that the semantic embed vectors are to some degree
"universal" for different languages, and one can embed in French and decode in English."""

paragraph2 = """I tried it, and SONAR seems to work surprisingly well. For example, the above
paragraph and this paragraph, if each are encoded into two 1024 dimensional vectors
(one for each paragraph), the model returns the following decoded outputs."""

paragraph3 = """\
Your text here.
"""

# Test with paragraphs
long_texts = [paragraph1, paragraph2, paragraph3]
long_reconstructed = reconstruct_text(long_texts)

print("Paragraph reconstruction:")
for i, (orig, rec) in enumerate(zip(long_texts, long_reconstructed)):
    print(f"\n--- Paragraph {i+1} ---")
    print(f"Original ({len(orig)} chars):")
    print(orig[:100] + "..." if len(orig) > 100 else orig)
    print(f"\nReconstructed ({len(rec)} chars):")
    print(rec[:100] + "..." if len(rec) > 100 else rec)