In [1]:
import numpy as np
import pandas as pd
from neo4j import GraphDatabase
from typing import List, Dict
import logging
import streamlit as st

In [2]:
class CourseDatabase:
    def __init__(self, uri: str, username: str, password: str):
        """Initialize connection to Neo4j database."""
        self.driver = GraphDatabase.driver(uri, auth=(username, password))

    def close(self):
        """Close the database connection."""
        self.driver.close()

    def add_course(self, code: str, name: str) -> None:
        """Add a new course to the database."""
        with self.driver.session() as session:
            session.execute_write(self._create_course, code, name)

    @staticmethod
    def _create_course(tx, code: str, name: str):
        query = """
        MERGE (c:Course {code: $code})
        SET c.name = $name
        RETURN c
        """
        result = tx.run(query, code=code, name=name)
        return result.single()

    def add_prerequisite(self, course_code: str, prereq_code: str) -> None:
        """Add a prerequisite relationship between courses."""
        with self.driver.session() as session:
            session.execute_write(self._create_prerequisite, course_code, prereq_code)

    @staticmethod
    def _create_prerequisite(tx, course_code: str, prereq_code: str):
        query = """
        MATCH (c:Course {code: $course_code})
        MATCH (p:Course {code: $prereq_code})
        MERGE (c)-[:REQUIRES]->(p)
        """
        tx.run(query, course_code=course_code, prereq_code=prereq_code)

    def get_prerequisites(self, course_code: str, recursive: bool = True) -> List[Dict]:
        """Get prerequisites for a course. If recursive=True, get all indirect prerequisites as well."""
        with self.driver.session() as session:
            return session.execute_read(self._get_prerequisites, course_code, recursive)

    @staticmethod
    def _get_prerequisites(tx, course_code: str, recursive: bool):
        depth = "*" if recursive else "1"
        query = f"""
        MATCH (c:Course {{code: $code}})-[:REQUIRES{depth}]->(p:Course)
        RETURN p.code as code, p.name as name
        """
        result = tx.run(query, code=course_code)
        return [dict(record) for record in result]

    def get_available_courses(self, completed_courses: List[str]) -> List[Dict]:
        """Get courses where all prerequisites have been completed."""
        with self.driver.session() as session:
            return session.execute_read(self._get_available_courses, completed_courses)

    @staticmethod
    def _get_available_courses(tx, completed_courses: List[str]):
        query = """
        MATCH (c:Course)
        WHERE (
            // Either has no prerequisites
            NOT EXISTS((c)-[:REQUIRES]->(:Course))
            OR 
            // Or all prerequisites are completed
            ALL(prereq IN [(c)-[:REQUIRES]->(p) | p.code] 
                WHERE prereq IN $completed_courses)
        )
        // Exclude courses that are already completed
        AND NOT c.code IN $completed_courses
        RETURN c.code as code, c.name as name
        """
        result = tx.run(query, completed_courses=completed_courses)
        return [dict(record) for record in result]

In [4]:
uri = st.secrets["uri"]
username = st.secrets["username"]
password = st.secrets["password"]

db = CourseDatabase(
    uri=uri,
    username=username,
    password=password
)

In [5]:
courses = [
            ("CS101", "Intro to Programming"),
            ("CS201", "Data Structures"),
            ("CS301", "Algorithms"),
            ("MATH101", "Calculus I"),
            ("MATH201", "Calculus II")
        ]
        
for code, name in courses:
    db.add_course(code, name)

In [6]:
prerequisites = [
            ("CS201", "CS101"),
            ("CS301", "CS201"),
            ("MATH201", "MATH101"),
            ("CS301", "MATH201")
        ]
        
for course, prereq in prerequisites:
    db.add_prerequisite(course, prereq)

In [8]:
print("Prerequisites for CS301:")
prereqs = db.get_prerequisites("CS301")
for prereq in prereqs:
    print(f"- {prereq['code']}: {prereq['name']}")

Prerequisites for CS301:
- MATH201: Calculus II
- MATH101: Calculus I
- CS201: Data Structures
- CS101: Intro to Programming


In [9]:
print("Available courses (completed: CS101, MATH101):")
available = db.get_available_courses(["CS101", "MATH101"])
for course in available:
    print(f"- {course['code']}: {course['name']}")

Available courses (completed: CS101, MATH101):
- CS201: Data Structures
- MATH201: Calculus II


In [10]:
db.close()