In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import sentencepiece as spm
from datasets import load_dataset


#dataset = load_dataset("ade_corpus_v2", "Ade_corpus_v2_classification")
# Extract text from the dataset
#text_data = dataset["train"]["text"]

# Define the parameters for training
#spm.SentencePieceTrainer.Train(
    #input=text_data,  # Path to your training data
    #model_prefix='your_model',  # Prefix for model files
    #vocab_size=8000,  # Adjust based on your needs
    #character_coverage=0.98,  # Adjust based on your needs
    #model_type='unigram',  # Model type (unigram, bpe, char, or word)
#)
# Define a function to load a SentencePiece model and encode text
#def load_sentencepiece_model(model_file):
    #sp = spm.SentencePieceProcessor()
    #sp.load(model_file)
    #return sp


# Sample data: Triplets (query, positive_doc, negative_doc)
sample_data = [
    ("Tell me about the Eiffel Tower.", "The Eiffel Tower is an iconic landmark in Paris.", "Apples are a type of fruit."),
    ("What is the capital of France?", "Paris is the capital of France.", "Bananas are a yellow fruit."),
    ("Explain the theory of relativity.", "The theory of relativity was developed by Albert Einstein.", "Dogs are loyal animals."),
    ("Tell me about the Mona Lisa.", "The Mona Lisa is a famous painting by Leonardo da Vinci.", "Cats are independent animals."),
    ("What is the largest planet in our solar system?", "Jupiter is the largest planet in our solar system.", "Elephants are the largest land mammals."),
    ("Who wrote the play 'Romeo and Juliet'?", "William Shakespeare wrote the play 'Romeo and Juliet'.", "Roses are a type of flower."),
    ("What is the capital of Japan?", "Tokyo is the capital of Japan.", "Fish are cold-blooded animals."),
    ("Explain the concept of photosynthesis.", "Photosynthesis is the process by which plants convert sunlight into energy.", "Owls are nocturnal birds."),
    ("Tell me about the Great Wall of China.", "The Great Wall of China is a historic fortification in China.", "Kangaroos are marsupials."),
    ("Who is the current President of the United States?", "Joe Biden is the current President of the United States.", "Tigers are large cats.")
]

# Define the queries, positive_docs, and negative_docs lists
queries = [item[0] for item in sample_data]
positive_docs = [item[1] for item in sample_data]
negative_docs = [item[2] for item in sample_data]

# Load a pre-trained SentencePiece model and tokenize the data
#sp_model = load_sentencepiece_model("your_model.model")
#queries = [sp_model.encode_as_ids(query) for query in queries]
#positive_docs = [sp_model.encode_as_ids(doc) for doc in positive_docs]
#negative_docs = [sp_model.encode_as_ids(doc) for doc in negative_docs]

vocab = {"<UNK>", "<PAD>"}

# Populate the vocab from the data
for text in positive_docs + queries + negative_docs:
    words = text.split()
    vocab.update(words)

# Create a word-to-index mapping
word_to_idx = {word: idx for idx, word in enumerate(vocab)}

# Helper function to convert text to sequence of word indices
def text_to_sequence(text):
    return [word_to_idx.get(word, word_to_idx["<UNK>"]) for word in text.split()]

# Convert paragraphs and queries into sequences of word indices
p_paragraph_sequences = [text_to_sequence(paragraph) for paragraph in positive_docs]
n_paragraph_sequences = [text_to_sequence(paragraph) for paragraph in negative_docs]
query_sequences = [text_to_sequence(query) for query in queries]


import torch.nn.functional as F

# Define the Three-Tower LSTM model
class SiameseLSTM(nn.Module):
    def __init__(self, vocab_size, embedding_dim, hidden_dim, num_layers):
        super(SiameseLSTM, self).__init__() 

        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.query_lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
        self.positive_lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
        self.negative_lstm = nn.LSTM(embedding_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True)
        self.cosine_similarity = nn.CosineSimilarity(dim=1, eps=1e-6)

    def forward(self, query, positive_doc, negative_doc):
        query_embedding = self.embedding(query)
        positive_embedding = self.embedding(positive_doc)
        negative_embedding = self.embedding(negative_doc)

        query_embedding, _= self.query_lstm(query_embedding)
        positive_embedding, _ = self.positive_lstm(positive_embedding)
        negative_embedding, _ = self.negative_lstm(negative_embedding)
        
        positive_similarity = self.cosine_similarity(query_embedding, positive_embedding)
        negative_similarity = self.cosine_similarity(query_embedding, negative_embedding)


        return positive_similarity, negative_similarity
    
            
    def contrastive_loss(self, positive_similarity, negative_similarity, margin=0.5):
        # The goal is to make positive similarity larger and negative similarity smaller
        loss = F.relu(margin - positive_similarity + negative_similarity).mean()
        return loss

        



