In [1]:
# Step 1: Install and Import Required Libraries

!pip install torch
!pip install keras keras-nlp huggingface-hub tensorflow sentence-transformers


[0m

In [2]:
import os
from google.colab import userdata

# Note: `userdata.get` is a Colab API. If you're not using Colab, set the env
# vars as appropriate for your system.

os.environ["KAGGLE_USERNAME"] = userdata.get('KAGGLE_USERNAME')
os.environ["KAGGLE_KEY"] = userdata.get('KAGGLE_KEY')

In [3]:
# Install Keras 3 last. See https://keras.io/getting_started/ for more details.
!pip install -q -U keras-nlp
!pip install -q -U "keras>=3"

[0m

In [4]:
os.environ["KERAS_BACKEND"] = "jax"  # Or "torch" or "tensorflow".
# Avoid memory fragmentation on JAX backend.
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]="1.00"

In [5]:
import keras
import keras_nlp
import tensorflow as tf

In [6]:
import json
import os
import numpy as np
import torch
from sentence_transformers import SentenceTransformer, util
import keras_nlp
import keras
import tensorflow as tf

# Step 2: Load and Preprocess the Dataset
data = []
with open("/content/formatted_data.jsonl", "r") as file:
    for line in file:
        entry = json.loads(line)
        data.append({
            "instruction": entry["instruction"],
            "response": entry["response"],
            "combined": entry["instruction"] + " " + entry["response"]
        })

# Step 3: Set Up Semantic Retrieval Using Embeddings
embedder = SentenceTransformer("all-MiniLM-L6-v2")

knowledge_embeddings = embedder.encode([entry["combined"] for entry in data], convert_to_tensor=True)

def retrieve_documents(prompt, num_docs=5, keyword="Lymphocytic Choriomeningitis"):
    filtered_data = [entry for entry in data if keyword in entry["combined"]]
    filtered_embeddings = embedder.encode([entry["combined"] for entry in filtered_data], convert_to_tensor=True)

    prompt_embedding = embedder.encode(prompt, convert_to_tensor=True)
    scores = util.pytorch_cos_sim(prompt_embedding, filtered_embeddings)[0]
    top_results = torch.topk(scores, k=num_docs)

    retrieved_docs = [filtered_data[idx]["combined"] for idx in top_results.indices]
    return retrieved_docs

# Step 4: Load and Configure the Gemma Model for Fine-tuning with LoRA
os.environ["KERAS_BACKEND"] = "jax"
gemma_lm = keras_nlp.models.GemmaCausalLM.from_preset("gemma2_2b_en")
gemma_lm.backbone.enable_lora(rank=4)

optimizer = keras.optimizers.AdamW(learning_rate=5e-5, weight_decay=0.01)
optimizer.exclude_from_weight_decay(var_names=["bias", "scale"])

gemma_lm.compile(
    loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    optimizer=optimizer,
    weighted_metrics=[keras.metrics.SparseCategoricalAccuracy()],
)

# Step 5: Define the Generation Function with Enhanced Filtering

def prune_response(response_text):
    # Remove repetitive phrases and keep only relevant information
    relevant_lines = []
    for line in response_text.splitlines():
        if "LCMV" in line or "Lymphocytic Choriomeningitis" in line:
            relevant_lines.append(line)
    return " ".join(relevant_lines[:4])  # Limit to the first 4 relevant lines for conciseness

def generate_text(prompt):
    retrieved_docs = retrieve_documents(prompt)
    context = " ".join(retrieved_docs)
    full_prompt = f"{context}\nInstruction: Provide concise and relevant information on LCMV only.\nPrompt:\n{prompt}\n\nResponse:\n"

    sampler = keras_nlp.samplers.TopKSampler(k=5, seed=2)
    gemma_lm.compile(sampler=sampler)
    generated_text = gemma_lm.generate(full_prompt, max_length=256)

    # Prune irrelevant or redundant content
    pruned_response = prune_response(generated_text)

    if not pruned_response:
        pruned_response = "The response does not contain relevant LCMV information."

    return pruned_response

# Step 6: Define Questions and Generate Responses for Each

questions = [
    "What common risk factors for Lymphocytic Choriomeningitis (LCMV) should be highlighted in patient education materials?",
    "What are the primary diagnostic steps for LCMV, and what challenges may arise in accurately diagnosing it?",
    "What diagnostic tests are most effective in early detection of viral infections with neurological symptoms?"
]

# Generate and Display Responses for Each Question
for question in questions:
    print(f"Question: {question}")
    response = generate_text(question)
    print(f"Response: {response}\n")


Question: What common risk factors for Lymphocytic Choriomeningitis (LCMV) should be highlighted in patient education materials?
Response: Who is at risk for Lymphocytic Choriomeningitis (LCM)? ? Individuals of all ages who come into contact with urine, feces, saliva, or blood of wild mice are potentially at risk for infection. Owners of pet mice or hamsters may be at risk for infection if these animals originate from colonies that were contaminated with LCMV, or if their animals are infected from other wild mice. Human fetuses are at risk of acquiring infection vertically from an infected mother.  Laboratory workers who work with the virus or handle infected animals are also at risk. However, this risk can be minimized by utilizing animals from sources that regularly test for the virus, wearing proper protective laboratory gear, and following appropriate safety precautions. Who is at risk for Lymphocytic Choriomeningitis (LCM)? ? LCMV infections can occur after exposure to fresh urine