In [None]:
# ! pip install --upgrade pip -q
# ! pip install cohere python-dotenv -q
# ! pip install qdrant-client -q
# ! pip install gtfs-realtime-bindings -q
# ! pip install pandas -q
# ! pip install tqdm ipywidgets -q
# ! pip freeze > ../requirements.txt

# Download latest MTA New York City Subway static reference data
* https://new.mta.info/developers

In [None]:
# ! rm -rf ../gtfs-reference/*
# ! curl -O http://web.mta.info/developers/data/nyct/subway/google_transit.zip
# ! unzip ./google_transit.zip -d ../gtfs-reference/
# ! rm -f ./google_transit.zip

### Load reference data into Pandas dataframes

In [None]:
import os
import pandas as pd
from pandas import DataFrame
from IPython.display import display


def load_ref(path: str, cols: list[str] = ()) -> DataFrame:
    df = pd.read_csv(path)
    if len(cols) > 0:
        df = df[cols]
    return df


base_dir = "../gtfs-reference"

# ROUTES
routes_cols: list[str] = ["route_id", "route_long_name", "route_desc"]
routes_df: DataFrame = load_ref(
    os.path.join(base_dir, "routes.txt"), routes_cols
)
pd.set_option("display.max_colwidth", None)
print(f"Routes: {len(routes_df)}")
display(routes_df.head())

# TRIPS
trips_cols: list[str] = ["route_id", "trip_id", "trip_headsign", "service_id", "direction_id"]
trips_df: DataFrame = load_ref(os.path.join(base_dir, "trips.txt"), trips_cols)
trips_df["direction_id"] = trips_df.apply(
    lambda row: "South" if row.direction_id == 1 else "North", axis=1
)
print(f"Trips: {len(trips_df)}")
display(trips_df.head())

# STOPS
stops_cols: list[str] = ["stop_id", "stop_name"]
stops_df: DataFrame = load_ref(os.path.join(base_dir, "stops.txt"), stops_cols)
print(f"Stops: {len(stops_df)}")
display(stops_df.head())

# STOP TIMES
stop_times_df: DataFrame = load_ref(os.path.join(base_dir, "stop_times.txt"))

# JOURNEYS
journeys_df = stop_times_df.merge(trips_df, on="trip_id", how="left")
journeys_df = journeys_df.merge(stops_df, on="stop_id", how="left")
journeys_df = journeys_df.merge(routes_df, on="route_id", how="left")
journeys_df = journeys_df.sort_values(by=["trip_id", "stop_sequence"])

# ROUTE 1 ON WEEKDAYS
journeys_df = journeys_df[journeys_df["service_id"] == "Weekday"]
journeys_df = journeys_df[journeys_df["route_id"] == "1"]
print(f"Journeys: {len(journeys_df)}")
display(journeys_df.head())

***
## Journey Enrichment
### Add next stop

In [None]:
trip_id = journeys_df["trip_id"]
journeys_df["next_stop_name"] = (
    journeys_df["stop_name"].shift(-1).where(trip_id.eq(trip_id.shift(-1)))
)
move = journeys_df.pop("next_stop_name")
journeys_df.insert(10, "next_stop_name", move)
display(journeys_df.head())

### Add journey times between stops

In [None]:
journeys_df["arrival_time"] = pd.to_datetime(
    journeys_df["arrival_time"], format="%H:%M:%S", errors="coerce"
)
journeys_df["departure_time"] = pd.to_datetime(
    journeys_df["departure_time"], format="%H:%M:%S", errors="coerce"
)

arr_dt = journeys_df["arrival_time"]
dep_dt = journeys_df["departure_time"]
trip_id = journeys_df["trip_id"]
journeys_df["journey_time"] = arr_dt - dep_dt.shift().where(trip_id.eq(trip_id.shift()))

journeys_df["arrival_time"] = journeys_df["arrival_time"].dt.time
journeys_df["departure_time"] = journeys_df["departure_time"].dt.time
journeys_df["journey_time"] = journeys_df["journey_time"].dt.seconds
move = journeys_df.pop("journey_time")
journeys_df.insert(4, "journey_time", move)

print(f"Journeys: {len(journeys_df)}")
display(journeys_df.head())

### Add text descriptions

In [None]:
from pandas import Series


