In [6]:
import torch
import torch.nn.functional as F

# Función para calcular la atención
def attention_mechanism(queries, keys, values):
    # Calcular similitudes usando distancia coseno
    sim_scores = F.cosine_similarity(queries.unsqueeze(1), keys.unsqueeze(0), dim=2)
    
    # Aplicar softmax para obtener pesos de atención
    attention_weights = F.softmax(sim_scores, dim=1)
    
    # Calcular la salida ponderada
    output = torch.bmm(attention_weights.unsqueeze(1), values.unsqueeze(0)).squeeze(1)
    return output, attention_weights

# Representaciones de palabras (ejemplo simplificado)
# En un caso real, usarías embeddings preentrenados (Word2Vec, GloVe, etc.)
word_vectors = {
    "happy": torch.tensor([1.0, 0.0, 0.0]),
    "joyful": torch.tensor([1.0, 0.0, 0.1]),
    "sad": torch.tensor([0.0, 1.0, 0.0]),
    "unhappy": torch.tensor([0.0, 1.0, 0.1])
}

# Consultas, Claves y Valores
queries = word_vectors["happy"].unsqueeze(0)  # Consulta: "happy"
keys = torch.stack([word_vectors["joyful"], word_vectors["sad"], word_vectors["unhappy"]])  # Claves
values = torch.stack([word_vectors["joyful"], word_vectors["sad"], word_vectors["unhappy"]])  # Valores

# Calcular la salida y los pesos de atención
output, attention_weights = attention_mechanism(queries, keys, values)

# Resultados
print("Output (Weighted Sum):", output)
print("Attention Weights:\n", attention_weights)


Output (Weighted Sum): tensor([[0.5749, 0.4251, 0.0787]])
Attention Weights:
 tensor([[0.5749, 0.2125, 0.2125]])
