Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 60 additions & 9 deletions pyard/data_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
# List of expression characters
expression_chars = ['N', 'Q', 'L', 'S']

ars_mapping_tables = ['dup_g', 'dup_lg', 'dup_lgx', 'g_group', 'lg_group', 'lgx_group']
ars_mapping_tables = ['dup_g', 'dup_lg', 'dup_lgx', 'g_group', 'lg_group', 'lgx_group', 'exon_group']
ARSMapping = namedtuple("ARSMapping", ars_mapping_tables)


Expand All @@ -54,7 +54,10 @@ def get_n_field_allele(allele: str, n: int) -> str:
last_char = allele[-1]
fields = allele.split(':')
if last_char in expression_chars and len(fields) > n:
return ':'.join(fields[0:n]) + last_char

# don't actually do this; it makes things like A*02:01:01L which is invalid
#return ':'.join(fields[0:n]) + last_char
return ':'.join(fields[0:n])
else:
return ':'.join(fields[0:n])

Expand All @@ -75,8 +78,8 @@ def generate_ars_mapping(db_connection: sqlite3.Connection, imgt_version):
g_group = db.load_dict(db_connection, table_name='g_group', columns=('allele', 'g'))
lg_group = db.load_dict(db_connection, table_name='lg_group', columns=('allele', 'lg'))
lgx_group = db.load_dict(db_connection, table_name='lgx_group', columns=('allele', 'lgx'))
return ARSMapping(dup_g=dup_g, dup_lg=dup_lg, dup_lgx=dup_lgx,
g_group=g_group, lg_group=lg_group, lgx_group=lgx_group)
exon_group = db.load_dict(db_connection, table_name='exon_group', columns=('allele', 'exon'))
return ARSMapping(dup_g=dup_g, dup_lg=dup_lg, dup_lgx=dup_lgx, g_group=g_group, lg_group=lg_group, lgx_group=lgx_group, exon_group=exon_group)

ars_url = f'{IMGT_HLA_URL}{imgt_version}/wmda/hla_nom_g.txt'
df = pd.read_csv(ars_url, skiprows=6, names=["Locus", "A", "G"], sep=";").dropna()
Expand Down Expand Up @@ -143,18 +146,21 @@ def generate_ars_mapping(db_connection: sqlite3.Connection, imgt_version):
])
lgx_group = df_lgx.set_index('A')['lgx'].to_dict()

df_exon = pd.concat([ df[['A', '3d']].rename(columns={'3d': 'exon'}), ])
exon_group = df_exon.set_index('A')['exon'].to_dict()

db.save_dict(db_connection, table_name='dup_g', dictionary=dup_g, columns=('allele', 'g_group'))
db.save_dict(db_connection, table_name='dup_lg', dictionary=dup_lg, columns=('allele', 'lg_group'))
db.save_dict(db_connection, table_name='dup_lgx', dictionary=dup_lgx, columns=('allele', 'lgx_group'))
db.save_dict(db_connection, table_name='g_group', dictionary=g_group, columns=('allele', 'g'))
db.save_dict(db_connection, table_name='lg_group', dictionary=lg_group, columns=('allele', 'lg'))
db.save_dict(db_connection, table_name='lgx_group', dictionary=lgx_group, columns=('allele', 'lgx'))
db.save_dict(db_connection, table_name='exon_group', dictionary=exon_group, columns=('allele', 'exon'))

return ARSMapping(dup_g=dup_g, dup_lg=dup_lg, dup_lgx=dup_lgx,
g_group=g_group, lg_group=lg_group, lgx_group=lgx_group)
return ARSMapping(dup_g=dup_g, dup_lg=dup_lg, dup_lgx=dup_lgx, g_group=g_group, lg_group=lg_group, lgx_group=lgx_group, exon_group=exon_group)


