generated from opensafely/research-template
/
transform.py
221 lines (153 loc) · 6.68 KB
/
transform.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import csv
import numpy as np
import pandas as pd
from age_bands import add_age_bands
from add_groupings import add_groupings
from groups import at_risk_groups, groups
demographic_cols = ["age_band", "sex", "ethnicity", "high_level_ethnicity", "imd_band"]
group_cols = [
group for group in groups if "covax" not in group and "unstatvacc" not in group
]
extra_at_risk_cols = [group for group in at_risk_groups if group not in group_cols]
extra_cols = ["patient_id", "vacc1_dat", "vacc2_dat", "wave"]
extra_vacc_cols = []
for prefix in ["mo", "nx", "jn", "gs", "vl"]:
extra_vacc_cols.append(f"{prefix}d1rx_dat")
extra_vacc_cols.append(f"{prefix}d2rx_dat")
necessary_cols = (
demographic_cols + group_cols + extra_at_risk_cols + extra_cols + extra_vacc_cols
)
def run(input_path="output/input.csv", output_path="output/cohort.pickle"):
raw_cohort = load_raw_cohort(input_path)
cohort = transform(raw_cohort)[necessary_cols]
cohort.to_pickle(output_path)
def load_raw_cohort(input_path):
with open(input_path) as f:
reader = csv.reader(f)
fieldnames = next(reader)
date_fieldnames = [fn for fn in fieldnames if fn.endswith("_dat")]
raw_cohort = pd.read_csv(input_path, parse_dates=date_fieldnames)
return raw_cohort
def transform(cohort):
"""Transform data generated by study definition."""
drop_non_fm_sex(cohort)
drop_over_120_age(cohort)
add_imd_bands(cohort)
add_ethnicity(cohort)
add_high_level_ethnicity(cohort)
add_missing_vacc_columns(cohort)
add_vacc_dates(cohort)
# The PRIMIS spec contains a number of overlapping age bands. Bands 1 to 12 are
# non-overlapping and we use these by default. We can add other age bands as
# required.
add_age_bands(cohort, range(1, 12 + 1))
add_groupings(cohort)
add_waves(cohort)
add_extra_at_risk_cols(cohort)
return cohort
def drop_non_fm_sex(cohort):
"""Drop records where sex is not F or M."""
ix = cohort[~cohort["sex"].isin(["F", "M"])].index
cohort.drop(ix, inplace=True)
def drop_over_120_age(cohort):
"""Drop records where age is >= 120.
There are a handful of patients with a recorded date of birth of 1900-01-01.
"""
ix = cohort[cohort["age"] >= 120].index
cohort.drop(ix, inplace=True)
def add_imd_bands(cohort):
"""Add IMD band from 1 (most deprived) to 5 (least deprived), or 0 if missing."""
cohort["imd_band"] = 0
s = cohort["imd_band"]
for band in range(1, 5 + 1):
s.mask(
((band - 1) < cohort["imd"] * 5 / 32844)
& (cohort["imd"] * 5 / 32844 < band),
band,
inplace=True,
)
def add_ethnicity(cohort):
"""Add ethnicity using bandings from PRIMIS spec."""
# eth2001 already indicates whether a patient is in any of bands 1-16
s = cohort["eth2001"].copy()
# Add band 17 (Patients with any other ethnicity code)
s.mask(s.isna() & cohort["non_eth2001_dat"].notna(), 17, inplace=True)
# Add band 18 (Ethnicity not given - patient refused)
s.mask(s.isna() & cohort["eth_notgiptref_dat"].notna(), 18, inplace=True)
# Add band 19 (Ethnicity not stated)
s.mask(s.isna() & cohort["eth_notstated_dat"].notna(), 19, inplace=True)
# Add band 20 (Ethnicity not recorded)
s.mask(s.isna(), 20, inplace=True)
cohort["ethnicity"] = s.astype("int8")
def add_high_level_ethnicity(cohort):
"""Add high-level ethnicity categories, based on bandings from PRIMIS spec."""
# Get mapping from category (1-16) to high-level category (1-5)
category_to_high_level_category = {}
with open("codelists/primis-covid19-vacc-uptake-eth2001.csv") as f:
for record in csv.DictReader(f):
category = int(record["grouping_16_id"])
high_level_category = int(record["grouping_6_id"])
if category in category_to_high_level_category:
assert category_to_high_level_category[category] == high_level_category
else:
category_to_high_level_category[category] = high_level_category
# Set high_level_ethnicity based on ethnicity column
cohort["high_level_ethnicity"] = np.nan
s = cohort["high_level_ethnicity"]
for category, high_level_category in category_to_high_level_category.items():
s.mask(cohort["ethnicity"] == category, high_level_category, inplace=True)
# For all other patients, set high_level_ethnicity to 6 (unknown)
s.mask(s.isna(), 6, inplace=True)
cohort["high_level_ethnicity"] = s.astype("int8")
def add_missing_vacc_columns(cohort):
"""Add columns for vaccines that are not yet available but which are referenced by
the spec.
"""
for prefix in ["mo", "nx", "jn", "gs", "vl"]:
assert f"{prefix}d1rx_dat" not in cohort
assert f"{prefix}d2rx_dat" not in cohort
cohort[f"{prefix}d1rx_dat"] = np.nan
cohort[f"{prefix}d2rx_dat"] = np.nan
def add_vacc_dates(cohort):
"""Record earliest date of first and second vaccinations.
In some cases, a patient will have only one covadm1/2_dat and covrx1/2_dat.
"""
cohort["vacc1_dat"] = cohort[["covadm1_dat", "covrx1_dat"]].min(axis=1)
cohort["vacc2_dat"] = cohort[["covadm2_dat", "covrx2_dat"]].min(axis=1)
def add_waves(cohort):
cohort["wave"] = 0
s = cohort["wave"]
# Wave 1: Residents in Care Homes
# (The spec includes staff in care homes, but occupation codes are not well
# recorded)
s.mask(cohort["longres_dat"].notnull(), 1, inplace=True)
# Wave 2: Age 80 or over
# (This spec includes frontline H&SC workers, but see above.)
s.mask((s == 0) & (cohort["age"] >= 80), 2, inplace=True)
# Wave 3: Age 75 - 79
s.mask((s == 0) & (cohort["age"] >= 75), 3, inplace=True)
# Wave 4: Clinically Extremely Vulnerable or age 70 - 74
s.mask(
(s == 0) & (cohort["shield_group"] | (cohort["age"] >= 70)), 4, inplace=True
)
# Wave 5: Age 65 - 69
s.mask((s == 0) & (cohort["age"] >= 65), 5, inplace=True)
# Wave 6: Age 16-64 in a defined At Risk group
s.mask(
(s == 0) & ((cohort["age"] >= 16) & cohort["atrisk_group"]), 6, inplace=True
)
# Wave 7: Age 60 - 64
s.mask((s == 0) & (cohort["age"] >= 60), 7, inplace=True)
# Wave 8: Age 55 - 59
s.mask((s == 0) & (cohort["age"] >= 55), 8, inplace=True)
# Wave 9: Age 50 - 54
s.mask((s == 0) & (cohort["age"] >= 50), 9, inplace=True)
def add_extra_at_risk_cols(cohort):
"""Add columns for extra at-risk groups."""
for col in extra_at_risk_cols:
date_col = col.replace("_group", "_dat")
if date_col in cohort.columns:
cohort[col] = cohort[date_col].notna()
if __name__ == "__main__":
import sys
run(input_path=sys.argv[1])