<a href="https://colab.research.google.com/github/gqcpm/scholar_stream/blob/main/research_agent.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
import os
!pip install langgraph langchain_core arxiv
if "COLAB_" not in "".join(os.environ.keys()):
    !pip install unsloth
else:
    # Do this only in Colab notebooks! Otherwise use pip install unsloth
    !pip install --no-deps bitsandbytes accelerate xformers==0.0.29.post3 peft trl==0.15.2 triton cut_cross_entropy unsloth_zoo
    !pip install sentencepiece protobuf datasets huggingface_hub hf_transfer
    !pip install --no-deps unsloth

In [None]:
from unsloth import FastLanguageModel
from google.colab import drive

drive.mount('/content/drive')

# 1. Load the BASE model (The big 14B one)
model, tokenizer = FastLanguageModel.from_pretrained(
    model_name = "unsloth/Qwen3-14B", # Changed from Qwen3-14B to Qwen3-8B
    max_seq_length = 1024,   # Context length - can be longer, but uses more memory
    load_in_4bit = True,     # 4bit uses much less memory
    full_finetuning = False, # We have full finetuning now!
    # token = "hf_...",      # use one if using gated models
)

# 2. Load your ADAPTERS on top (The files you just saved)
model.load_adapter("/content/drive/MyDrive/ai_models/lora_adapters")

# 3. Enable Inference Speedup
FastLanguageModel.for_inference(model)

print("Model loaded successfully from Drive!")

In [None]:
import json
import arxiv
from typing import TypedDict, List, Annotated
from langgraph.graph import StateGraph, END
from langchain_core.messages import SystemMessage, HumanMessage

# --- 1. DEFINE THE STATE ---
# This dictionary tracks the data as it moves between agents
class ResearchState(TypedDict):
    task: str               # The user's original question
    plan: List[str]         # The list of steps to research
    content: List[str]      # The raw data gathered from ArXiv
    draft: str              # The current written report
    critique: str           # Feedback from the critic
    revision_number: int    # To track iterations
    max_revisions: int      # Limit to stop infinite loops

# --- 2. HELPER: CONNECT UNSLOTH MODEL ---
# This function wraps your loaded 'model' and 'tokenizer' to work like a chat bot
def call_local_model(messages, max_tokens=1024):
    """
    Formats messages for Qwen/Unsloth and generates a response.
    """
    # Apply the specific chat template for your model (Qwen handles this well)
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )

    # Create inputs
    model_inputs = tokenizer([text], return_tensors="pt").to("cuda")

    # Generate
    generated_ids = model.generate(
        **model_inputs,
        max_new_tokens=max_tokens,
        use_cache=True
    )

    # Decode and strip the prompt (so we only get the new response)
    generated_ids = [
        output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
    ]

    response_text = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return response_text

# --- 3. DEFINE THE NODES (AGENTS) ---

def planner_node(state: ResearchState):
    print("--- üß† PLANNER IS THINKING ---")

    # Construct the prompt
    messages = [
        {"role": "system", "content": "You are a Research Planner. Return a Python list of 3 short, specific search queries related to the user's task. Example format: ['query 1', 'query 2', 'query 3']. Do not explain, just return the list."},
        {"role": "user", "content": f"Task: {state['task']}"}
    ]

    # Get response from your local model
    response = call_local_model(messages)

    # Simple parsing to ensure we get a list (Basic robustness)
    # If the model chats too much, we try to extract the list part
    try:
        # Try to find the bracketed list in the text
        import ast
        start = response.find('[')
        end = response.rfind(']') + 1
        plan = ast.literal_eval(response[start:end])
    except:
        # Fallback if model fails to output strict list
        plan = [f"{state['task']} generic analysis", f"{state['task']} method comparison"]

    return {"plan": plan}

def researcher_node(state: ResearchState):
    print("--- üïµÔ∏è RESEARCHER IS SEARCHING ARXIV ---")

    collected_content = []
    client = arxiv.Client()

    # Iterate through the plan generated by the previous node
    for query in state['plan']:
        print(f"Searching for: {query}")
        search = arxiv.Search(
            query=query,
            max_results=2, # Keep low for speed in demo
            sort_by=arxiv.SortCriterion.Relevance
        )

        for r in client.results(search):
            paper_summary = f"Title: {r.title}\nAbstract: {r.summary[:500]}..."
            collected_content.append(paper_summary)

    return {"content": collected_content}

