diff --git a/throughputd.c b/throughputd.c index 746a970..fe6199e 100644 --- a/throughputd.c +++ b/throughputd.c @@ -3,7 +3,7 @@ This file is part of throughputd. - This program is free software; you can redistribute it and/or modify it + This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License version 2 as published by the Free Software Foundation. */ @@ -40,28 +40,30 @@ #define RECORDING_THREAD_SLEEP_TIME 1 #define SQL_SCHEMA_CREATE_STMT \ -"CREATE TABLE IF NOT EXISTS %s ( \ - id INTEGER PRIMARY KEY AUTOINCREMENT, \ - ip TEXT NOT NULL, \ - interface TEXT NOT NULL, \ - timestamp INTEGER NOT NULL, \ +"CREATE TABLE IF NOT EXISTS `%s` ( \ + id INTEGER PRIMARY KEY AUTOINCREMENT, \ + ip TEXT NOT NULL, \ + interface TEXT NOT NULL, \ + timestamp INTEGER NOT NULL, \ send_total INTEGER NOT NULL, \ - recv_total INTEGER NOT NULL \ + recv_total INTEGER NOT NULL \ );" -#define SQL_INSERT_RECORD_STMT "INSERT INTO %s(ip, interface, timestamp, send_total, recv_total) VALUES(?, ?, ?, ?, ?);" - -#define SQL_CREATE_INDEX_STMT "CREATE INDEX IF NOT EXISTS nt_timestamp ON %s(timestamp);" +#define SQL_CREATE_INDEX_STMT "CREATE INDEX IF NOT EXISTS `nt_timestamp` ON `%s`(timestamp);" +#define SQL_CREATE_SAVEPOINT_STMT "SAVEPOINT `%s`;" +#define SQL_RELEASE_SAVEPOINT_STMT "RELEASE SAVEPOINT `%s`;" +#define SQL_ROLLBACK_SAVEPOINT_STMT "ROLLBACK TO SAVEPOINT `%s`;" +#define SQL_INSERT_RECORD_STMT "INSERT INTO `%s`(ip, interface, timestamp, send_total, recv_total) VALUES(?, ?, ?, ?, ?);" #define DEFAULT_DBFILE_NAME "throughputd.db" #define DEFAULT_TABLE_NAME "network_traffic" #define USAGE \ -"Usage: %s [options...] []\n" \ -"Valid options are:\n" \ +"Usage: %s [options...] []\n" \ +"Valid options are:\n" \ " -t integer Interval between writes in seconds (default: 5)\n" \ " -f path Path to sqlite database (default: throughputd.db)\n" \ -" -p path Path to PID file (default: none)\n" \ +" -p path Path to PID file (default: none)\n" \ " -a table Name of database table (default: network_traffic)\n" \ " -d Daemonize after starting (only if debugging disabled)\n" @@ -134,37 +136,39 @@ static struct sigaction signal_action = { /*************************** Recording Logic ****************************/ +static int free_entry(struct hashtable *records, struct hashtable_link *hl, void *unused){ + hashtable_delete(records, hl->key); + free(container_of(hl, struct throughputd_record, link)); + return 0; +} + static int record_entry(struct hashtable *records, struct hashtable_link *hl, void *data){ int ret; sqlite3_stmt *stmt = NULL; struct throughputd_context *ctx = data; struct throughputd_record *record = container_of(hl, struct throughputd_record, link); - + ret = sqlite3_prepare_v2(db, insert_stmt, strlen(insert_stmt), &stmt, NULL); if(ret != SQLITE_OK){ PRINT_ERROR(ret, "error preparing insert statement: %s", sqlite3_errmsg(db)); goto error; } - + sqlite3_bind_text(stmt, 1, record->lan_ip, strlen(record->lan_ip), NULL); sqlite3_bind_text(stmt, 2, ctx->ifaddr->ifa_name, strlen(ctx->ifaddr->ifa_name), NULL); sqlite3_bind_int64(stmt, 3, ctx->cur_time); sqlite3_bind_int64(stmt, 4, record->send_total); sqlite3_bind_int64(stmt, 5, record->recv_total); - + ret = sqlite3_step(stmt); - if(ret != SQLITE_DONE) { - PRINT_ERROR(ret, "error executing insert statement"); + if(ret != SQLITE_DONE && ret != SQLITE_BUSY){ + PRINT_ERROR(ret, "error executing insert statement: %s", sqlite3_errmsg(db)); goto error; } - + sqlite3_finalize(stmt); - - hashtable_delete(records, hl->key); - free(record); - - return 0; - + return (ret == SQLITE_DONE) ? 0 : SQLITE_BUSY; + error: PRINT_ERROR(ret, "error recording entry"); if(stmt) sqlite3_finalize(stmt); @@ -174,61 +178,145 @@ static int record_entry(struct hashtable *records, struct hashtable_link *hl, vo static void *recording_thread(void *unused){ int ret, i, sleep_cnt = 0; char transaction_exists = 0; + char *create_savepoint_stmt = NULL; + char *release_savepoint_stmt = NULL; + char *rollback_savepoint_stmt = NULL; time_t cur_time; struct throughputd_context *ctx; - + while(!should_stop_recording){ sleep(RECORDING_THREAD_SLEEP_TIME); - + if(sleep_cnt + 1 < record_interval){ sleep_cnt++; continue; } - + sleep_cnt = 0; - + PRINT_DEBUG("interval elapsed, recording current state"); cur_time = time(NULL); - + + /* begin the transaction for the entire timestamp */ +retry_transaction_begin: ret = sqlite3_exec(db, "BEGIN;", NULL, NULL, NULL); - if(ret != SQLITE_OK){ - PRINT_ERROR(ret, "error beginning transaction"); + if(ret == SQLITE_BUSY){ + PRINT_DEBUG("transaction begin returned busy, retrying"); + goto retry_transaction_begin; + }else if(ret != SQLITE_OK){ + PRINT_ERROR(ret, "error beginning transaction: %s", sqlite3_errmsg(db)); goto error; } + transaction_exists = 1; - + for(i = 0; i < context_count; i++){ ctx = &throughputd_contexts[i]; ctx->cur_time = cur_time; - - pthread_mutex_lock(&ctx->lock); - + PRINT_DEBUG("recording current state for %s", ctx->ifaddr->ifa_name); + + /* create the strings we will need for managing transaction state */ + ret = asprintf(&create_savepoint_stmt, SQL_CREATE_SAVEPOINT_STMT, ctx->ifaddr->ifa_name); + if(ret < 0){ + create_savepoint_stmt = NULL; + ret = EFAULT; + PRINT_ERROR(ret, "error composing savepoint create statement"); + goto error; + } + + ret = asprintf(&release_savepoint_stmt, SQL_RELEASE_SAVEPOINT_STMT, ctx->ifaddr->ifa_name); + if(ret < 0){ + release_savepoint_stmt = NULL; + ret = EFAULT; + PRINT_ERROR(ret, "error composing savepoint release statement"); + goto error; + } + + ret = asprintf(&rollback_savepoint_stmt, SQL_ROLLBACK_SAVEPOINT_STMT, ctx->ifaddr->ifa_name); + if(ret < 0){ + rollback_savepoint_stmt = NULL; + ret = EFAULT; + PRINT_ERROR(ret, "error composing savepoint rollback statement"); + goto error; + } + + pthread_mutex_lock(&ctx->lock); + +retry_savepoint: + /* create savepoint for rollback purposes */ + ret = sqlite3_exec(db, create_savepoint_stmt, NULL, NULL, NULL); + if(ret == SQLITE_BUSY){ + PRINT_DEBUG("attempt to create savepoint returned busy, retrying"); + goto retry_savepoint; + }else if(ret != SQLITE_OK){ + PRINT_ERROR(ret, "error creating savepoint for new records: %s", sqlite3_errmsg(db)); + pthread_mutex_unlock(&ctx->lock); + goto error; + } + + /* push the current records to the database */ ret = hashtable_for_each_key(&ctx->records, record_entry, ctx); - if(ret) goto error; - + if(ret == SQLITE_BUSY){ + /* inserts within a transaction cannot be retried if they return busy. Rollback to savepoint */ + PRINT_DEBUG("attempt to record entry returned busy, rolling back and retrying"); + ret = sqlite3_exec(db, rollback_savepoint_stmt, NULL, NULL, NULL); + if(ret != SQLITE_OK){ + PRINT_ERROR(ret, "error rolling back to savepoint: %s", sqlite3_errmsg(db)); + pthread_mutex_unlock(&ctx->lock); + goto error; + } + goto retry_savepoint; + }else if(ret != SQLITE_OK){ + pthread_mutex_unlock(&ctx->lock); + goto error; + } + +retry_savepoint_release: + /* release the savepoint */ + ret = sqlite3_exec(db, release_savepoint_stmt, NULL, NULL, NULL); + if(ret == SQLITE_BUSY){ + /* savepoint commits can be retried directly if they returned busy */ + PRINT_DEBUG("savepoint release returned busy, retrying"); + goto retry_savepoint_release; + }else if(ret != SQLITE_OK){ + PRINT_ERROR(ret, "error releasing savepoint: %s", sqlite3_errmsg(db)); + pthread_mutex_unlock(&ctx->lock); + goto error; + } + + /* delete all entries after the transaction has succeeded */ + (void)hashtable_for_each_key(&ctx->records, free_entry, ctx); + pthread_mutex_unlock(&ctx->lock); + + free(create_savepoint_stmt); + free(release_savepoint_stmt); + free(rollback_savepoint_stmt); } - - PRINT_DEBUG("committing transaction"); -recommit_transaction: + PRINT_DEBUG("committing transaction"); +retry_transaction_commit: ret = sqlite3_exec(db, "COMMIT;", NULL, NULL, NULL); if(ret == SQLITE_BUSY){ + /* transaction commits can be retried directly if they returned busy */ PRINT_DEBUG("transaction commit returned busy, retrying"); - goto recommit_transaction; + goto retry_transaction_commit; }else if(ret != SQLITE_OK){ PRINT_ERROR(ret, "error committing transaction: %s", sqlite3_errmsg(db)); goto error; } transaction_exists = 0; } - + return NULL; - + error: PRINT_ERROR(ret, "error during the recording thread, exiting"); if(transaction_exists) sqlite3_exec(db, "ROLLBACK;", NULL, NULL, NULL); + if(create_savepoint_stmt) free(create_savepoint_stmt); + if(release_savepoint_stmt) free(release_savepoint_stmt); + if(rollback_savepoint_stmt) free(rollback_savepoint_stmt); return NULL; } @@ -237,18 +325,18 @@ static void *recording_thread(void *unused){ static int throughputd_record_alloc(char *ip, struct throughputd_record **record_out){ int ret; struct throughputd_record *record; - + PRINT_DEBUG("allocating new record"); record = malloc(sizeof(struct throughputd_record)); if(!record){ ret = ENOMEM; PRINT_ERROR(ret, "error allocating new record"); - + *record_out = NULL; return ret; } - - record->recv_total = 0; + + record->recv_total = 0; record->send_total = 0; strncpy(record->lan_ip, ip, INET6_ADDRSTRLEN); @@ -260,29 +348,29 @@ static int update_record(struct throughputd_context *ctx, char *ip, uint64_t dat int ret; struct throughputd_record *record; struct hashtable_link *hl; - + pthread_mutex_lock(&ctx->lock); - + hl = hashtable_find(&ctx->records, ip); if(!hl){ PRINT_DEBUG("existing record not found, adding new record"); ret = throughputd_record_alloc(ip, &record); if(ret) goto error; - + ret = hashtable_insert(&ctx->records, record->lan_ip, &record->link); if(ret){ PRINT_ERROR(ret, "error inserting new record into hashtable: This shouldn't happen"); goto error; } }else record = container_of(hl, struct throughputd_record, link); - + if(is_recv) record->recv_total += datalen; else record->send_total += datalen; - + pthread_mutex_unlock(&ctx->lock); - + return 0; - + error: PRINT_ERROR(ret, "error updating record for packet"); pthread_mutex_unlock(&ctx->lock); @@ -293,7 +381,7 @@ static int update_record(struct throughputd_context *ctx, char *ip, uint64_t dat static int ip_matches_nic(uint32_t *ip, uint32_t *if_addr, uint32_t *netmask, int bits){ int i; - + for(i = 0; i < bits / 8 / sizeof(uint32_t); i++){ if((if_addr[i] & netmask[i]) != (ip[i] & netmask[i])) return 0; } @@ -314,12 +402,12 @@ static void handle_ipv4_packet(struct throughputd_context *ctx, const u_char *pa inet_ntop(AF_INET, &header->dest, dest_ip, INET_ADDRSTRLEN); PRINT_DEBUG("IPv4 src = %s, dest = %s", src_ip, dest_ip); - + if(ip_matches_nic(src_addr, if_addr, if_mask, 32)){ PRINT_DEBUG("Packet is outgoing"); update_record(ctx, src_ip, len, 0); } - + if(ip_matches_nic(dest_addr, if_addr, if_mask, 32)){ PRINT_DEBUG("Packet is incoming"); update_record(ctx, dest_ip, len, 1); @@ -334,17 +422,17 @@ static void handle_ipv6_packet(struct throughputd_context *ctx, const u_char *pa uint32_t *if_mask = (uint32_t *)(((struct sockaddr_in6 *)ctx->ifaddr->ifa_netmask)->sin6_addr.s6_addr); uint32_t *src_addr = ((uint32_t *)&header->src.s6_addr); uint32_t *dest_addr = ((uint32_t *)&header->dest.s6_addr); - + inet_ntop(AF_INET6, &header->src, src_ip, INET6_ADDRSTRLEN); inet_ntop(AF_INET6, &header->dest, dest_ip, INET6_ADDRSTRLEN); PRINT_DEBUG("IPv6 src = %s, dest = %s", src_ip, dest_ip); - + if(ip_matches_nic(src_addr, if_addr, if_mask, 128)){ PRINT_DEBUG("Packet is outgoing"); update_record(ctx, src_ip, len, 0); } - + if(ip_matches_nic(dest_addr, if_addr, if_mask, 128)){ PRINT_DEBUG("Packet is incoming"); update_record(ctx, dest_ip, len, 1); @@ -354,12 +442,12 @@ static void handle_ipv6_packet(struct throughputd_context *ctx, const u_char *pa /*************************** Main processing Logic ****************************/ static void on_packet_received(u_char *data, const struct pcap_pkthdr *pkt_header, const u_char *packet){ - struct throughputd_context *ctx = (struct throughputd_context *)data; + struct throughputd_context *ctx = (struct throughputd_context *)data; struct ethernet_header *eth_header = (struct ethernet_header *)packet; int tag_offset = 0; uint8_t *eth_type_ptr; uint16_t eth_type; - + reevaluate_type: eth_type_ptr = ((uint8_t*)ð_header->type) + tag_offset; eth_type = ntohs(*((uint16_t*)eth_type_ptr)); @@ -394,7 +482,7 @@ static void *interface_listening_thread(void *data){ PRINT_ERROR(ret, "error running pcap loop"); return NULL; } - + PRINT_DEBUG("exiting processing loop"); return NULL; @@ -404,14 +492,14 @@ static int initialize_thread_context(struct throughputd_context *ctx, struct ifa int ret; char hashtable_initialized = 0, mutex_intialized = 0; char errbuf[PCAP_ERRBUF_SIZE]; - + ctx->ifaddr = interface; - + PRINT_DEBUG("initializing hashtable"); ret = hashtable_init(&ctx->records, HASHTABLE_ARRAY_SIZE); if(ret) goto error; hashtable_initialized = 1; - + PRINT_DEBUG("initializing mutex"); ret = pthread_mutex_init(&ctx->lock, NULL); if(ret){ @@ -419,7 +507,7 @@ static int initialize_thread_context(struct throughputd_context *ctx, struct ifa goto error; } mutex_intialized = 1; - + PRINT_DEBUG("opening device %s for listening", interface->ifa_name); ctx->pcap_fd = pcap_open_live(interface->ifa_name, PACKET_BUF_LEN, 0, PCAP_TIMEOUT_MS, errbuf); if(!ctx->pcap_fd){ @@ -427,14 +515,14 @@ static int initialize_thread_context(struct throughputd_context *ctx, struct ifa PRINT_ERROR(ret, "error opening pcap device %s", errbuf); goto error; } - + PRINT_DEBUG("creating thread"); ret = pthread_create(&ctx->thread, NULL, interface_listening_thread, ctx); if(ret){ PRINT_ERROR(ret, "error starting trafic monitoring thread"); goto error; } - + return 0; error: @@ -446,10 +534,10 @@ static int initialize_thread_context(struct throughputd_context *ctx, struct ifa static void signal_handler(int sig){ int i; - + PRINT_DEBUG("----------------------- SIGTERM / SIGINT caught ---------------------------"); should_stop_recording = 1; - + for(i = 0; i < context_count; i++){ pcap_breakloop(throughputd_contexts[i].pcap_fd); } @@ -462,60 +550,60 @@ static int free_record(struct hashtable *records, struct hashtable_link *hl, voi static int string_is_present(int argc, char **argv, char *str){ int i; - + for(i = 0; i < argc; i++){ if(!strcmp(str, argv[i])) return 1; } - + return 0; } static int context_already_exists(char *ifname){ int i; - + for(i = 0; i < context_count; i++){ if(!strcmp(throughputd_contexts[i].ifaddr->ifa_name, ifname)) return 1; } - + return 0; } static int interface_exists(char *ifname){ struct ifaddrs *interface; - + for(interface = ifaddrs; interface; interface = interface->ifa_next){ if(interface->ifa_addr->sa_family != AF_INET && interface->ifa_addr->sa_family != AF_INET6) continue; if(!strcmp(interface->ifa_name, ifname)) return 1; } - + return 0; } static void throughputd_cleanup(void){ int i; struct throughputd_context *ctx; - + PRINT_DEBUG("cleaning up all allocations"); - + for(i = 0; i < context_count; i++){ ctx = &throughputd_contexts[i]; - + PRINT_DEBUG("waiting for thread to stop"); pthread_join(ctx->thread, NULL); - + PRINT_DEBUG("freeing hashtable entries"); hashtable_for_each_key(&ctx->records, free_record, NULL); - + PRINT_DEBUG("freeing hashtable"); hashtable_destroy(&ctx->records); PRINT_DEBUG("destroying context lock"); pthread_mutex_destroy(&ctx->lock); - + PRINT_DEBUG("destroying context interface descriptor"); pcap_close(ctx->pcap_fd); } - + PRINT_DEBUG("destroying global variables"); if(ifaddrs) freeifaddrs(ifaddrs); if(throughputd_contexts) free(throughputd_contexts); @@ -539,7 +627,7 @@ int main(int argc, char **argv){ #ifndef DEBUG_ENABLED int daemonize = 0; #endif - + while((c = getopt(argc, argv, "df:t:a:p:")) != -1){ switch(c){ #ifndef DEBUG_ENABLED @@ -570,7 +658,7 @@ int main(int argc, char **argv){ goto error; } } - + #ifndef DEBUG_ENABLED if(daemonize){ PRINT_DEBUG("daemonizing process"); @@ -591,7 +679,7 @@ int main(int argc, char **argv){ PRINT_ERROR(ret, "error opening pid file for writing"); goto error; } - + PRINT_DEBUG("writing PID to file"); ret = fprintf(pid_fd, "%d\n", getpid()); if(ret < 0){ @@ -599,11 +687,11 @@ int main(int argc, char **argv){ PRINT_ERROR(ret, "error writing pid to pid file"); goto error; } - + PRINT_DEBUG("closing PID file"); fclose(pid_fd); } - + interface_names = &argv[optind]; if(argv[optind] && strlen(argv[optind]) == 0 && argc - optind == 1){ if_count = 0; @@ -612,27 +700,27 @@ int main(int argc, char **argv){ else{ if_count = argc - optind; } - + PRINT_DEBUG("registering signal handlers %d", signal_action.sa_flags); ret = sigaction(SIGTERM, &signal_action, NULL); if(ret){ PRINT_ERROR(ret, "error registering SIGTERM signal handler"); goto error; } - + ret = sigaction(SIGINT, &signal_action, NULL); if(ret){ PRINT_ERROR(ret, "error registering SIGINT signal handler"); goto error; } - + PRINT_DEBUG("opening sqlite database file %s", dbname); ret = sqlite3_open(dbname, &db); if(ret){ PRINT_ERROR(ret, "error opening sqlite database"); goto error; } - + PRINT_DEBUG("composing create table statement"); ret = asprintf(&create_stmt, SQL_SCHEMA_CREATE_STMT, db_table_name); if(ret < 0){ @@ -641,7 +729,7 @@ int main(int argc, char **argv){ PRINT_ERROR(ret, "error composing create table statement"); goto error; } - + PRINT_DEBUG("creating table (if it doesnt exist already)"); ret = sqlite3_exec(db, create_stmt, NULL, NULL, NULL); if(ret != SQLITE_OK){ @@ -650,7 +738,7 @@ int main(int argc, char **argv){ } free(create_stmt); create_stmt = NULL; - + PRINT_DEBUG("composing create index statement"); ret = asprintf(&index_stmt, SQL_CREATE_INDEX_STMT, db_table_name); if(ret < 0){ @@ -659,7 +747,7 @@ int main(int argc, char **argv){ PRINT_ERROR(ret, "error composing create index statement"); goto error; } - + PRINT_DEBUG("creating timestamp index"); ret = sqlite3_exec(db, index_stmt, NULL, NULL, NULL); if(ret != SQLITE_OK){ @@ -668,7 +756,7 @@ int main(int argc, char **argv){ } free(index_stmt); index_stmt = NULL; - + PRINT_DEBUG("composing insert statement"); ret = asprintf(&insert_stmt, SQL_INSERT_RECORD_STMT, db_table_name); if(ret < 0){ @@ -677,14 +765,14 @@ int main(int argc, char **argv){ PRINT_ERROR(ret, "error composing insert statement"); goto error; } - + PRINT_DEBUG("fetching list of all network interfaces"); ret = getifaddrs(&ifaddrs); if(ret){ PRINT_ERROR(ret, "error fetching list of network interfaces"); goto error; } - + if(if_count == 0){ PRINT_DEBUG("no interfaces specified, discovering all network interfaces"); for(interface = ifaddrs; interface; interface = interface->ifa_next){ @@ -702,7 +790,7 @@ int main(int argc, char **argv){ default: continue; } - + if_count++; } }else{ @@ -716,7 +804,7 @@ int main(int argc, char **argv){ } } } - + PRINT_DEBUG("allocating context array"); throughputd_contexts = malloc(if_count * sizeof(struct throughputd_context)); if(!throughputd_contexts){ @@ -724,23 +812,23 @@ int main(int argc, char **argv){ PRINT_ERROR(ret, "error allocating context array"); goto error; } - + PRINT_DEBUG("initializing context array"); for(interface = ifaddrs; interface; interface = interface->ifa_next){ if(interface->ifa_addr->sa_family != AF_INET && interface->ifa_addr->sa_family != AF_INET6) continue; - + if(interface_names && *interface_names){ if(!string_is_present(if_count, interface_names, interface->ifa_name)) continue; if(context_already_exists(interface->ifa_name)) continue; } - + PRINT_DEBUG("initialing interface %s", interface->ifa_name); ret = initialize_thread_context(&throughputd_contexts[context_count], interface); if(ret) goto error; - + context_count++; } - + PRINT_DEBUG("creating recording thread"); ret = pthread_create(&recording_pthread, NULL, recording_thread, NULL); if(ret){ @@ -754,14 +842,14 @@ int main(int argc, char **argv){ PRINT_ERROR(ret, "error joining main thread with recording thread"); goto error; } - + throughputd_cleanup(); return 0; - + error: PRINT_ERROR(ret, "error during main function"); if(ret == EINVAL) printf(USAGE, argv[0]); - + if(create_stmt) free(create_stmt); if(index_stmt) free(index_stmt); throughputd_cleanup();