# Produce NY taxi trips to Kafka

## Setup

In [22]:
import os
from confluent_kafka import SerializingProducer, DeserializingConsumer
from confluent_kafka.serialization import StringSerializer, StringDeserializer
from confluent_kafka.admin import AdminClient, NewTopic
from uuid import uuid4
import sys, random
import csv, json
import time
from datetime import datetime 

In [23]:
BOOTSTRAP_SERVERS = os.environ.get('BOOTSTRAP_SERVERS')
assert BOOTSTRAP_SERVERS is not None, 'BOOTSTRAP_SERVERS must be set'

TRIP_CSV = "sample.csv"

# TODO: Update this code such that
# 1. It automatiaclly detects the csv file in given directory
# 2. It reads the csv file and generates the json data
# FARE_CSV = "sample_fare.csv"

assert os.path.exists(TRIP_CSV), f'{TRIP_CSV} file not found'

TRIP_TOPIC = 'trips'
FARE_TOPIC = 'fares'

PRODUCER_CONFIG = {
    'bootstrap.servers': BOOTSTRAP_SERVERS,
    'partitioner': 'murmur2_random',
    'key.serializer': StringSerializer('utf_8'),
    'value.serializer':  StringSerializer('utf_8')
}

## Utility functions

In [24]:
def get_topics():
    global BOOTSTRAP_SERVERS

    kafka_broker = {'bootstrap.servers': BOOTSTRAP_SERVERS}
    admin_client = AdminClient(kafka_broker)
    return admin_client.list_topics().topics

def delivery_report(err, msg):
    if err:
        print('Message delivery failed: {}'.format(err))

def convert_to_trip(row):
    """
    Converts row to a trip dictionary with the following keys
    medallion, hack_license, vendor_id, rate_code, store_and_fwd_flag,
    pickup_datetime, dropoff_datetime, passenger_count, trip_time_in_secs,
    trip_distance, pickup_longitude, pickup_latitude, dropoff_longitude,
    dropoff_latitude
    """

    time_stamp = time.time()
    date_time = datetime.fromtimestamp(time_stamp)
    str_date_time = date_time.strftime("%Y-%m-%dT%H:%M:%SZ") 

    trip = {
        "medallion": row[0],
        "hack_license": row[1],
        "vendor_id": row[2],
        "rate_code": row[3],
        "store_and_fwd_flag": row[4],
        "pickup_datetime": row[5],
        "dropoff_datetime": row[6],
        "passenger_count": row[7],
        "trip_time_in_secs": row[8],
        "trip_distance": row[9],
        "pickup_longitude": row[10],
        "pickup_latitude": row[11],
        "dropoff_longitude": row[12],
        "dropoff_latitude": row[13],
        "timestamp": str_date_time,
    }

    return trip

def convert_to_fare(row):
    """
    Converts row to a trip fare dictionary with the following keys
    medallion, hack_license, vendor_id,
    pickup_datetime, payment_type, fare_amount,
    surcharge, mta_tax, tip_amount, tolls_amount, total_amount
    """
    time_stamp = time.time()
    date_time = datetime.fromtimestamp(time_stamp)
    str_date_time = date_time.strftime("%Y-%m-%dT%H:%M:%SZ")  # "%d-%m-%Y, %H:%M:%S"

    fare = {
        "medallion": row[0],
        "hack_license": row[1],
        "vendor_id": row[2],
        "pickup_datetime": row[3],
        "payment_type": row[4],
        "fare_amount": row[5],
        "surcharge": row[6],
        "mta_tax": row[7],
        "tip_amount": row[8],
        "tolls_amount": row[9],
        "total_amount": row[10],
        "timestamp": str_date_time,
    }

    return fare

## Producer

In [25]:
get_topics()

{'test-topic': TopicMetadata(test-topic, 1 partitions),
 'trips': TopicMetadata(trips, 1 partitions),
 'stock': TopicMetadata(stock, 1 partitions),
 'anotherone': TopicMetadata(anotherone, 1 partitions)}

In [29]:
def produce_trips(trip_csv,producer_config, topic,limit=-1, header=True):
    n = 0
    p = SerializingProducer(producer_config)
    with open(trip_csv) as csv_file:
        csv_reader = csv.reader(csv_file, delimiter=',')
        if header:
            next(csv_reader)
        try:
            for row in csv_reader:
                if n == limit:
                    break
                trip = convert_to_trip(row)
                if n % 100 == 0:
                    print(f"Produced {n} messages")
                p.poll(0)
                p.produce(topic, value=json.dumps(trip), on_delivery=delivery_report)            
                time.sleep(0.5)
                n = n + 1

            p.flush()
        except BufferError:
            sys.stderr.write('%% Local producer queue is full (%d messages awaiting delivery): try again\n' % len(p))
    

In [None]:
produce_trips(TRIP_CSV,PRODUCER_CONFIG,TRIP_TOPIC,limit=10_000)

### Cleanup

In [None]:
# admin_client = AdminClient({"bootstrap.servers":BOOTSTRAP_SERVERS})
# admin_client.delete_topics(topics=[TRIP_TOPIC])