In [2]:
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.vectorstores import Chroma
from langchain.text_splitter import CharacterTextSplitter
from langchain import OpenAI, VectorDBQA

from dotenv import load_dotenv
import os

load_dotenv()
API_KEY = os.getenv('OPENAI_API_KEY')
os.environ["OPENAI_API_KEY"] = API_KEY

embeddings = OpenAIEmbeddings()

In [3]:
llm = OpenAI(temperature=0)

In [4]:
from langchain.vectorstores import Chroma
from langchain.embeddings.openai import OpenAIEmbeddings

embeddings = OpenAIEmbeddings()
vectordb = Chroma(persist_directory="../../data/dev/cs_courses_vectorstore", embedding_function=embeddings)

from langchain.llms import OpenAI
from langchain.retrievers.self_query.base import SelfQueryRetriever
from langchain.chains.query_constructor.base import AttributeInfo

metadata_field_info=[
    AttributeInfo(
        name="department",
        description="The academic department", 
        type="string"
    ),
    AttributeInfo(
        name="course_number",
        description="The number that comes after the department in the course name", 
        type="string"
    ),
    AttributeInfo(
        name="year",
        description="The year this course is being offered", 
        type="string"
    ),
    AttributeInfo(
        name="semester",
        description="The semester this course is being offered (Fall or Spring)", 
        type="string"
    )
]

document_content_description = "Information about a college department's courses and sections of each course"

Using embedded DuckDB with persistence: data will be stored in: ../../data/dev/cs_courses_vectorstore


In [5]:
llm = OpenAI(temperature=0)
retriever = SelfQueryRetriever.from_llm(llm, vectordb, document_content_description, metadata_field_info, verbose=True, kwargs={"k": 6})
retriever.search_kwargs["k"] = 6
retriever.search_type = "mmr"

In [6]:
from pydantic import BaseModel, Field, root_validator
from typing import Any, Dict, Type
from langchain.tools import BaseTool, StructuredTool, Tool, tool


_VALID_DEPARTMENTS = {
    "CSCI",
    "MATH",
}

class CourseSearchInputSchema(BaseModel):
    department: str = Field(description="should be a Williams College department prefix")
    
    @root_validator
    def validate_query(cls, values: Dict[str, Any]) -> Dict:
        department = values["department"]
        if department not in _VALID_DEPARTMENTS:
            raise ValueError(f"Department {department} is not on the approved list:"
                             f" {sorted(_VALID_DEPARTMENTS)}")
        return values


class CourseSearchTool(BaseTool):
    name = "course_search"
    description = "useful for when you need to answer questions about specific courses offered by a department"
    args_schema: Type[CourseSearchInputSchema] = CourseSearchInputSchema

    def _run(self, query: str) -> str:
        """Use the tool."""
        return str(retriever.get_relevant_documents(query))
    
    async def _arun(self, query: str) -> str:
        """Use the tool asynchronously."""
        raise NotImplementedError("custom_search does not support async")

In [19]:
tool = CourseSearchTool()
tool._run(query="Tell me about machine learning")

query='machine learning' filter=None


'[Document(page_content=\'CSCI 374: Machine Learning\\n Course description: Machine learning is a field that derives from artificial intelligence and statistics, and is concerned with the design and analysis of computer algorithms that "learn" automatically through the use of data. Computer algorithms are capable of discerning subtle patterns and structure in the data that would be practically impossible for a human to find. As a result, real-world decisions, such as treatment options and loan approvals, are being increasingly automated based on predictions or factual knowledge derived from such algorithms. This course explores topics in supervised learning (e.g., random forests and neural networks), unsupervised learning (e.g., k-means clustering and expectation maximization), and possibly reinforcement learning (e.g., Q-learning and temporal difference learning.) It will also introduce methods for the evaluation of learning algorithms (with an emphasis on analysis of generalizability

In [7]:
tool = CourseSearchTool()
tool._run(query="Which courses does Rohit teach?")

query='Rohit' filter=None


'[Document(page_content="CSCI 374: Machine Learning, Section 02\\n Year: 2024, Semester: Fall, Section type: in-person, Class type: Lecture, Meetings: [{\'days\': \'MR\', \'start\': \'14:35\', \'end\': \'15:50\', \'facility\': \'\'}], Instructors: [\'Rohit Bhattacharya\']", metadata={\'source_url\': \'https://catalog.williams.edu/csci/detail/?strm=&cn=374&crsid=017427&req_year=0\', \'department: \': \'CSCI\', \'course_number\': 374, \'year\': 2024, \'semester\': \'Fall\'}), Document(page_content="CSCI 374: Machine Learning, Section 02\\n Year: 2024, Semester: Fall, Section type: in-person, Class type: Lecture, Meetings: [{\'days\': \'MR\', \'start\': \'14:35\', \'end\': \'15:50\', \'facility\': \'\'}], Instructors: [\'Rohit Bhattacharya\']", metadata={\'source_url\': \'https://catalog.williams.edu/csci/detail/?strm=&cn=374&crsid=017427&req_year=0\', \'department\': \'CSCI\', \'course_number\': 374, \'year\': 2024, \'semester\': \'Fall\'}), Document(page_content="CSCI 374: Machine Learn

In [8]:
retriever.get_relevant_documents("Which comp sci courses will Rohit teach in the fall?")

query='Rohit' filter=Operation(operator=<Operator.AND: 'and'>, arguments=[Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='department', value='comp sci'), Comparison(comparator=<Comparator.EQ: 'eq'>, attribute='semester', value='fall')])


NoDatapointsException: No datapoints found for the supplied filter {"$and": [{"department": {"$eq": "comp sci"}}, {"semester": {"$eq": "fall"}}]}