-
Notifications
You must be signed in to change notification settings - Fork 13
/
tpp_backend.py
3767 lines (3489 loc) · 143 KB
/
tpp_backend.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
import datetime
import enum
import hashlib
import math
import os
import re
import uuid
from functools import cached_property
from urllib import parse
import pandas
import structlog
from .codelistlib import expand_dmd_codelist
from .csv_utils import is_csv_filename, write_rows_to_csv
from .date_expressions import MSSQLDateFormatter
from .expressions import format_expression
from .log_utils import LoggingDatabaseConnection, log_execution_time, log_stats
from .mssql_utils import (
mssql_connection_params_from_url,
mssql_dbapi_connection_from_url,
mssql_fetch_table,
)
from .pandas_utils import dataframe_from_rows, dataframe_to_file
from .process_covariate_definitions import ISARIC_COLUMN_MAPPINGS
from .therapeutics_utils import ALLOWED_RISK_GROUPS
logger = structlog.get_logger()
# The batch size was chosen through a bit of unscientific trial-and-error and some
# guesswork. It may well need changing in future.
BATCH_SIZE = 32000
# Retry six times over ~90 minutes.
RETRIES = 6
SLEEP = 4
BACKOFF_FACTOR = 4
# This table has its name for historical reasons, and reads slightly oddly: it should be
# interpreted as "allowed patients with regard to type one dissents"
ALLOWED_PATIENTS_TABLE = "AllowedPatientsWithTypeOneDissent"
class TPPBackend:
_db_connection = None
_current_column_name = None
include_t1oo = False
def __init__(
self,
database_url,
covariate_definitions,
temporary_database=None,
dummy_data=False,
):
if database_url is not None:
# set self.include_t1oo from the database url
self.database_url = self.modify_dsn(database_url)
else:
self.database_url = database_url
self.covariate_definitions = covariate_definitions
self.temporary_database = temporary_database
self.dummy_data = dummy_data
self.next_temp_table_id = 1
self._therapeutics_table_name = None
self.truncate_sql_logs = False
if self.covariate_definitions:
self.queries = self.get_queries(self.covariate_definitions)
else:
self.queries = []
def modify_dsn(self, dsn):
"""
Removes the `opensafely_include_t1oo` parameter if present and uses it to set
the `include_t1oo` attribute accordingly
"""
parts = parse.urlparse(dsn)
params = parse.parse_qs(parts.query, keep_blank_values=True)
include_t1oo_values = params.pop("opensafely_include_t1oo", [])
if len(include_t1oo_values) == 1:
self.include_t1oo = include_t1oo_values[0].lower() == "true"
elif len(include_t1oo_values) != 0:
raise ValueError(
"`opensafely_include_t1oo` parameter must not be supplied more than once"
)
new_query = parse.urlencode(params, doseq=True)
new_parts = parts._replace(query=new_query)
return parse.urlunparse(new_parts)
def to_file(self, filename):
queries = list(self.queries)
# If we have a temporary database available we write results to a table
# there, download them, and then delete the table. This allows us to
# resume in the case of a failed download without rerunning the whole
# query
if self.temporary_database:
output_table = self.save_results_to_temporary_db(queries)
else:
output_table = "#final_output"
queries[-1] = (
f"-- Writing results into {output_table}\n"
f"SELECT * INTO {output_table} FROM ({queries[-1]}) t"
)
queries.append(f"CREATE INDEX ix_patient_id ON {output_table} (patient_id)")
self.execute_queries(queries)
results = mssql_fetch_table(
get_cursor=self._get_cursor,
table=output_table,
key_column="patient_id",
batch_size=BATCH_SIZE,
retries=RETRIES,
sleep=SLEEP,
backoff_factor=BACKOFF_FACTOR,
)
# Wrap the results stream in a function which captures unique IDs and
# logs progress
unique_ids = set()
total_rows = 0
def check_ids_and_log(results):
risk_group_variables = self.get_therapeutic_risk_groups()
nonlocal total_rows
headers = next(results)
id_column_index = headers.index("patient_id")
yield headers
for row in results:
if risk_group_variables:
row = self._clean_risk_groups(row, headers, risk_group_variables)
unique_ids.add(row[id_column_index])
total_rows += 1
if total_rows % 1000000 == 0:
logger.info(f"Downloaded {total_rows} results")
yield row
logger.info(f"Downloaded {total_rows} results")
results = check_ids_and_log(results)
# Special handling for CSV as we can stream this directly to disk
# without building a dataframe in memory
if is_csv_filename(filename):
with log_execution_time(
logger, description=f"write_rows_to_csv {filename}"
):
write_rows_to_csv(results, filename)
else:
with log_execution_time(
logger, description=f"Create df and write dataframe_to_file {filename}"
):
df = dataframe_from_rows(self.covariate_definitions, results)
dataframe_to_file(df, filename)
self.execute_queries(
[f"-- Deleting '{output_table}'\nDROP TABLE {output_table}"]
)
duplicates = total_rows - len(unique_ids)
if duplicates != 0:
raise RuntimeError(f"Duplicate IDs found ({duplicates} rows)")
def _get_cursor(self):
# If we've written results to a temporary database then we make a
# new connection each time we retry, which can fix issues with
# dropped connections. If we haven't then we're relying on
# session-scoped temporary tables which means we can't reconnect
# without losing everything
force_reconnect = bool(self.temporary_database)
return self.get_db_connection(force_reconnect=force_reconnect).cursor()
def _clean_risk_groups(self, row, keys, risk_group_variables):
"""
Check that risk group variables only contain verified allowed risk
groups, and remove duplicates
"""
indices = [keys.index(variable_name) for variable_name in risk_group_variables]
row = list(row)
for index in indices:
risk_groups = row[index]
cleaned_groups = ",".join(
{
group if group.lower() in ALLOWED_RISK_GROUPS else "other"
for group in risk_groups.split(",")
}
)
row[index] = cleaned_groups
return tuple(row)
def to_dicts(self, convert_to_strings=True):
result = self.execute_queries(self.queries)
keys = [x[0] for x in result.description]
# This checks any risk group variables and replaces disallowed values
risk_group_variables = self.get_therapeutic_risk_groups()
output = []
for row in result:
if risk_group_variables:
row = self._clean_risk_groups(row, keys, risk_group_variables)
if convert_to_strings:
# Convert all values to str as that's what will end in a CSV
row = map(str, row)
output.append(dict(zip(keys, row)))
unique_ids = set(item["patient_id"] for item in output)
duplicates = len(output) - len(unique_ids)
if duplicates != 0:
raise RuntimeError(f"Duplicate IDs found ({duplicates} rows)")
return output
def to_sql(self):
"""
Generate a single SQL string.
Useful for debugging, optimising, etc.
"""
return "\nGO\n\n".join(self.queries)
def save_results_to_temporary_db(self, queries):
"""
Sometimes there are glitches (network issues?) which occur when
downloading large result sets. To avoid having to recompute these each
time we can write the results to a table which we only delete once
we've fully downloaded its contents. If the download is interrupted
then subsequent runs will pick up the table and download it without
having to re-run all the queries.
"""
assert self.temporary_database
# We're using the hash of all the queries and the database name as a
# cache key. Obviously this doesn't take into account the fact that the
# data itself may change, but for our purposes this doesn't matter:
# this is designed to be a very short-lived cache which is deleted as
# soon as the data is successfully downloaded. We need to include the
# database name because a single server may contain multiple databases
# (e.g full data and sample data) which share a single temporary
# database.
hash_elements = queries + [
mssql_connection_params_from_url(self.database_url)["database"]
]
query_hash = hashlib.sha1("\n".join(hash_elements).encode("utf8")).hexdigest()
output_table = f"{self.temporary_database}..DataExtract_{query_hash}"
logger.info(f"Checking for existing results in '{output_table}'")
if not self.table_exists(output_table):
logger.info(
"No existing results found, running queries to generate new results"
)
# We want to raise an error now, rather than waiting until we've
# run all the other queries
self.assert_database_exists_and_is_writable(self.temporary_database)
queries = list(queries)
final_query = queries.pop()
self.execute_queries(queries)
# We need to run the final query in a transaction so that we don't end up
# with an empty output table in the event that the query fails. See:
# https://docs.microsoft.com/en-us/sql/t-sql/queries/select-into-clause-transact-sql?view=sql-server-ver15#remarks
conn = self.get_db_connection()
logger.info(f"Writing results into temporary table '{output_table}'")
# pymssql's autocommit implementation is different to CTDS and MS
# pyodbc. Firstly, it's a method rather than a property. Secondly,
# when autocommit is off, it implicity starts a transaction from
# the client, so we do not need to do it. CTDS do not seem to do
# this, which means we would have to begin it manually.
previous_autocommit = conn.autocommit_state
conn.autocommit(False)
cursor = conn.cursor()
cursor.execute(f"SELECT * INTO {output_table} FROM ({final_query}) t")
cursor.execute(f"CREATE INDEX ix_patient_id ON {output_table} (patient_id)")
conn.commit()
conn.autocommit(previous_autocommit)
logger.info(f"Downloading results from '{output_table}'")
else:
logger.info(f"Downloading results from previous run in '{output_table}'")
return output_table
def table_exists(self, table_name):
# We don't have access to sys.tables so this seems like the simplest
# way of testing for table existence
cursor = self.get_db_connection().cursor()
try:
cursor.execute(f"SELECT 1 FROM {table_name}")
list(cursor)
return True
# Because we don't want to depend on a specific database driver we
# can't catch a specific exception class here
except Exception as e:
if "Invalid object name" in str(e):
return False
else:
raise
def assert_database_exists_and_is_writable(self, db_name):
# As above, there's probably a better way of doing this
test_table = f"{db_name}..test_{uuid.uuid4().hex}"
cursor = self.get_db_connection().cursor()
try:
cursor.execute(f"SELECT * INTO {test_table} FROM (select 1 as foo) t")
cursor.execute(f"DROP TABLE {test_table}")
# Because we don't want to depend on a specific database driver we
# can't catch a specific exception class here
except Exception as e:
if "Database does not exist" in str(e):
raise RuntimeError(f"Temporary database '{db_name}' does not exist")
else:
raise
def get_db_connection(self, force_reconnect=False):
if self._db_connection:
if not force_reconnect:
return self._db_connection
else:
self._db_connection.close()
self._db_connection = LoggingDatabaseConnection(
logger,
mssql_dbapi_connection_from_url(self.database_url),
truncate=self.truncate_sql_logs,
time_stats=True,
)
return self._db_connection
def close(self):
if self._db_connection:
self._db_connection.close()
self._db_connection = None
def get_therapeutic_risk_groups(self):
therapeutics_risk_group_variables = set()
for name, (query_type, query_args) in self.covariate_definitions.items():
if (
query_type == "with_covid_therapeutics"
and query_args["returning"] == "risk_group"
):
therapeutics_risk_group_variables.add(name)
return therapeutics_risk_group_variables
def get_queries(self, covariate_definitions):
output_columns = {}
table_queries = {}
for name, (query_type, query_args) in covariate_definitions.items():
# So we can safely mutate these below
query_args = query_args.copy()
# These arguments are not used in generating column data and the
# corresponding functions do not accept them
query_args.pop("return_expectations", None)
is_hidden = query_args.pop("hidden", False)
# Fixed values are the simplest case
if query_type == "fixed_value":
output_columns[name] = self.get_fixed_value_expression(**query_args)
# `categorised_as` columns don't generate their own table query,
# they're just a CASE expression over columns generated by other
# queries
elif query_type == "categorised_as":
output_columns[name] = self.get_case_expression(
output_columns, **query_args
)
# `value_from` columns also don't generate a table, they just take
# a value from another table
elif query_type == "value_from":
assert query_args["source"] in table_queries
output_columns[name] = self.get_column_expression(**query_args)
# As do `aggregate_of` columns
elif query_type == "aggregate_of":
output_columns[name] = self.get_aggregate_expression(
output_columns, **query_args
)
else:
column_args = pop_keys_from_dict(
query_args, ["column_type", "date_format"]
)
sql_list = self.get_queries_for_column(
name, query_type, query_args, output_columns
)
# Wrap the final SELECT query so that it writes its results
# into the appropriate temporary table
sql_list[-1] = (
f"-- Query for {name}\n"
f"SELECT * INTO #{name} FROM ({sql_list[-1]}) t"
)
# Add the index query
sql_list.append(
f"CREATE CLUSTERED INDEX patient_id_ix ON #{name} (patient_id)"
)
table_queries[name] = sql_list
# The first column should always be patient_id so we can join on it
output_columns[name] = self.get_column_expression(
source=name,
returning=query_args.get("returning", "value"),
**column_args,
)
output_columns[name].is_hidden = is_hidden
# If the population query defines its own temporary table then we use
# that as the primary table to query against and left join everything
# else against that. Otherwise, we use the `Patient` table.
if "population" in table_queries:
primary_table = "#population"
patient_id_expr = ColumnExpression("#population.patient_id")
else:
primary_table = "Patient"
patient_id_expr = ColumnExpression("Patient.Patient_ID")
# Insert `patient_id` as the first column
output_columns = dict(patient_id=patient_id_expr, **output_columns)
output_columns_str = ",\n ".join(
f"{expr} AS [{name}]"
for (name, expr) in output_columns.items()
if not expr.is_hidden and name != "population"
)
joins = [
f"LEFT JOIN #{name} ON #{name}.patient_id = {patient_id_expr}"
for name in table_queries
if name != "population"
]
wheres = [f'{output_columns["population"]} = 1']
if not self.include_t1oo:
# If this query has not been explictly flagged as including T1OO patients
# then we add an extra JOIN on the allowed patients table to ensure that we
# only include patients which exist in that table
#
# PLEASE NOTE: This logic is referenced in our public documentation, so if
# we make any changes here we should ensure that the documentation is kept
# up-to-date:
# https://github.com/opensafely/documentation/blob/ea2e1645/docs/type-one-opt-outs.md
#
# From Cohort Extractor's point of view, the construction of the "allowed
# patients" table is opaque. For discussion of the approach currently used
# to populate this see:
# https://docs.google.com/document/d/1nBAwDucDCeoNeC5IF58lHk6LT-RJg6YZRp5RRkI7HI8/
joins.append(
f"JOIN {ALLOWED_PATIENTS_TABLE} ON {ALLOWED_PATIENTS_TABLE}.Patient_ID = {patient_id_expr}",
)
# This condition is theoretically redundant given the RIGHT JOIN above, but
# it feels safer to be explicit here
wheres.append(
f"{ALLOWED_PATIENTS_TABLE}.Patient_ID IS NOT NULL",
)
joins_str = "\n ".join(joins)
where_str = " AND ".join(wheres)
joined_output_query = f"""
-- Join all columns for final output
SELECT
{output_columns_str}
FROM
{primary_table}
{joins_str}
WHERE {where_str}
"""
all_queries = []
for sql_list in table_queries.values():
all_queries.extend(sql_list)
all_queries.append(joined_output_query)
log_stats(
logger,
output_column_count=len(output_columns),
table_count=len(table_queries),
table_joins_count=len(joins),
)
return all_queries
def get_column_expression(self, column_type, source, returning, date_format=None):
default_value = self.get_default_value_for_type(column_type)
# Zero is a legitimate IMD return value so we can't use it to indicate NULL.
# Instead we use -1, which matches the value used in the database when there's
# an address record with an unknown IMD. Obviously implementing this is a
# special case here is terrible, but there's no other way of doing it without
# serious refactoring elsewhere.
if returning == "index_of_multiple_deprivation":
default_value = -1
column_expr = f"#{source}.{escape_identifer(returning)}"
if column_type == "date":
column_expr = truncate_date(column_expr, date_format)
return ColumnExpression(
f"ISNULL({column_expr}, {quote(default_value)})",
type=column_type,
default_value=default_value,
source_tables=[f"#{source}"],
date_format=date_format,
)
def get_fixed_value_expression(self, value, column_type, date_format=None):
return ColumnExpression(
quote(value, reformat_dates=False),
type=column_type,
default_value=self.get_default_value_for_type(column_type),
source_tables=[],
date_format=date_format,
)
def get_default_value_for_type(self, column_type):
if column_type == "date":
return ""
elif column_type == "str":
return ""
elif column_type == "bool":
return 0
elif column_type == "int":
return 0
elif column_type == "float":
return 0.0
elif column_type == "bytes":
return 0
else:
raise ValueError(f"Unhandled column type: {column_type}")
def get_case_expression(
self, other_columns, column_type, category_definitions, date_format=None
):
category_definitions = category_definitions.copy()
defaults = [k for (k, v) in category_definitions.items() if v == "DEFAULT"]
if len(defaults) > 1:
raise ValueError("At most one default category can be defined")
if len(defaults) == 1:
default_value = defaults[0]
category_definitions.pop(default_value)
else:
raise ValueError(
"At least one category must be given the definition 'DEFAULT'"
)
# We pass the `reformat_dates=False` option to the quoting function
# here to preserve date-like strings in the user-supplied ISO format,
# as opposed to converting to MSSQL-friendly format which we otherwise
# do by default. The distinction here is between dates used as inputs
# and dates used as outputs. We always need to supply input dates in a
# format which will be consistently parsed by MSSQL indepent of locale
# settings (amazingly, this does not include ISO format). But we always
# want the dates we output to be in ISO format. In almost all cases
# user-supplied dates are functioning as inputs, which is why
# reformatting is our default behaviour; but in this context they
# function as outputs so we need to do something different. (Note that
# if these outputs later go on to be supplied as inputs to other
# functions then the `date_expressions` code will ensure they are
# parsed correctly.)
def quote_category(value):
return quote(value, reformat_dates=False)
# For each column already defined, determine its corresponding "empty"
# value (i.e. the default value for that column's type). This allows us
# to support implicit boolean conversion because we know what the
# "falsey" value for each column should be.
empty_value_map = {
name: column.default_value for name, column in other_columns.items()
}
clauses = []
tables_used = set()
for category, expression in category_definitions.items():
# The column references in the supplied expression need to be
# rewritten to ensure they refer to the correct CTE. The formatting
# function also ensures that the expression matches the very
# limited subset of SQL we support here.
formatted_expression, names_used = format_expression(
expression, other_columns, empty_value_map=empty_value_map
)
clauses.append(
f"WHEN ({formatted_expression}) THEN {quote_category(category)}"
)
# Record all the source tables used in evaluating the expression
for name in names_used:
tables_used.update(other_columns[name].source_tables)
return ColumnExpression(
f"CASE {' '.join(clauses)} ELSE {quote_category(default_value)} END",
type=column_type,
# Note, confusingly, this is not the same as the `default_value`
# used above. Above it refers to the value the case-expression will
# default to in case of no match. Below it refers to the "empty"
# value for the column type which is almost always the empty string
# apart from bools and ints where it's zero.
default_value=self.get_default_value_for_type(column_type),
source_tables=list(tables_used),
date_format="YYYY-MM-DD" if column_type == "date" else None,
)
def get_aggregate_expression(
self, other_columns, column_type, column_names, aggregate_function
):
assert aggregate_function in ("MIN", "MAX")
default_value = self.get_default_value_for_type(column_type)
# In other databases we could use GREATEST/LEAST to aggregate over
# columns, but for MSSQL we need to use this Table Value Constructor
# trick: https://stackoverflow.com/a/6871572
components = ", ".join(f"({other_columns[name]})" for name in column_names)
aggregate_expression = (
f"SELECT {aggregate_function}(value)"
f" FROM (VALUES {components}) AS _table(value)"
# This is slightly awkward: MIN and MAX ignore NULL values which is
# the behaviour we want here. For instance, given a column like:
#
# minimum_of("covid_test_date", "hosptial_admission_date")
#
# We want this value to be equal to "covid_test_date" for patients
# which have not been admitted to hospital (i.e. patients for which
# "hospital_admission_date" is NULL).
#
# However, the values we have here have already been passed through
# the `ISNULL` function to replace NULLs with a default value
# (which for dates is the empty string). This means that if we just
# took the minimum over these values in the example above, we'd get
# the empty string from "hosptial_admission_date" rather than the
# value in "covid_test_date".
#
# To workaround this we add a WHERE clause to filter out any
# default values (essentially treating them as if they were NULL).
# This gives us the result we want but it does mean we can't
# distinguish e.g. a recorded value of 0.0 from a missing value.
# This, however, is a general problem with the way we handle NULLs
# in our system, and so we're not introducing any new difficulty
# here. (It's also unlikely to be a problem in practice.)
f" WHERE value != {quote(default_value)}"
)
# Keep track of all source tables used in evaluating this aggregate
tables_used = set()
for name in column_names:
tables_used.update(other_columns[name].source_tables)
return ColumnExpression(
f"ISNULL(({aggregate_expression}), {quote(default_value)})",
type=column_type,
default_value=default_value,
source_tables=list(tables_used),
# It's already been checked that date_format is consistent across
# the source columns, so we just grab the first one and use the
# date_format from that
date_format=other_columns[column_names[0]].date_format,
)
def execute_queries(self, queries):
cursor = self.get_db_connection().cursor()
for query in queries:
comment_match = re.match(r"^\s*\-\-\s*(.+)\n", query)
if comment_match:
event_name = comment_match.group(1)
logger.info(f"Running: {event_name}")
else:
event_name = None
cursor.execute(query, log_desc=event_name)
return cursor
def get_queries_for_column(
self, column_name, query_type, query_args, output_columns
):
method_name = f"patients_{query_type}"
method = getattr(self, method_name)
# We need to make available the SQL expression for each column in the
# output so far in case date expressions in the current column need to
# refer to these
self.output_columns = output_columns
# Keep track of the current column name for debugging purposes
self._current_column_name = column_name
return_value = method(**query_args)
self._current_column_name = None
# We want to allow the query methods to return just a single SQL string
# which we automatically wrap in a list
if isinstance(return_value, str):
return_value = [return_value]
return return_value
def create_codelist_table(self, codelist, case_sensitive=True):
table_name = self.get_temp_table_name("codelist")
if codelist.has_categories:
values = list(set(codelist))
else:
values = list({(code, "") for code in codelist})
# Depending on the case-sensitivity of the code system the columns in question
# use different collations and we need to use a matching one here
collation = "Latin1_General_BIN" if case_sensitive else "Latin1_General_CI_AS"
max_code_len = max(len(code) for (code, category) in values)
queries = [
f"""
-- Uploading codelist for {self._current_column_name}
CREATE TABLE {table_name} (
code VARCHAR({max_code_len}) COLLATE {collation} NOT NULL PRIMARY KEY,
category VARCHAR(MAX)
)
"""
]
queries += make_batches_of_insert_statements(
table_name, ("code", "category"), values
)
return table_name, queries
def get_temp_table_name(self, suffix):
# The hash prefix indicates a temporary table
table_name = f"#tmp{self.next_temp_table_id}_"
self.next_temp_table_id += 1
# We include the current column name if available for ease of debugging
if self._current_column_name:
table_name += f"{self._current_column_name}_"
table_name += suffix
return table_name
def get_date_condition(self, table, date_expr, between):
"""
Takes a table name, an SQL expression representing a date (which can
just be a column name on the table, or something more complicated) and
a date interval.
Returns two fragements of SQL: a "condition" and a "join"
The condition is SQL which evaluates true when `date_expr` is in the
supplied period.
The join provides the (possibly empty) JOINs which need to be appended
to "table" in order to evaluate the condition.
"""
if between is None:
between = (None, None)
min_date, max_date = between
min_date_expr, join_tables1 = self.date_ref_to_sql_expr(min_date)
max_date_expr, join_tables2 = self.date_ref_to_sql_expr(max_date)
date_expr = MSSQLDateFormatter.cast_as_date(date_expr)
min_date_expr = MSSQLDateFormatter.cast_as_date(min_date_expr)
max_date_expr = MSSQLDateFormatter.cast_as_date(max_date_expr)
joins = [
f"LEFT JOIN {join_table}\n"
f"ON {join_table}.patient_id = {table}.patient_id"
for join_table in set(join_tables1 + join_tables2)
]
join_str = "\n".join(joins)
if min_date_expr is not None and max_date_expr is not None:
return (
f"{date_expr} BETWEEN {min_date_expr} AND {max_date_expr}",
join_str,
)
elif min_date_expr is not None:
return f"{date_expr} >= {min_date_expr}", join_str
elif max_date_expr is not None:
return f"{date_expr} <= {max_date_expr}", join_str
else:
return "1=1", join_str
def get_date_sql(self, table, *date_expressions):
"""
Given a table name and one or more date expressions return the
corresponding SQL expressions followed by a fragment of SQL supplying
any necessary JOINs
"""
all_join_tables = set()
sql_expressions = []
for date_expression in date_expressions:
assert date_expression is not None
sql_expression, join_tables = self.date_ref_to_sql_expr(date_expression)
sql_expressions.append(sql_expression)
all_join_tables.update(join_tables)
joins = [
f"LEFT JOIN {join_table}\n"
f"ON {join_table}.patient_id = {table}.patient_id"
for join_table in all_join_tables
]
join_str = "\n".join(joins)
return (*sql_expressions, join_str)
def date_ref_to_sql_expr(self, date):
"""
Given a date reference return its corresponding SQL expression,
together with a list of any tables to which this expression refers
"""
if date is None:
return None, []
# Simple date literals
if is_iso_date(date):
return quote(date), []
# More complicated date expressions which reference other tables
formatter = MSSQLDateFormatter(self.output_columns)
date_expr, column_name = formatter(date)
tables = self.output_columns[column_name].source_tables
return date_expr, tables
def patients_age_as_of(self, reference_date):
date_expr, date_joins = self.get_date_sql("Patient", reference_date)
return f"""
SELECT
Patient.Patient_ID AS patient_id,
CASE WHEN
dateadd(year, datediff (year, DateOfBirth, {date_expr}), DateOfBirth) > {date_expr}
THEN
datediff(year, DateOfBirth, {date_expr}) - 1
ELSE
datediff(year, DateOfBirth, {date_expr})
END AS value
FROM Patient
{date_joins}
"""
def patients_date_of_birth(self):
return """
SELECT Patient_ID AS patient_id, DateOfBirth AS value FROM Patient
"""
def patients_sex(self):
return """
SELECT
Patient_ID AS patient_id,
Sex AS value
FROM Patient
"""
def patients_all(self):
"""
All patients
"""
return """
SELECT Patient_ID AS patient_id, 1 AS value
FROM Patient
"""
def patients_random_sample(self, percent):
"""
A random sample of approximately `percent` patients
"""
# See
# https://docs.microsoft.com/en-us/previous-versions/software-testing/cc441928(v=msdn.10)?redirectedfrom=MSDN
# A TABLESAMPLE clause is more efficient, but its
# approximations don't work with small numbers, and we might
# want to use this method for small numbers (and certainly do
# in the tests!)
assert percent, "Must specify a percentage greater than zero"
return f"""
SELECT Patient_ID, 1 AS value
FROM Patient
WHERE (ABS(CAST(
(BINARY_CHECKSUM(*) *
RAND()) as int)) % 100) < {quote(percent)}
"""
def patients_most_recent_bmi(
self,
# Set date limits
between=None,
minimum_age_at_measurement=16,
# Add an additional column indicating when measurement was taken
include_date_of_match=False,
):
"""
Return patients' most recent BMI (in the defined period) either
computed from weight and height measurements or, where they are not
availble, from recorded BMI values. Measurements taken when a patient
was below the minimum age are ignored. The height measurement can be
taken before (but not after) the defined period as long as the patient
was over the minimum age at the time.
Optionally returns an additional column with the date of the
measurement. If the BMI is computed from weight and height then we use
the date of the weight measurement for this.
"""
# From https://github.com/ebmdatalab/tpp-sql-notebook/issues/10:
#
# 1) BMI calculated from last recorded height and weight
#
# 2) If height and weight is not available, then take latest
# recorded BMI. Both values must be recorded when the patient
# is >=16, weight must be within the last 10 years
date_condition, date_joins = self.get_date_condition(
"CodedEvent", "ConsultationDate", between
)
bmi_code = "22K.."
# XXX these two sets of codes need validating. The final in
# each list is the canonical version according to TPP
weight_codes = [
"X76C7", # Concept containing "body weight" terms:
"22A..", # O/E weight
]
height_codes = [
"XM01E", # Concept containing height/length/stature/growth terms:
"229..", # O/E height
]
bmi_cte = f"""
SELECT t.Patient_ID, t.BMI, t.ConsultationDate
FROM (
SELECT CodedEvent.Patient_ID, NumericValue AS BMI, ConsultationDate,
ROW_NUMBER() OVER (
PARTITION BY CodedEvent.Patient_ID ORDER BY ConsultationDate DESC, CodedEvent_ID
) AS rownum
FROM CodedEvent
{date_joins}
WHERE CTV3Code = {quote(bmi_code)} AND {date_condition}
) t
WHERE t.rownum = 1
"""
patients_cte = """
SELECT Patient_ID, DateOfBirth
FROM Patient
"""
weight_codes_sql = codelist_to_sql(weight_codes)
weights_cte = f"""
SELECT t.Patient_ID, t.weight, t.ConsultationDate
FROM (
SELECT CodedEvent.Patient_ID, NumericValue AS weight, ConsultationDate,
ROW_NUMBER() OVER (
PARTITION BY CodedEvent.Patient_ID ORDER BY ConsultationDate DESC, CodedEvent_ID
) AS rownum
FROM CodedEvent
{date_joins}
WHERE CTV3Code IN ({weight_codes_sql}) AND {date_condition}
) t
WHERE t.rownum = 1
"""
height_codes_sql = codelist_to_sql(height_codes)
# The height date restriction is different from the others. We don't
# mind using old values as long as the patient was old enough when they
# were taken.
height_date_condition, height_date_joins = self.get_date_condition(
"CodedEvent",
"ConsultationDate",
remove_lower_date_bound(between),
)
heights_cte = f"""
SELECT t.Patient_ID, t.height, t.ConsultationDate
FROM (
SELECT CodedEvent.Patient_ID, NumericValue AS height, ConsultationDate,
ROW_NUMBER() OVER (
PARTITION BY CodedEvent.Patient_ID ORDER BY ConsultationDate DESC, CodedEvent_ID
) AS rownum
FROM CodedEvent
{height_date_joins}
WHERE CTV3Code IN ({height_codes_sql}) AND {height_date_condition}
) t
WHERE t.rownum = 1
"""
min_age = int(minimum_age_at_measurement)
return f"""
SELECT
patients.Patient_ID AS patient_id,
ROUND(COALESCE(weight/SQUARE(NULLIF(height, 0)), bmis.BMI), 1) AS value,
CASE
WHEN weight IS NULL OR height IS NULL THEN bmis.ConsultationDate
ELSE weights.ConsultationDate
END AS date
FROM ({patients_cte}) AS patients
LEFT JOIN ({weights_cte}) AS weights
ON weights.Patient_ID = patients.Patient_ID AND DATEDIFF(YEAR, patients.DateOfBirth, weights.ConsultationDate) >= {min_age}
LEFT JOIN ({heights_cte}) AS heights
ON heights.Patient_ID = patients.Patient_ID AND DATEDIFF(YEAR, patients.DateOfBirth, heights.ConsultationDate) >= {min_age}
LEFT JOIN ({bmi_cte}) AS bmis
ON bmis.Patient_ID = patients.Patient_ID AND DATEDIFF(YEAR, patients.DateOfBirth, bmis.ConsultationDate) >= {min_age}
-- XXX maybe add a "WHERE NULL..." here
"""
def _summarised_recorded_value(
self,
codelist,
on_most_recent_day_of_measurement,
between,
include_date_of_match,
summary_function,
):
coded_event_table, coded_event_column = coded_event_table_column(codelist)
date_condition, date_joins = self.get_date_condition(
coded_event_table, "ConsultationDate", between
)
codelist_sql = codelist_to_sql(codelist)
if on_most_recent_day_of_measurement:
# The first query finds, for each patient, the most recent day on which
# they've had a measurement. The final query selects, for each patient, the
# aggregated value on that day.
# Note, there's a CAST in the JOIN condition but apparently SQL Server can still
# use an index for this. See: https://stackoverflow.com/a/25564539
latest_date_table = self.get_temp_table_name("latest_date")
return [
f"""
SELECT {coded_event_table}.Patient_ID, CAST(MAX(ConsultationDate) AS date) AS day
INTO {latest_date_table}
FROM {coded_event_table}
{date_joins}
WHERE {coded_event_column} IN ({codelist_sql}) AND {date_condition}
GROUP BY {coded_event_table}.Patient_ID
""",
f"""
CREATE CLUSTERED INDEX ix ON {latest_date_table} (Patient_ID, day)
""",
f"""
SELECT
{latest_date_table}.Patient_ID AS patient_id,
{summary_function}({coded_event_table}.NumericValue) AS value,
{latest_date_table}.day AS date
FROM {latest_date_table}
LEFT JOIN {coded_event_table}
ON (
{coded_event_table}.Patient_ID = {latest_date_table}.Patient_ID
AND CAST({coded_event_table}.ConsultationDate AS date) = {latest_date_table}.day
)
WHERE
{coded_event_table}.{coded_event_column} IN ({codelist_sql})
GROUP BY {latest_date_table}.Patient_ID, {latest_date_table}.day
""",
]