generated from opensafely/research-template
/
transform.py
132 lines (89 loc) · 4.12 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import csv
import numpy as np
import pandas as pd
from age_bands import add_age_bands
from add_groupings import add_groupings
def run(input_path="output/input.csv", output_path="output/cohort.pickle"):
raw_cohort = load_raw_cohort(input_path)
cohort = transform(raw_cohort)
cohort.to_pickle(output_path)
def load_raw_cohort(input_path):
with open(input_path) as f:
reader = csv.reader(f)
fieldnames = next(reader)
date_fieldnames = [fn for fn in fieldnames if fn.endswith("_dat")]
raw_cohort = pd.read_csv(input_path, parse_dates=date_fieldnames)
return raw_cohort
def transform(cohort):
"""Transform data generated by study definition."""
cohort = drop_non_fm_sex(cohort)
cohort = drop_over_120_age(cohort)
add_ethnicity(cohort)
add_high_level_ethnicity(cohort)
add_missing_vacc_columns(cohort)
add_vacc_dates(cohort)
# The PRIMIS spec contains a number of overlapping age bands. Bands 1 to 12 are
# non-overlapping and we use these by default. We can add other age bands as
# required.
add_age_bands(cohort, range(1, 12 + 1))
add_groupings(cohort)
return cohort
def drop_non_fm_sex(cohort):
"""Drop records where sex is not F or M."""
return cohort[cohort["sex"].isin(["F", "M"])].reindex()
def drop_over_120_age(cohort):
"""Drop records where age is >= 120.
There are a handful of patients with a recorded date of birth of 1900-01-01.
"""
return cohort[cohort["age"] < 120].reindex()
def add_ethnicity(cohort):
"""Add ethnicity using bandings from PRIMIS spec."""
# eth2001 already indicates whether a patient is in any of bands 1-16
s = cohort["eth2001"].copy()
# Add band 17 (Patients with any other ethnicity code)
s.mask(s.isna() & cohort["non_eth2001_dat"].notna(), 17, inplace=True)
# Add band 18 (Ethnicity not given - patient refused)
s.mask(s.isna() & cohort["eth_notgiptref_dat"].notna(), 18, inplace=True)
# Add band 19 (Ethnicity not stated)
s.mask(s.isna() & cohort["eth_notstated_dat"].notna(), 19, inplace=True)
# Add band 20 (Ethnicity not recorded)
s.mask(s.isna(), 20, inplace=True)
cohort["ethnicity"] = s.astype("int8")
def add_high_level_ethnicity(cohort):
"""Add high-level ethnicity categories, based on bandings from PRIMIS spec."""
# Get mapping from category (1-16) to high-level category (1-5)
category_to_high_level_category = {}
with open("codelists/primis-covid19-vacc-uptake-eth2001.csv") as f:
for record in csv.DictReader(f):
category = int(record["grouping_16_id"])
high_level_category = int(record["grouping_6_id"])
if category in category_to_high_level_category:
assert category_to_high_level_category[category] == high_level_category
else:
category_to_high_level_category[category] = high_level_category
# Set high_level_ethnicity based on ethnicity column
cohort["high_level_ethnicity"] = np.nan
s = cohort["high_level_ethnicity"]
for category, high_level_category in category_to_high_level_category.items():
s.mask(cohort["ethnicity"] == category, high_level_category, inplace=True)
# For all other patients, set high_level_ethnicity to 6 (unknown)
s.mask(s.isna(), 6, inplace=True)
cohort["high_level_ethnicity"] = s.astype("int8")
def add_missing_vacc_columns(cohort):
"""Add columns for vaccines that are not yet available but which are referenced by
the spec.
"""
for prefix in ["mo", "nx", "jn", "gs", "vl"]:
assert f"{prefix}d1rx_dat" not in cohort
assert f"{prefix}d2rx_dat" not in cohort
cohort[f"{prefix}d1rx_dat"] = np.nan
cohort[f"{prefix}d2rx_dat"] = np.nan
def add_vacc_dates(cohort):
"""Record earliest date of first and second vaccinations.
In some cases, a patient will have only one covadm1/2_dat and covrx1/2_dat.
"""
cohort["vacc1_dat"] = cohort[["covadm1_dat", "covrx1_dat"]].min(axis=1)
cohort["vacc2_dat"] = cohort[["covadm2_dat", "covrx2_dat"]].min(axis=1)
if __name__ == "__main__":
import sys
run(input_path=sys.argv[1])