# Integrated RAG-based Story Generation with IBM's ALLAM and ChromaDB

This notebook combines data embedding, ChromaDB storage, and IBM's ALLAM model for retrieval-augmented generation (RAG). It organizes the process as follows:
- Embedding children's stories and storing them in ChromaDB for efficient retrieval.
- Using IBM's ALLAM model to generate stories enhanced by retrieved, relevant story chunks or through prompting.

## Steps
1. Install libraries.
2. Embed and store data in ChromaDB.
3. Retrieve and store stories to ChromaDB.
4. Setup IBM ALLAM model.
5. RAG-based story generation or Prompt Engineering generation.
6. Testing Functions.

### Step 1: Install Required Libraries

In [None]:
# Install libraries for ChromaDB, Firebase, IBM's ALLAM, and LangChain
!pip install langchain chromadb requests ibm-watsonx-ai firebase-admin

In [None]:
# Import necessary libraries
from langchain import PromptTemplate
from ibm_watsonx_ai.foundation_models import ModelInference
from chromadb import Client
from firebase_admin import credentials, firestore
from chromadb.config import Settings

import os
import json
import firebase_admin
import getpass

### Step 2: Data Embedding and Storage in ChromaDB

In [None]:


# Initialize Firebase
cred = credentials.Certificate('path/to/firebase_credentials.json')  # Adjust with actual path
firebase_admin.initialize_app(cred)
db = firestore.client()

# Initialize ChromaDB client
chroma_client = Client(Settings())
collection = chroma_client.get_or_create_collection(name='story_chunks')


### Step 3: Retrieve and store stories to ChromaDB
- Retrieve stories from Firebase.
- Process, chunk, and embed each story segment.
- Store the chunks along with embeddings and metadata in ChromaDB.

In [None]:
def retrieve_and_store_stories():
    # Retrieve stories from Firebase, chunk them, and store in ChromaDB
    documents, embeddings, metadata, ids = [], [], [], []
    stories_ref = db.collection('stories')
    stories = stories_ref.stream()

    for idx, story_doc in enumerate(stories, start=1):
        story_data = story_doc.to_dict().get('data', [])
        metadata = {}
        story_text = ''

        # Process each item in the story data
        for item in story_data:
            if item.get('prompt') == 'القصة:':
                story_text = item.get('completion', '')
            else:
                metadata[item.get('prompt', '')] = item.get('completion', '')

        if story_text:
            chunks = re.findall(r'(<[^>]+>.*?</[^>]+>)', story_text, flags=re.DOTALL)
            for chunk_idx, chunk in enumerate(chunks):
                # Generate embedding for each chunk
                embedding = model.encode(chunk)
                documents.append(chunk)
                embeddings.append(embedding)
                ids.append(f'id_{idx}_{chunk_idx}')
                metadata.append({**metadata, 'story_id': idx, 'chunk_id': f'id_{idx}_{chunk_idx}'})

    # Store in ChromaDB
    collection.add(documents=documents, embeddings=embeddings, metadatas=metadata, ids=ids)
    print('Data stored in ChromaDB.')


### Step 4: Initialize IBM ALLAM Model

In [None]:

def get_credentials():
    # Function to retrieve API credentials
    return {
        'url': 'https://eu-de.ml.cloud.ibm.com',
        'apikey': getpass.getpass("n_mMMf-M68fePv4tTMw5OgipRtv4tB1HLdMeOaf16GRL")
    }

# Define model parameters
model_id = "sdaia/allam-1-13b-instruct"  # Defining model ID
project_id = "07954e26-b0dd-45e4-a22f-0ae8e0a8593a"  # Defining project ID
parameters = {'decoding_method': 'greedy', 'max_new_tokens': 3024, 'repetition_penalty': 1.2}

# Instantiate the ALLAM model
model = ModelInference(
    model_id=model_id,
    params=parameters,
    credentials=get_credentials(),
    project_id=project_id
)
print('ALLAM model initialized.')