def generate_alleles_and_xx_codes(db_connection: sqlite3.Connection, imgt_version):
def generate_alleles_and_xx_codes_and_who(db_connection: sqlite3.Connection, imgt_version):
"""
Checks to see if there's already an allele list file for the `imgt_version`
in the `data_dir` directory. If not, will download the file and create
Expand Down Expand Up @@ -186,10 +192,17 @@ def generate_alleles_and_xx_codes(db_connection: sqlite3.Connection, imgt_versio

if db.table_exists(db_connection, 'alleles'):
valid_alleles = db.load_set(db_connection, 'alleles')
who_alleles = db.load_set(db_connection, 'who_alleles')

who_codes = db.load_dict(db_connection, 'who_group',
('who', 'allele_list'))
who_codes = {k: v.split('/') for k, v in who_codes.items()}

xx_codes = db.load_dict(db_connection, 'xx_codes',
('allele_1d', 'allele_list'))
xx_codes = {k: v.split('/') for k, v in xx_codes.items()}
return valid_alleles, xx_codes

return valid_alleles, who_alleles, xx_codes, who_codes

# Create a Pandas DataFrame from the mac_code list file
# Skip the header (first 6 lines) and use only the Allele column
Expand All @@ -203,6 +216,8 @@ def generate_alleles_and_xx_codes(db_connection: sqlite3.Connection, imgt_versio
# All 2-field, 3-field and the original Alleles are considered valid alleles
allele_df['2d'] = allele_df['Allele'].apply(get_2field_allele)
allele_df['3d'] = allele_df['Allele'].apply(get_3field_allele)
# this says all 3rd and 2nd field versions of longer alleles are valid
who_alleles = set(allele_df['Allele'])
valid_alleles = set(allele_df['Allele']). \
union(set(allele_df['2d'])). \
union(set(allele_df['3d']))
Expand All @@ -216,6 +231,9 @@ def generate_alleles_and_xx_codes(db_connection: sqlite3.Connection, imgt_versio
.apply(lambda x: list(x['Allele'])) \
.to_dict()

# Save this version of the who
#db.save_set(db_connection, 'who', valid_alleles, 'allele')

# Update xx codes with broads and splits
for broad, splits in broad_splits_dna_mapping.items():
for split in splits:
Expand All @@ -226,13 +244,46 @@ def generate_alleles_and_xx_codes(db_connection: sqlite3.Connection, imgt_versio

# Save this version of the valid alleles
db.save_set(db_connection, 'alleles', valid_alleles, 'allele')
# Save this version of the who alleles
db.save_set(db_connection, 'who_alleles', who_alleles, 'allele')
# Save this version of xx codes
flat_xx_codes = {k: '/'.join(sorted(v, key=functools.cmp_to_key(smart_sort_comparator)))
for k, v in xx_codes.items()}
db.save_dict(db_connection, 'xx_codes', flat_xx_codes,
('allele_1d', 'allele_list'))

return valid_alleles, xx_codes
# W H O

# Create who mapping from the unique alleles in the 2-field column
who_df1 = pd.DataFrame(allele_df['Allele'].unique(), columns=['Allele'])
who_df1['1d'] = allele_df['Allele'].apply(lambda x: x.split(":")[0])
who_df2 = pd.DataFrame(allele_df['Allele'].unique(), columns=['Allele'])
who_df2['2d'] = allele_df['Allele'].apply(get_2field_allele)
who_df3 = pd.DataFrame(allele_df['Allele'].unique(), columns=['Allele'])
who_df3['3d'] = allele_df['Allele'].apply(get_3field_allele)

# make one df
who_df1.rename(columns = {'1d':'input'}, inplace = True)
who_df2.rename(columns = {'2d':'input'}, inplace = True)
who_df3.rename(columns = {'3d':'input'}, inplace = True)
who_codes = pd.concat([who_df1, who_df2, who_df3])

# remove valid alleles from who_codes to avoid recursion
# there is a more pythonic way to do this for sure
for k in who_alleles:
if k in who_codes['input']:
who_codes.drop(labels=k, axis='index')

# who_codes maps a first field name to its 2 field expansion
who_group = who_codes.groupby(['input']).apply(lambda x: list(x['Allele'])).to_dict()

# dictionary
flat_who_group= {k: '/'.join(sorted(v, key=functools.cmp_to_key(smart_sort_comparator)))
for k, v in who_group.items()}
db.save_dict(db_connection, table_name='who_group', dictionary=flat_who_group, columns=('who', 'allele_list'))


return valid_alleles, who_alleles, xx_codes, who_codes


def generate_mac_codes(db_connection: sqlite3.Connection, refresh_mac: bool):
Expand Down
45 changes: 34 additions & 11 deletions pyard/pyard.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
#
# py-ard
# Copyright (c) 2020 Be The Match operated by National Marrow Donor Program. All Rights Reserved.
# Copyright (c) 2020 Be The Match operated by National Marrow Donor Program.
# All Rights Reserved.
#
# This library is free software; you can redistribute it and/or modify it
# under the terms of the GNU Lesser General Public License as published
Expand All @@ -27,7 +28,7 @@

from . import db
from .data_repository import generate_ars_mapping, \
generate_mac_codes, generate_alleles_and_xx_codes, \
generate_mac_codes, generate_alleles_and_xx_codes_and_who, \
generate_serology_mapping, generate_v2_to_v3_mapping
from .db import is_valid_mac_code, mac_code_to_alleles, v2_to_v3_allele
from .smart_sort import smart_sort_comparator
Expand Down Expand Up @@ -62,7 +63,7 @@ def __init__(self, imgt_version: str = 'Latest',
# Load MAC codes
generate_mac_codes(self.db_connection, refresh_mac)
# Load Alleles and XX Codes
self.valid_alleles, self.xx_codes = generate_alleles_and_xx_codes(self.db_connection, imgt_version)
self.valid_alleles, self.who_alleles, self.xx_codes, self.who_group = generate_alleles_and_xx_codes_and_who(self.db_connection, imgt_version)
# Load ARS mappings
self.ars_mappings = generate_ars_mapping(self.db_connection, imgt_version)
# Load Serology mappings
Expand All @@ -73,9 +74,6 @@ def __init__(self, imgt_version: str = 'Latest',
# Close the current read-write db connection
self.db_connection.close()

# reference data is read-only and can be frozen
gc.freeze()

# Re-open the connection in read-only mode as we're not updating it anymore
self.db_connection = db.create_db_connection(data_dir, imgt_version, ro=True)

Expand All @@ -86,7 +84,7 @@ def __del__(self):
"""
self.db_connection.close()