In [2]:
# Find the maximum sequence length across all documents and queries
max_seq_length = max(max(len(seq) for seq in p_paragraph_sequences + n_paragraph_sequences), max(len(seq) for seq in query_sequences))

# Pad all sequences to this maximum length
def pad_sequence(seq, max_length):
    return seq + [0] * (max_length - len(seq))

padded_query_sequences = [pad_sequence(seq, max_seq_length) for seq in query_sequences]
padded_p_paragraph_sequences = [pad_sequence(seq, max_seq_length) for seq in p_paragraph_sequences]
padded_n_paragraph_sequences = [pad_sequence(seq, max_seq_length) for seq in n_paragraph_sequences]


In [3]:
max_seq_length

11

In [4]:
vocab_size = len(vocab) 
embedding_dim = 50
hidden_dim = 64
num_layers = 2
three_tower_model = SiameseLSTM(vocab_size, embedding_dim, hidden_dim, num_layers)

# Define an optimizer
optimizer = optim.SGD(three_tower_model.parameters(), lr=0.01)

# Training loop (you'll need to define your training data and loops)
batch_size = 3
total_loss = 0.0
total_iterations = len(query_sequences)
num_epochs = 5
for epoch in range(num_epochs):
    for i in range(len(query_sequences)):
        query = torch.tensor(padded_query_sequences[i]).unsqueeze(0)
        positive_doc = torch.tensor(padded_p_paragraph_sequences[i]).unsqueeze(0)
        negative_doc = torch.tensor(padded_n_paragraph_sequences[i]).unsqueeze(0)

        positive_similarity, negative_similarity = three_tower_model(query, positive_doc, negative_doc)
        loss = three_tower_model.contrastive_loss(positive_similarity, negative_similarity)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Optionally, you can print the mean loss for the current epoch
    mean_loss = total_loss / total_iterations
    print(f"Epoch {epoch + 1}/{num_epochs}, Mean Loss: {mean_loss}")
    total_loss = 0.0
# Reset total_loss for the next epoch


Epoch 1/5, Mean Loss: 0.5771744608879089
Epoch 2/5, Mean Loss: 0.4293012320995331
Epoch 3/5, Mean Loss: 0.3462718665599823
Epoch 4/5, Mean Loss: 0.2857722580432892
Epoch 5/5, Mean Loss: 0.2384976327419281


In [5]:
# Sample data: Three queries and three sets of relevant documents
query = "What's the theory of relativity?"

relevant_docs = [
    "Relativity was by Einstein.",
    "Theory explains spacetime.",
    "Albert discussed relativity.",
    "Physics changed with relativity.",
    "Einstein's groundbreaking theory."
]



query_sequence = text_to_sequence(query)
relevant_doc_sequences = [text_to_sequence(doc) for doc in relevant_docs]



# Pad the sequences
padded_query_sequence = pad_sequence(query_sequence, max_seq_length)
padded_relevant_doc_sequences = [pad_sequence(seq, max_seq_length) for seq in relevant_doc_sequences]

# Check lengths of sequences
for seq in padded_relevant_doc_sequences:
    print(len(seq))


11
11
11
11
11


In [14]:
# Sample data: Three queries and three sets of relevant documents
query = "What's the theory of relativity?"

relevant_docs = [
    "Relativity was by Einstein.",
    "Theory explains spacetime.",
    "Albert discussed relativity."
]


query_sequence = text_to_sequence(query)
relevant_doc_sequences = [text_to_sequence(doc) for doc in relevant_docs]



# Pad the sequences
padded_query_sequence = pad_sequence(query_sequence, max_seq_length)
padded_relevant_doc_sequences = [pad_sequence(seq, max_seq_length) for seq in relevant_doc_sequences]

with torch.no_grad():
    similarity_scores = [] 
    for doc_tensor in padded_relevant_doc_sequences:
        negative_doc = torch.tensor(padded_n_paragraph_sequences[1]).unsqueeze(0)
        query_tensor = torch.tensor(padded_query_sequence).unsqueeze(0)
        positive_tensor = torch.tensor(doc_tensor).unsqueeze(0)
        
        # Use your model to get similarity scores
        positive_similarity,_ = three_tower_model(query_tensor, positive_tensor, negative_doc)
        
        
        
        similarity_scores.append(positive_similarity.mean().item())
        
        
ranked_docs = sorted(zip(similarity_scores, relevant_docs), key=lambda x: x[0], reverse=True)

# Print the ranked documents
for rank, (score,doc) in enumerate(ranked_docs, start=1):
    print(f"Rank {rank}: Document={doc} Similarity Score={score}")


Rank 1: Document=Albert discussed relativity. Similarity Score=0.269143670797348
Rank 2: Document=Theory explains spacetime. Similarity Score=0.24998407065868378
Rank 3: Document=Relativity was by Einstein. Similarity Score=0.24208283424377441