def row2text(row: Series) -> str:
    """Returns a text representation of a journeys row."""
    text = []
    text.append(
        'Route {} "{}" travelling {} will arrive at stop "{}" at {}.'.format(
            row.route_id,
            row.route_long_name,
            str(row.direction_id).lower(),
            row.stop_name,
            row.arrival_time,
        )
    )
    if not pd.isnull(row.next_stop_name):
        text.append(f" The next stop is {row.next_stop_name}.")
    return "".join(text)


journeys_df["text"] = journeys_df.apply(row2text, axis=1)

pd.set_option("display.max_colwidth", None)
display(journeys_df["text"].head())

***
# Routes and stops lookup

In [None]:
from typing import Optional
from datetime import datetime, timedelta


def get_trips(route: str, ts: datetime, direction: Optional[str] = None) -> list[str]:
    """Returns a list of trips that are active over the next 15 minutes.
    Optionally filter by direction of travel."""

    df = journeys_df[["trip_id", "route_id", "direction_id", "arrival_time"]]
    df = df[df["route_id"] == route]
    if direction in ["North", "South"]:
        df = df[df["direction_id"] == direction]
    start = ts.strftime("%H:%M:%S")
    end = (ts + timedelta(minutes=15)).strftime("%H:%M:%S")
    start_mask = df["arrival_time"] >= pd.Timestamp(start).time()
    end_mask = df["arrival_time"] <= pd.Timestamp(end).time()
    df = df.loc[start_mask]
    df = df.loc[end_mask]
    df = df[["trip_id"]]
    df = df.drop_duplicates()
    trips = [r.trip_id for _, r in df.iterrows()]
    return trips


def get_stops(trip: str, from_stop: Optional[str] = None) -> dict:
    """Returns a text summary of a trip, including a list of stops and scheduled arrival times,
    optionally filtered from a given stop."""

    df = journeys_df[
        [
            "trip_id",
            "trip_headsign",
            "route_id",
            "direction_id",
            "stop_id",
            "stop_name",
            "arrival_time",
        ]
    ]
    df = df[df["trip_id"] == trip]
    stops_at = []
    for _, row in df.iterrows():
        if from_stop:
            if from_stop in (row["stop_id"], row["stop_name"]):
                stops_at.clear()
        stops_at.append(
            '"{}" ({})'.format(
                row["stop_name"],
                row["arrival_time"].strftime("%H:%M:%S"),
            )
        )
    res = {"Route": df["route_id"].iloc[0]}
    res["text"] = 'Route {} travelling {} with the headsign "{}" stops at, {}'.format(
        df["route_id"].iloc[0],
        str(df["direction_id"].iloc[0]).lower(),
        df["trip_headsign"].iloc[0],
        ", ".join(stops_at),
    )
    return res


# Test
t = get_trips("1", datetime.now(), "South")[0]
s = get_stops(t, "50 St")
print(s)

***
# Generate embeddings and populate QDrant collection: `mta-gtfs-timetable`

* QDrant Web UI: http://localhost:6333/dashboard
* QDrant REST API: http://localhost:6333
* QDrant GRPC API: localhost:6334

In [None]:
%%html
<style>
.cell-output-ipywidget-background {
    background-color: transparent !important;
}
:root {
    --jp-widgets-color: var(--vscode-editor-foreground);
    --jp-widgets-font-size: var(--vscode-editor-font-size);
}  
</style>

In [None]:
from datetime import date, datetime, timedelta

# ------------------------------------
# Filter journeys for the next 3 hours
# ------------------------------------
start = datetime.now().strftime("%H:%M:%S")
end = (datetime.now() + timedelta(hours=3)).strftime("%H:%M:%S")
start_mask = journeys_df["arrival_time"] >= pd.Timestamp(start).time()
end_mask = journeys_df["arrival_time"] <= pd.Timestamp(end).time()
journeys_df = journeys_df.loc[start_mask]
journeys_df = journeys_df.loc[end_mask]
print(f"Number of journeys between {start} and {end}: {len(journeys_df)}")
# ------------------------------------

In [None]:
import cohere
from dotenv import load_dotenv
from tqdm.notebook import tqdm
from qdrant_client import QdrantClient
from qdrant_client.models import Batch, Distance, VectorParams
from datetime import date, datetime, timedelta

load_dotenv()
collection = os.getenv("QDRANT_COLLECTION")
cohere_key = os.getenv("COHERE_API_KEY")

co = cohere.Client(cohere_key)
client = QdrantClient(path="../qdrant")
# client = QdrantClient(url="http://localhost:6333")

buf_max: int = 96
id_buf: list[int] = []
text_buf: list[str] = []
meta_buf: list[dict] = []