@functools.lru_cache(maxsize=1000)
@functools.lru_cache(maxsize=1000000)
def redux(self, allele: str, ars_type: str) -> str:
"""
Does ARS reduction with allele and ARS type
Expand All @@ -98,7 +96,6 @@ def redux(self, allele: str, ars_type: str) -> str:
:return: ARS reduced allele
:rtype: str
"""

# deal with leading 'HLA-'
if HLA_regex.search(allele):
hla, allele_name = allele.split("-")
Expand Down Expand Up @@ -135,6 +132,20 @@ def redux(self, allele: str, ars_type: str) -> str:
# for 'lgx' when allele is not in G group,
# return allele with only first 2 field
return ':'.join(allele.split(':')[0:2])
elif ars_type == "W":
# new ars_type which is full WHO expansion
if self._is_who_allele(allele):
return allele
if allele in self.who_group:
return self.redux_gl("/".join(self.who_group[allele]), ars_type)
else:
return allele
elif ars_type == "exon":
if allele in self.ars_mappings.exon_group:
return self.ars_mappings.exon_group[allele]
else:
# for 'exon' return allele with only first 3 fields
return ':'.join(allele.split(':')[0:3])
else:
if self._remove_invalid:
if self._is_valid_allele(allele):
Expand All @@ -144,7 +155,7 @@ def redux(self, allele: str, ars_type: str) -> str:
else:
return allele

@functools.lru_cache(maxsize=1000)
@functools.lru_cache(maxsize=1000000)
def redux_gl(self, glstring: str, redux_type: str) -> str:
"""
Does ARS reduction with gl string and ARS type
Expand Down Expand Up @@ -265,6 +276,14 @@ def is_v2(allele: str) -> bool:
"""
return '*' in allele and ':' not in allele

def _is_who_allele(self, allele):
"""
Test if allele is a WHO allele in the current imgt database
:param allele: Allele to test
:return: bool to indicate if allele is valid
"""
return allele in self.who_alleles

def _is_valid_allele(self, allele):
"""
Test if allele is valid in the current imgt database
Expand Down Expand Up @@ -365,9 +384,13 @@ def isvalid(self, allele: str) -> bool:
"""
if allele == '':
return False

# removed the test for is_v2()
# this leads to an infinte recursion if the input matches these patterns
# but is not ultimately valid e.g. DRB3*NNNN

if not self.is_mac(allele) and \
not self.is_serology(allele) and \
not self.is_v2(allele):
not self.is_serology(allele):
# Alleles ending with P or G are valid_alleles
if allele.endswith(('P', 'G')):
# remove the last character
Expand Down
Loading