<a href="https://colab.research.google.com/github/jagonmoy/Creative-Requirement-Generation-with-RAG-and-MDLM/blob/main/RE_RAG_MDLM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q transformers sentence-transformers faiss-cpu

In [None]:
# import libraries
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import torch

In [None]:
# Load sentence embedding model for retrieval
retriever = SentenceTransformer('all-MiniLM-L6-v2')

In [None]:
# Load generative language model (T5)
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-base")
generator = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-base")
generator.eval()

In [None]:
# Sample requirement corpus (expand this as needed)
docs = [
    "Users can pause subscriptions anytime.",
    "Subscribers should be able to gift their plan to friends.",
    "Users can manage billing preferences via the account dashboard.",
    "The app allows subscription rescheduling and reminders.",
    "Users can view their usage statistics over time.",
    "The system should allow exporting invoices monthly.",
    "Users can set custom payment thresholds.",
]

In [None]:
# Generate vector embeddings
doc_embeddings = retriever.encode(docs)
index = faiss.IndexFlatL2(doc_embeddings.shape[1])
index.add(np.array(doc_embeddings))

In [None]:
def retrieve_context(query, top_k=2):
    query_vec = retriever.encode([query])
    _, indices = index.search(np.array(query_vec), top_k)
    return [docs[i] for i in indices[0]]

In [None]:
def generate_creatives(base_req, n_variants=3):
    context = retrieve_context(base_req)
    context_str = " | ".join(context)

    prompt = f"""Given the requirement: "{base_req}" And examples: "{context_str}" Generate a creative and useful alternative requirement."""

    inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
    output = generator.generate(
        **inputs,
        max_new_tokens=50,
        num_return_sequences=n_variants,
        do_sample=True,
        top_k=50,
        temperature=0.9,
    )

    return [tokenizer.decode(o, skip_special_tokens=True) for o in output]

In [None]:
# Try your own requirement
base = "Users should be able to manage their subscriptions."

results = generate_creatives(base)
print("Creative Alternatives:\n")
for res in results:
    print("-", res)