<a href="https://colab.research.google.com/github/jagonmoy/Creative-Requirement-Generation-with-RAG-and-MDLM/blob/main/RE_RAG_LLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q transformers sentence-transformers faiss-cpu

In [None]:
# import libraries
from transformers import BertTokenizer, BertForMaskedLM
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import torch

In [None]:
# RAG

# Load sentence embedding model for retrieval . It is required to convert sentence into vector embeddings
retriever = SentenceTransformer('all-MiniLM-L6-v2')


# Sample requirement corpus. This is just a sample in real-word scenario this will be too big. This is just for the prototype so tried to keep it simple.
docs = [
    "Users can pause subscriptions anytime.",
    "Subscribers should be able to gift their plan to friends.",
    "Users can manage billing preferences via the account dashboard.",
    "The app allows subscription rescheduling and reminders.",
    "Users can view their usage statistics over time.",
    "The system should allow exporting invoices monthly.",
    "Users can set custom payment thresholds.",
]


# Generate vector embeddings for all requirement texts using a sentence transformer.
doc_embeddings = retriever.encode(docs)

# Create a FAISS index using L2 (Euclidean) distance to enable fast similarity search.
index = faiss.IndexFlatL2(doc_embeddings.shape[1])

# Add all the requirement embeddings to the FAISS index. This makes the index ready to perform nearest-neighbor searches for any new query.
index.add(np.array(doc_embeddings))

def retrieve_context(query, top_k=2):
    # Similar to the requirements corpus encode the query as well. Cause We need to find similar requirements to the provided query
    query_vec = retriever.encode([query])

    # Get indices of top K requirements from the corpus which are similar to the query
    _, indices = index.search(np.array(query_vec), top_k)

    # return those requirements
    return [docs[i] for i in indices[0]]

In [None]:
# RAG with FLAT_T5
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM


# Load generative language model (T5)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
generator = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
# the below line removes randomness to predict similar results which is better for ouputs and makes the model lightweight to run
generator.eval()

def generate_creatives_using_FLAN_T5(base_req, n_variants=3):
    # Get the relevant context similar to the base query using RAG and construct a context string
    context = retrieve_context(base_req)
    context_str = ",".join(context)

    # Construct the full Prompt with base query and context string
    prompt = f"""Given the requirement: "{base_req}" And examples: "{context_str}" Generate a creative and useful alternative requirement."""

    # convert the text prompt to a pytorch format which the model expects, also ensure the text is cut off safely if the token length is too big
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    output = generator.generate(
        **inputs,
        max_new_tokens=50,
        num_return_sequences=n_variants,
        do_sample=True, # l randomly samples from the top tokens for creativity and variation
        top_k=50, # When sampling, only consider the top 50 most likely next tokens
        temperature=0.9, # determines how bold and riskier the sampling is
    )
    # loop over each output decode them and remove special tokens
    return [tokenizer.decode(o, skip_special_tokens=True) for o in output]

In [None]:
# Try your own requirement
base = "Users should be able to manage their subscriptions."

results = generate_creatives_using_FLAN_T5(base)
print("Creative Alternatives:\n")
for res in results:
    print("-", res)

In [None]:
# RAG with BERT
import spacy
from transformers import BertTokenizer, BertForMaskedLM
import random

# Load English tokenizer from spaCy
nlp = spacy.load("en_core_web_sm")

# Load BERT model
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
bert_model = BertForMaskedLM.from_pretrained("bert-base-uncased")
bert_model.eval()

def mask_important_word_randomized(text):
    doc = nlp(text)
    # Collect all verbs and nouns with their positions
    candidates = [(token.text, token.i) for token in doc if token.pos_ in ["VERB", "NOUN"]]

    if not candidates:
        return None  # fallback if nothing to mask

    # Randomly select one candidate to mask
    _, index_to_mask = random.choice(candidates)

    # Reconstruct the sentence with the [MASK] at the right index
    words = text.split()

    # Handle spaCy/BERT token mismatch by using index from original text
    if index_to_mask < len(words):  # safe check
        words[index_to_mask] = "[MASK]"
    else:
        return None

    return " ".join(words)


def generate_creatives_using_BERT_SMART_MASK(base_req, n_variants=3):
    # Retrieve context
    context = retrieve_context(base_req)
    context_str = " ".join(context[:2])  # Add top 2 similar requirements

    # Smart mask a key verb or noun in base_req
    masked_req = mask_important_word_randomized(base_req)
    print(masked_req)
    if not masked_req:
        return [base_req]  # Nothing to mask

    # Combine context and masked requirement
    prompt = f"{context_str} [SEP] {masked_req}"

    # Tokenize and find [MASK] position
    inputs = bert_tokenizer(prompt, return_tensors="pt", truncation=True)
    mask_index = torch.where(inputs["input_ids"] == bert_tokenizer.mask_token_id)[1]

    # Predict top replacements for [MASK]
    with torch.no_grad():
        outputs = bert_model(**inputs)
    logits = outputs.logits
    mask_logits = logits[0, mask_index, :]
    top_tokens = torch.topk(mask_logits, n_variants, dim=1).indices[0].tolist()

    # Replace [MASK] with top token predictions
    variants = []
    for token_id in top_tokens:
        predicted_word = bert_tokenizer.decode([token_id]).strip()
        filled = masked_req.replace("[MASK]", predicted_word)
        variants.append(filled)

    return variants

In [None]:
base_req = "Users can schedule appointments through the app."
results = generate_creatives_using_BERT_SMART_MASK(base_req)
for i, r in enumerate(results, 1):
    print(f"{i}. {r}")