In [1]:
from transformers import BartTokenizer, BartForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
from IPython.display import display, HTML
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

mode = 'bart'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "facebook/bart-large"
tokenizer = BartTokenizer.from_pretrained(model_name)
model = BartForConditionalGeneration.from_pretrained(model_name).to(device)

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers.modeling_outputs import BaseModelOutput
from IPython.display import display, HTML
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt

mode = 't5'
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name = "google/flan-t5-large"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name).to(device)

In [None]:
def add_noise_with_snr(encoder_output: torch.Tensor, target_snr_db: float) -> torch.Tensor:
    """
    Adds noise to the encoder output based on a target SNR in dB.

    Args:
        encoder_output (torch.Tensor): The encoder's output (last_hidden_state).
        target_snr_db (float): The desired signal-to-noise ratio in dB.

    Returns:
        torch.Tensor: Encoder output with added noise.
    """
    # Convert SNR from dB to linear scale
    target_snr_linear = 10 ** (target_snr_db / 10)
    
    # Calculate power of the signal
    signal_power = torch.mean(encoder_output ** 2)
    
    # Calculate required noise power for the target SNR
    noise_power = signal_power / target_snr_linear
    noise = torch.randn_like(encoder_output) * torch.sqrt(noise_power)
    
    # Add noise to the encoder output
    noisy_encoder_output = encoder_output + noise
    return noisy_encoder_output

def generate_with_embeddings(input_text: str, encoder_outputs: torch.Tensor = None, mode: str = 't5') -> (str, torch.Tensor, list):
    """
    Generates text from input and returns both generated text and decoder embeddings.

    Args:
        input_text (str): Input text for the model.
        encoder_outputs (torch.Tensor, optional): Custom encoder outputs to be fed into the decoder.
        mode (str): The mode of the model, either 't5' or 'bart'.

    Returns:
        Tuple[str, torch.Tensor, list]: Generated text, decoder embeddings for each token in the output sequence, and decoded tokens.
    """
    # Encode input text
    inputs = tokenizer(input_text, return_tensors="pt").to(device)
    
    # Pass through encoder if no custom encoder outputs provided
    if encoder_outputs is None:
        if mode == 't5':
            encoder_outputs = model.encoder(input_ids=inputs.input_ids)
        elif mode == 'bart':
            encoder_outputs = model.model.encoder(input_ids=inputs.input_ids)
        else:
            raise ValueError("Mode must be 't5' or 'bart'")
    else:
        # Copy the encoder_outputs to prevent accumulation
        encoder_outputs = BaseModelOutput(last_hidden_state=encoder_outputs.last_hidden_state.clone())
    
    # Generate with embeddings output
    outputs = model.generate(
        input_ids=None,  # Set None to use encoder_outputs
        encoder_outputs=encoder_outputs,
        output_hidden_states=True,
        return_dict_in_generate=True,
        max_length=500,
        min_length=10,
        do_sample=True,
        temperature=0.2
    )

    # Retrieve generated token IDs and decode to text
    generated_ids = outputs.sequences
    generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

    # Collect the last decoder hidden states as embeddings
    last_column = [row[-1] for row in outputs.decoder_hidden_states]
    decoder_embeddings = torch.stack(last_column).squeeze()  # squeeze() to eliminate unnecessary dimensions
    
    # Decode embeddings back to tokens
    decoded_tokens = [tokenizer.decode([token]) for token in generated_ids[0]]

    return generated_text, decoder_embeddings, decoded_tokens

def align_tensors(tensor_a: torch.Tensor, tensor_b: torch.Tensor, mode: str = "truncate") -> (torch.Tensor, torch.Tensor):
    """
    Aligns two tensors along the first dimension by either truncating or padding.

    Args:
        tensor_a (torch.Tensor): The first tensor.
        tensor_b (torch.Tensor): The second tensor.
        mode (str): The alignment mode, either "truncate" or "pad".

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: The two tensors aligned along the first dimension.
    """
    # Get the first dimension sizes
    m, n = tensor_a.size(0), tensor_b.size(0)

    if mode == "truncate":
        # Truncate to the minimum length along the first dimension
        min_rows = min(m, n)
        tensor_a = tensor_a[:min_rows]
        tensor_b = tensor_b[:min_rows]

    elif mode == "pad":
        # Pad to the maximum length along the first dimension
        max_rows = max(m, n)
        if m < max_rows:
            padding = torch.zeros((max_rows - m, *tensor_a.shape[1:]), device=tensor_a.device)
            tensor_a = torch.cat([tensor_a, padding], dim=0)
        if n < max_rows:
            padding = torch.zeros((max_rows - n, *tensor_b.shape[1:]), device=tensor_b.device)
            tensor_b = torch.cat([tensor_b, padding], dim=0)

    else:
        raise ValueError("Mode must be 'truncate' or 'pad'")

    return tensor_a, tensor_b