/
combine_csvs.py
72 lines (56 loc) · 1.75 KB
/
combine_csvs.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
#! /usr/bin/env python
"""
Combine CSVs and sort by given field.
"""
import sys
import argparse
import csv
def main():
p = argparse.ArgumentParser()
p.add_argument('--sort-by', default=None)
p.add_argument('--reverse', action='store_true')
p.add_argument('--fields', type=str)
p.add_argument('csvs', nargs='+')
args = p.parse_args()
first_csv = args.csvs[0]
csvs = args.csvs[1:]
rows = []
with open(first_csv, 'rt') as fp:
r = csv.DictReader(fp)
rows.extend(list(r))
fieldnames = r.fieldnames
if not rows:
print(f"error: file {first_csv} is empty!?", file=sys.stderr)
sys.exit(-1)
if args.fields:
fieldnames = args.fields.split(',')
if args.sort_by:
sort_by = args.sort_by
else:
sort_by = fieldnames[0]
try:
float(rows[0][sort_by])
key_fn = lambda x: float(x[sort_by])
except ValueError:
key_fn = lambda x: x[sort_by]
for csvfile in csvs:
with open(csvfile, 'rt') as fp:
r = csv.DictReader(fp)
if not set(fieldnames).issubset(r.fieldnames):
diff = set(r.fieldnames) ^ set(fieldnames)
print(f"error! disjoint fieldnames b/t {first_csv} and {csvfile}: {str(diff)}", file=sys.stderr)
sys.exit(-1)
new_rows = list(r)
rows.extend(new_rows)
print(f'loaded {len(rows)} total. now sorting!', file=sys.stderr)
rows.sort(key=key_fn, reverse=args.reverse)
o = csv.DictWriter(sys.stdout, fieldnames)
o.writeheader()
for row in rows:
subrow = {}
for field in fieldnames:
subrow[field] = row[field]
o.writerow(subrow)
return 0
if __name__ == '__main__':
sys.exit(main())