def writer_node(state: ResearchState):
    print("--- ‚úçÔ∏è WRITER IS DRAFTING ---")

    # Combine all research into one context string
    context_str = "\n\n".join(state['content'])

    messages = [
        {"role": "system", "content": "You are a Research Analyst. Synthesize the provided research summaries into a clear, structured report."},
        {"role": "user", "content": f"Task: {state['task']}\n\nResearch Materials:\n{context_str}"}
    ]

    draft = call_local_model(messages)

    return {
        "draft": draft,
        "revision_number": state.get("revision_number", 0) + 1
    }

def critic_node(state: ResearchState):
    print("--- üßê CRITIC IS REVIEWING ---")

    messages = [
        {"role": "system", "content": "You are a strict Academic Reviewer. Check the draft. If it is high quality, reply with only the word 'APPROVE'. If it needs work, provide 1 sentence of feedback."},
        {"role": "user", "content": f"Draft: {state['draft']}"}
    ]

    critique = call_local_model(messages)
    return {"critique": critique}

def should_continue(state: ResearchState):
    critique = state.get('critique', '')
    rev_num = state.get('revision_number', 0)
    max_rev = state.get('max_revisions', 2)

    if rev_num >= max_rev:
        print("--- üõë MAX REVISIONS REACHED ---")
        return "end"

    if "APPROVE" in critique.upper():
        print("--- ‚úÖ DRAFT APPROVED ---")
        return "end"
    else:
        print("--- üîÑ LOOPING BACK TO WRITER ---")
        return "writer" # In a complex app, this might go back to researcher

# --- 4. BUILD THE GRAPH ---

workflow = StateGraph(ResearchState)

# Add Nodes
workflow.add_node("planner", planner_node)
workflow.add_node("researcher", researcher_node)
workflow.add_node("writer", writer_node)
workflow.add_node("critic", critic_node)

# Set Entry Point
workflow.set_entry_point("planner")

# Define Edges
workflow.add_edge("planner", "researcher")
workflow.add_edge("researcher", "writer")
workflow.add_edge("writer", "critic")

# Conditional Edge (The Logic Loop)
workflow.add_conditional_edges(
    "critic",
    should_continue,
    {
        "writer": "writer",  # If rejected, go back to writing (or researching)
        "end": END           # If approved, finish
    }
)

# Compile
app = workflow.compile()

print("Graph compiled! Ready to run.")

In [None]:
# Initialize the state
initial_state = {
    "task": "Compare the performance of Mamba vs Transformers in 2024",
    "max_revisions": 2,
    "revision_number": 0,
    "content": [],
    "plan": [],
    "draft": "",
    "critique": ""
}

# Variable to store the latest draft
final_draft = ""

print("Starting the Research Agent...")

# Run the graph
for output in app.stream(initial_state):
    # output looks like: {'node_name': {'key': 'value'}}

    for node_name, node_content in output.items():
        print(f"--- Finished running: {node_name} ---")

        # If this node produced a draft, save it!
        if 'draft' in node_content:
            final_draft = node_content['draft']
            print(f"Draft updated (Length: {len(final_draft)} chars)")

# Print Final Result
print("\n\n=== FINAL RESEARCH REPORT ===")
if final_draft:
    print(final_draft)
else:
    print("No draft was generated (Did the loop crash or max out?)")

In [None]:
from google.colab import drive
import json
import os

# 1. Mount Drive (if not already) to access the file
drive.mount('/content/drive')

# 2. Define a function to clean the specific 'widgets' error
def clean_notebook_metadata(file_path):
    with open(file_path, 'r', encoding='utf-8') as f:
        notebook = json.load(f)

    # Check and delete the 'widgets' key from metadata if it exists
    if 'metadata' in notebook and 'widgets' in notebook['metadata']:
        del notebook['metadata']['widgets']
        print(f"‚úÖ Cleaned widgets metadata from: {file_path}")

        # Overwrite the file with the clean version
        with open(file_path, 'w', encoding='utf-8') as f:
            json.dump(notebook, f, indent=1)
    else:
        print("NOTEBOOK IS ALREADY CLEAN (No widget metadata found).")

# 3. Apply it to your current notebook
# NOTE: Replace 'Your_Notebook_Name.ipynb' with your actual file name in Drive
# You can find the path by clicking the Folder icon on the left -> drive -> MyDrive -> Colab Notebooks
notebook_path = "/content/drive/MyDrive/Colab Notebooks/research_agent.ipynb"

if os.path.exists(notebook_path):
    clean_notebook_metadata(notebook_path)
    print("üëâ Now go to File -> Save a copy in GitHub")
else:
    print(f"‚ùå File not found at {notebook_path}. Please check the path.")