try:
    # Recreate collection
    client.delete_collection(collection)
    client.create_collection(
        collection_name=collection,
        vectors_config=VectorParams(size=1024, distance=Distance.COSINE),
    )

    for id, row in tqdm(journeys_df.iterrows(), total=journeys_df.shape[0]):
        id_buf.append(int(id))
        text_buf.append(row.text)
        meta_buf.append(
            {
                "trip_id": row.trip_id,
                "stop_id": row.stop_id,
                "arrival_time": datetime.combine(date.min, row.arrival_time),
                "text": row.text,
            }
        )
        # Cohere embed max number of text per api call is 96
        if len(text_buf) == buf_max:
            # Generating the embeddings
            embeddings = co.embed(
                model="embed-english-v3.0",  # 1024
                input_type="search_document",
                texts=text_buf,
            ).embeddings
            # Insert into vector store
            client.upsert(
                collection_name=collection,
                points=Batch(ids=id_buf, vectors=embeddings, payloads=meta_buf),
            )
            id_buf.clear()
            text_buf.clear()
            meta_buf.clear()
finally:
    client.close()

***
# Test semantic search

In [None]:
from cohere import ChatDocument
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, DatetimeRange, ScoredPoint

load_dotenv()
collection = os.getenv("QDRANT_COLLECTION")
cohere_key = os.getenv("COHERE_API_KEY")

co = cohere.Client(cohere_key)
client = QdrantClient(path="../qdrant")
# client = QdrantClient(url="http://localhost:6333")


def retrieve(client: QdrantClient, query: str) -> list[ChatDocument]:
    """Retrieves similar texts from the QDrant collection. Applies a time range
    filter to narrow the results down for the next 15 minutes.
    """
    time_now = datetime.combine(date.min, datetime.now().time())
    time_plus = time_now + timedelta(minutes=15)
    filter = Filter(
        must=[
            FieldCondition(
                key="arrival_time",
                range=DatetimeRange(
                    gte=time_now.isoformat(),
                    lte=time_plus.isoformat(),
                ),
            )
        ]
    )
    results = client.search(
        collection_name=collection,
        query_vector=co.embed(
            model="embed-english-v3.0",
            input_type="search_query",
            texts=[query],
        ).embeddings[0],
        query_filter=filter,
        limit=5,
    )
    results = sorted(results, key=lambda x: x.payload.get("arrival_time"))
    docs = []
    for result in results:
        # Strip date from datetime
        ts = result.payload.get("arrival_time")
        if ts:
            ts = datetime.strptime(ts, "%Y-%m-%dT%H:%M:%S")
            ts = ts.strftime("%H:%M:%S")
        docs.append({"arrival_time": ts, "text": result.payload.get("text")})

        # # Fetch future stops for trip
        trip = result.payload.get("trip_id")
        stop = result.payload.get("stop_id")
        trip_stops = get_stops(trip, stop)
        docs.append(trip_stops)
    return docs


q = f"What is the next train to arrive at 50 St?"
print(f"Question: {q} \n")
try:
    results = retrieve(client, q)
    for r in results:
        print(f"{r}")
finally:
    client.close()

***
# Cohere MTA GTFS Chat

In [None]:
load_dotenv()
co = cohere.Client(os.getenv("COHERE_API_KEY"))
client = QdrantClient(path="../qdrant")
# client = QdrantClient(url="http://localhost:6333")
conversation_id = "3"

# Questions:
#   What is the next train to arrive at 50 St?
#   Does the train travelling south stop at Franklin St?
#   What is the headsign on this train?

try:
    while True:
        human_message = input("User: ")
        if human_message.lower() == "quit":
            print("Ending chat.")
            break
        ai_response = co.chat(
            message=human_message,
            model="command-r",
            search_queries_only=True,
        )
        if ai_response.search_queries:
            print("Retrieving information...", end="")
            # Retrieve similar documents from vector store
            documents = []
            for query in ai_response.search_queries:
                documents.extend(retrieve(client, query.text))
            ai_response = co.chat_stream(
                message=human_message,
                model="command-r",
                documents=documents,
                conversation_id=conversation_id,
            )
        else:
            ai_response = co.chat_stream(
                message=human_message,
                model="command-r",
                conversation_id=conversation_id,
            )

        print("\nChatbot:")
        for event in ai_response:
            if event.event_type == "text-generation":
                print(event.text, end="")
        print(f"\n{'-'*75}\n")
finally:
    client.close()