### Step 5: RAG-based Story Generation

In [None]:
def generate_story_with_retrieval(characterType, characterCount, characters, storyLocation, storyMoral, otherThings):
    """
    Retrieves relevant story chunks from ChromaDB, formats a structured prompt, 
    and generates a story using IBM's ALLAM model.
    """
    # Construct retrieval prompt based on child’s input details
    retrieval_prompt = (
        f"A story with {characterType}, set in {storyLocation}, focusing on the theme of {storyMoral} and {otherThings}."
    )
    query_embedding = allam_model.embed_text(retrieval_prompt)

    # Retrieve relevant story chunks from ChromaDB
    results = collection.query(
        query_embeddings=[query_embedding],
        n_results=3
    )
    retrieved_content = " ".join([doc for doc in results["documents"][0]])

    # Define the structured prompt template
    prompt_template = PromptTemplate(
        input_variables=["retrieved_content", "characterType", "characterCount", "characters", "storyLocation", "storyMoral", "otherThings"],
        template="""
        Based on the following input details:
        - Character Type: {characterType}
        - Number of characters: {characterCount}
        - Name of Characters: {characters}
        - Setting: {storyLocation}
        - Main Theme: {storyMoral}
        - More details: {otherThings}
        - Retrieved Content: {retrieved_content}

        Use the details provided to write a short Arabic story for children aged 8-11, emphasizing Arabic culture. 
        Ensure it has a clear beginning, middle, and end, and conclude with a unique ending type.
        """
    )

    # Format the final prompt using PromptTemplate
        formatted_prompt = prompt_template.format(
        retrieved_content=retrieved_content,
        characterType = characterType,
        characterCount = characterCount,
        characters = characters,
        storyLocation = storyLocation,
        storyMoral = storyMoral,
        otherThings = otherThings
    )

    # Generate story using IBM ALLAM model
    print("Submitting generation request to IBM ALLAM...")
    story_response = allam_model.generate(formatted_prompt)
    return story_response['text']


### or Step 5: Prompt Engineering Story Generation

In [None]:
def generate_story_with_prompt_engineering(characterType, characterCount, characters, storyLocation, storyMoral, otherThings):
    """
    Generates a story using prompt engineering based on user input.
    This function formats a structured prompt and sends it to IBM's ALLAM model.
    """
    
    # Define the prompt template based on input details
    prompt = f"""
    Write a short Arabic children's story suitable for ages 8-11 with the following details:
    
    - **Character Type**: {characterType}
    - **Number of Characters**: {characterCount}
    - **Character Names**: {characters}
    - **Setting**: {storyLocation}
    - **Main Theme or Moral**: {story_moral}
    - **Additional Elements**: {other_things}
    
    Make sure the story has a clear beginning, middle, and end. Emphasize Arabic culture in a simple and age-appropriate way.
    Conclude the story with a unique ending type, such as a happy ending, a surprise, or a moral lesson.
    """
    
    # Send the prompt to IBM's ALLAM model
    print("Submitting generation request to IBM ALLAM...")
    story_response = allam_model.generate(prompt)
    
    # Extract and return the generated story text
    return story_response['text']


## Testing Functions

In [None]:
# Test the generate_story_with RAG
test_story = generate_story_with_retrieval(
    character_type="animals and humans",
    character_count="3",
    characters="Ali, Lina, and a talking owl",
    story_location="in a village and forest",
    story_moral="importance of honesty",
    other_things="magical events and a surprise ending"
)

print("Generated Story:", test_story)


In [None]:
# Test the generate_story_with prompt engineering
test_story = generate_story_with_prompt_engineering(
    character_type="animals and humans",
    character_count="3",
    characters="Ali, Lina, and a talking owl",
    story_location="in a village and forest",
    story_moral="importance of honesty",
    other_things="magical events and a surprise ending"
)

print("Generated Story:", test_story)