In [1]:
%load_ext autoreload
%autoreload 2

# 1. 病种高套（BZGY）基本建模思路
  
* 病种高套指在DRGs或者单病种付费支付制度下，医院高编病组,套取超额报销的违规行为。
  
<p>住院结束后医院需要把入院诊断，出院诊断和每日使用明细和花费上传给医保局. 医院可以轻易修改出院诊断，但每日明细耗材涉及到药品进销存等，不易篡改。因为我们可以使用大量住院单据，训练从明细->出院诊断的分类模型。对于每个住院单据，我们使用分类模型判断其出院诊断，如果模型输出和实际商保诊断不符，且模型推断的诊断对应报销金额大于医院实际上报诊断时，则该单据有较大的病种高套嫌疑</p>
  
  

 

  
  

# 2. ULMFIT 迁移学习模型

#### Universal Language Model Fine-tuning for Text Classification
https://arxiv.org/abs/1801.06146

In [2]:
%reload_ext autoreload
%autoreload 2
%matplotlib inline

# pick ICDs with enough sample
import pyspark.sql.functions as pyspark_F


from fastai import *
from fastai.text import *
from pyspark import SparkConf
from pyspark.sql import SparkSession

from pyspark.sql.functions import pandas_udf,PandasUDFType,concat_ws,collect_list
from pyspark.sql.types import *

# spark.stop()
conf = SparkConf().setAppName('BZGT').setMaster('local[*]')
spark = SparkSession.builder.config(conf=conf).config("spark.driver.memory", "35g").config("spark.local.dir", "/media/data1/data/tmp").getOrCreate()
spark

## 4.1 在全量单据上训练语言模型

In [2]:
# 全量数据！！！

# DETAIL_PATH_FULL = "./data/1month/detail"

# detail_full = spark.read.parquet(DETAIL_PATH_FULL).union(spark.read.parquet(DETAIL_PATH_INHOSP))

# print(detail_full.count())


In [3]:
DETAIL_PATH_INHOSP = "./data/inHosp/detail"

detail_full = spark.read.parquet(DETAIL_PATH_INHOSP)
data_full = detail_full.select(["REGISTER_NO","MED_PROJECT_NAME"]).groupby("REGISTER_NO").agg(concat_ws(" ", collect_list("MED_PROJECT_NAME")).alias("details"))
# 小数据
# data = data

In [6]:
data = data_full.toPandas()


使用 Sentence Piece Encoder 来处理明细名称非标准的问题：https://github.com/google/sentencepiece

In [4]:
# Fastai has a bug regarding ssp:https://forums.fast.ai/t/trouble-saving-with-sentencepiece-tokenizer/42143
def get_default_size(texts, max_vocab_sz):
    "Either max_vocab_sz or one quarter of the number of unique words in `texts`"
    cnt = Counter()
    for t in texts: 
        cnt.update(t.split())
        if len(cnt)//4 > max_vocab_sz: return max_vocab_sz
    res = len(cnt)//4
    while res%8 != 0: res+=1
    return res

def train_sentencepiece_fixed(texts:Collection[str], path:PathOrStr, pre_rules: ListRules=None, post_rules:ListRules=None, 
    vocab_sz:int=None, max_vocab_sz:int=5000, model_type:str='unigram', max_sentence_len:int=20480, lang='en',
    char_coverage=None, tmp_dir='tmp', enc='utf8'):
    "Train a sentencepiece tokenizer on `texts` and save it in `path/tmp_dir`"
    from sentencepiece import SentencePieceTrainer
    cache_dir = Path(path)/tmp_dir
    os.makedirs(cache_dir, exist_ok=True)
    if vocab_sz is None: vocab_sz=get_default_size(texts, max_vocab_sz)
    raw_text_path = cache_dir / 'all_text.out'
    with open(raw_text_path, 'w', encoding=enc) as f: f.write("\n".join(texts))
    spec_tokens = ['\u2581'+s for s in defaults.text_spec_tok]
    SentencePieceTrainer.Train(" ".join([
        f"--input={raw_text_path} --max_sentence_length={max_sentence_len}",
        f"--character_coverage=0.9998",
        f"--unk_id={len(defaults.text_spec_tok)} --pad_id=-1 --bos_id=-1 --eos_id=-1",
        f"--user_defined_symbols={','.join(spec_tokens)}",
        f"--model_prefix={cache_dir/'spm'} --vocab_size={vocab_sz} --model_type={model_type}"]))
    raw_text_path.unlink()
    return cache_dir

class FixSSP(SPProcessor):
     def __init__(self,pre_rules: ListRules=None, post_rules:ListRules=None, vocab_sz:int=None,
                 max_vocab_sz:int=10000, model_type:str='unigram', max_sentence_len:int=25480, lang='cn',
                 char_coverage=None, tmp_dir='tmp', mark_fields:bool=False, include_bos:bool=True, 
                 include_eos:bool=False,  enc='utf8', **kwargs):
            super().__init__(**kwargs)
            self.train_func = partial(train_sentencepiece_fixed, pre_rules=pre_rules, post_rules=post_rules, vocab_sz=vocab_sz,
                max_vocab_sz=max_vocab_sz, model_type=model_type, max_sentence_len=max_sentence_len, lang=lang,
                char_coverage=char_coverage, tmp_dir=tmp_dir, enc=enc)
            
            

In [14]:
# from sentencepiece import SentencePieceProcessor
# tok = SentencePieceProcessor()
# tok.load("./tmp/spm.model")

# tok.encode_as_pieces("穿刺组织活断,尿素氮测定（急诊）,山药（配方颗粒）,静脉穿刺置管术,视黄醇结合蛋白")

['▁',
 '穿',
 '刺',
 '组',
 '织',
 '活',
 '断',
 ',',
 '尿',
 '素',
 '氮',
 '测定',
 '(',
 '急诊',
 ')',
 ',',
 '山',
 '药',
 '(',
 '配',
 '方',
 '颗粒',
 ')',
 ',',
 '静',
 '脉',
 '穿',
 '刺',
 '置',
 '管',
 '术',
 ',',
 '视',
 '黄',
 '醇',
 '结',
 '合',
 '蛋',
 '白']

## 创建语言模型

In [None]:

bs=128
databunch = (TextList.from_df(data,cols="details", processor=[FixSSP(max_vocab_sz=8000)])
             .split_by_rand_pct(0.1, seed=42)
             .label_for_lm()
             .databunch(bs=bs,num_workers=1))


In [17]:
databunch.vocab.itos

['▁xxunk',
 '▁xxpad',
 '▁xxbos',
 '▁xxeos',
 '▁xxfld',
 '▁xxmaj',
 '▁xxup',
 '▁xxrep',
 '▁xxwrep',
 '<unk>',
 '▁',
 '(',
 ')',
 '▁氯化钠',
 '医院',
 '住院',
 '测定',
 '▁鼻导管吸氧',
 '三',
 '、',
 '-',
 '▁血氧饱和度监测',
 '级',
 '▁等',
 '级护理',
 '查',
 '诊',
 '费',
 '▁普通病房床位费',
 '▁一次性真空采血器',
 '▁心电监护',
 '▁自费诊疗及服务项目',
 '▁/',
 '输液',
 '▁静脉',
 '含输液器',
 '▁一次性注射器',
 'c',
 '二级',
 '▁省',
 '▁特级护理',
 'd',
 '▁葡萄糖测定',
 '▁一般专项护理',
 '▁血清',
 '▁静脉置管冲洗',
 '门诊',
 '市离休',
 '颗粒',
 '▁葡萄糖',
 't',
 '▁血清载脂蛋白',
 'b',
 '配',
 '方',
 '▁静脉注射',
 '部位',
 '▁氯化钾',
 'α',
 '▁自费费用',
 '▁计算机图文报告',
 'o',
 '住院可使用的诊疗项目',
 '急诊',
 'ca',
 '▁贴敷疗法',
 '抗原',
 '▁维生素',
 '抗体测定',
 '▁静脉采血',
 '▁糖类',
 'a',
 '▁穴位贴敷治疗',
 'r',
 '五',
 '▁血浆',
 '▁钙测定',
 '▁血',
 '常规',
 '分类',
 '▁氨溴索',
 'e',
 '▁抗',
 'i',
 '抗原测定',
 '一个',
 'aa',
 '▁超',
 '治疗',
 '二',
 '▁氧气雾化吸入',
 'ct',
 '3)',
 '敏',
 '时间测定',
 '▁静脉穿刺置管术',
 '▁一次性留置针',
 '18',
 's',
 '一级',
 '反应蛋白测定',
 '体测定',
 '针',
 '▁镁测定',
 '▁布地奈德',
 '▁皮下注射',
 '▁钾测定',
 '▁钠测定',
 '聚',
 '▁腺苷脱氨酶测定',
 '▁血清胆碱脂酶测定',
 'im',
 'er',
 '▁氯测定',
 'l',
 '▁无机磷测定',
 'igm',


In [13]:
learn = language_model_learner(databunch, AWD_LSTM, drop_mult=0.1, wd=0.1, pretrained=False).to_fp16()
lr = 3e-3
lr *= bs/48  # Scale learning rate by batch size

In [14]:
import warnings
warnings.filterwarnings('ignore')

learn.unfreeze()
learn.fit_one_cycle(1, lr, moms=(0.8,0.7))

epoch,train_loss,valid_loss,accuracy,time
0,1.219202,1.284387,0.704334,43:22


IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



KeyboardInterrupt: 

In [18]:
lm_fns = ['./lm_wt', './lm_wt_vocab']
learn.to_fp32().save(lm_fns[0], with_opt=False)
learn.data.vocab.save(lm_fns[1])

## 4.2 在（drgs）住院单据上精调语言模型-Domain Fine Tune

In [7]:
REGISTER_PATH = "./data/inHosp/register"
DETAIL_PATH = "./data/inHosp/detail"
COMBINED_PATH = "./data/inHosp/combined"

#从医保报销的单据表中抽取住院单据以及其对应明细清单
register = spark.read.parquet(REGISTER_PATH)
detail = spark.read.parquet(DETAIL_PATH)
combine = spark.read.parquet(COMBINED_PATH)
print(register.count())
print(detail.count())
print(combine.count())



### filter out diseases with few samples

min_sample = 200

icd_diagnosis = \
register.select(['out_hosp_diagnosisname','out_hosp_diagnosiscode']).dropna().\
filter(pyspark_F.length("out_hosp_diagnosiscode")>=5).\
withColumn("ICD4", register.out_hosp_diagnosiscode.substr(0,5)).\
groupby(['out_hosp_diagnosisname','ICD4']).count().\
filter(f"count >= {min_sample}").\
sort(["out_hosp_diagnosisname",]).collect()

icd_diagnosis_dict = {}
for row in icd_diagnosis:
    icd_diagnosis_dict[row.ICD4] = row.out_hosp_diagnosisname
len(icd_diagnosis_dict)

### 查看数据

data_inhosp = \
combine.filter(pyspark_F.length("out_hosp_diagnosiscode")>=5).\
withColumn("ICD4", combine.out_hosp_diagnosiscode.substr(0,5)).\
filter(pyspark_F.col('ICD4').isin(list(icd_diagnosis_dict.keys()))).toPandas()
data_inhosp

139229
44195071
136118


Unnamed: 0,register_no,per_no,age,sex,per_type,zone_name,in_hosp_diagnosiscode,in_hosp_diagnosisname,out_hosp_diagnosisname,out_hosp_diagnosiscode,...,has_surgery,sugery_fee,herb_fee,western_fee,cn_med_fee,bed_fee,material_fee,care_fee,treat_fee,ICD4
0,219762348-1-330522012,f91c34b573,36,女,城镇职工,c7cdeaff4b,K29.500,慢性胃炎,慢性胃炎,K29.500,...,1,450.00,0.00,224.85,172.00,70.0,122.40,0.0,59.30,K29.5
1,219762549-1-330522012,fe114e1d78,56,男,城镇职工,c7cdeaff4b,M13.900,关节炎,半月板损伤,M23.308,...,1,2607.01,0.00,2426.48,0.00,150.0,1537.99,0.0,1900.33,M23.3
2,219786156-1-330522012,946df375b9,56,男,城镇职工,c7cdeaff4b,R42.x00,头晕和眩晕,头晕和眩晕,R42.x00,...,0,0.00,0.00,264.78,0.00,13.0,47.15,0.0,87.00,R42.x
3,219851040-1-330522012,5a605afc4b,64,男,城镇居民,c7cdeaff4b,N40.x00,前列腺增生,前列腺增生,N40.x00,...,0,0.00,0.00,348.48,0.00,46.0,8.75,0.0,223.30,N40.x
4,219851061-1-330522012,79af746d7c,27,女,城镇居民,c7cdeaff4b,syjtbx,生育津贴报销,生育津贴报销,syjtbx,...,0,0.00,0.00,0.00,0.00,0.0,0.00,0.0,0.00,syjtb
5,219851488-1-330522012,255f7572c7,76,男,城镇居民,c7cdeaff4b,K80.200-01,胆囊结石,胆囊结石,K80.200-01,...,0,0.00,0.00,1849.39,38.88,100.0,32.00,0.0,91.00,K80.2
6,219908562-1-330522012,f40f42aefa,86,男,城镇居民,c7cdeaff4b,J44.100,慢性阻塞性肺病伴有急性加重,慢性阻塞性肺病伴有急性加重,J44.100,...,1,0.01,885.72,1988.88,0.00,320.0,125.96,0.0,2945.00,J44.1
7,219908631-1-330522012,1df8f0475b,79,男,城镇职工,c7cdeaff4b,R42.x00,头晕和眩晕,头晕和眩晕,R42.x00,...,0,0.00,680.12,1417.21,101.86,180.0,140.40,0.0,739.00,R42.x
8,219935082-1-330522012,ccb2835bd1,72,男,城镇职工,c7cdeaff4b,J44.100,慢性阻塞性肺病伴有急性加重,慢性阻塞性肺病伴有急性加重,J44.100,...,0,0.00,0.00,2245.67,0.00,180.0,101.46,0.0,1485.00,J44.1
9,219986664-1-330522012,61bf28584e,68,男,城镇职工,c7cdeaff4b,I10.x00,特发性(原发性)高血压,特发性(原发性)高血压,I10.x00,...,0,0.00,0.00,706.32,303.96,205.0,76.36,0.0,297.00,I10.x


In [20]:
bs=64

data_fine_tuned = (TextList.from_df(data_inhosp,cols="details", processor=[FixSSP.load("./")])
             .split_by_rand_pct(0.1, seed=42)
             .label_for_lm()
             .databunch(bs=bs,bptt=50,num_workers=1))

In [None]:
lm_fns = ['./lm_wt', './lm_wt_vocab']

learn_lm = language_model_learner(data_fine_tuned, AWD_LSTM, pretrained_fnames=lm_fns, drop_mult=1.0, wd=0.1)

In [32]:
lr = 1e-3
lr *= bs/48

learn_lm.fit_one_cycle(1, lr*10, moms=(0.8,0.7))

epoch,train_loss,valid_loss,accuracy,time
0,0.947487,0.852845,0.80807,11:40


In [None]:
learn_lm.unfreeze()
learn_lm.fit_one_cycle(3, slice(lr/10,lr*10), moms=(0.8,0.7))

In [34]:
learn_lm.save(f'lm_fine_tuned')
learn_lm.save_encoder(f'lm_fine_tuned_enc')

In [37]:
def generate(prefix, predict_func, n_words=500, temp=1):
    out = [i for i in predict_func(
        prefix, n_words, temperature=temp).split(" ")]
    return "   ".join(out)

TEXT = "妇科检查"
predict_func = functools.partial(learn_lm.predict)
# AWD_learn.beam_search
generate(TEXT,predict_func=predict_func)

'妇科检查   ,   一   次   性   真   空   采   血   器   ,   妇   科   常规检查   ,   贴   敷   疗   法   ,   一   次   性   鼻   导   管   ,   自   费   费   用   ,   一   次   性   鼻   导   管   ,   即   毁   式   一   次   性   阴   道   扩   张   器   ,   静脉   置   管   冲   洗   ,   阴   道   灌   洗   上   药   ,   自   费   诊   疗   及   服务   项   目   ,   麻   醉   中   监测   (   <4   小时   )   ,   氯化钠   ,   静脉   穿   刺   置   管   术   ,   计   算   机   图   文   报   告   ,   无   机   磷   测定   ,   静脉   穿   刺   置   管   胺   ,   计   算   机   图   文   报   告   ,   鼻   导   管   吸   氧   ,   无   机   磷   测定   ,   心   肌   酶   谱   常规检查   ,   普   通   病   房   床   位   费   ,   乙   肝   e   抗体测定   ,   数字化   摄影   (   d   r   )   ,   左   卡   尼   汀   ,   静脉   穿   刺   置   管   术   ,   住院   诊查费   (   三   级   医院   )   ,   自   费   费   用   ,   静脉   穿   刺   置   管   术   ,   阴   道   分   泌   物   白   细胞   酯   酶   检查   ,   麻   醉   中   监测   (   <4   小时   )   ,   支   原   体   培   养   及   药   敏   ,   氯   化   钾   ,   计   算   机   图   文   报   告   ,   血   氧   饱   和   度   监测   ,   血   清   载   脂   蛋

## 4.3 将语言模型的RNN编码器迁移到下游分类

In [8]:
train_df = data_inhosp[['details','ICD4']]
train_df

Unnamed: 0,details,ICD4
0,"血清人绒毛膜促性腺激素测定,淀粉酶测定,无机磷测定,葡萄糖,兰索拉唑,生化筛查常规检查,糖类...",K29.5
1,"自费诊疗及服务项目,灸法(艾条灸),主肺动脉窗修补术,血清载脂蛋白B测定,钾测定,普通病房床...",M23.3
2,"常规心电图检查,一次性真空采血器,血清载脂蛋白α测定,抗Sm抗体测定,血清总胆汁酸测定,淀粉...",R42.x
3,"钙测定,16层及以上多排螺旋CT扫描加收,住院诊查费（二级医院),乙肝e抗体测定,碳[14C...",N40.x
4,,syjtb
5,"粪便隐血试验(OB),血清胆碱脂酶测定,计算机图文报告,左卡尼汀,普通病房床位费,住院诊查费...",K80.2
6,"多索茶碱,一般专项护理,计算机图文报告,自费诊疗及服务项目,抗核提取物抗体测定(抗ENA抗体...",J44.1
7,"普通病房床位费,总前列腺特异性抗原测定(TPSA),MRI扫描增加各项功能加收,血清载脂蛋白...",R42.x
8,"氨溴索,结核菌涂片检查,自费诊疗及服务项目,血清α-L-岩藻糖苷酶测定,贴敷疗法,计算机图文...",J44.1
9,"真菌涂片检查,普通病房床位费,一次性真空采血器,氯化钾,等级护理（三级医院）,凝血功能常规检...",I10.x


In [12]:
# import catboost
# catboost.__version__

'0.23.2'

In [18]:
bs=64

In [13]:
data_clas = (TextList.from_df(train_df, cols='details', processor=FixSSP.load("./"))
    .split_by_rand_pct(0.1, seed=42)
    .label_from_df(cols='ICD4')
    .databunch(bs=bs, num_workers=1))

In [19]:
learn_c = text_classifier_learner(data_clas, AWD_LSTM, drop_mult=0.5, pretrained=False, wd=0.1).to_fp16()
learn_c.load_encoder(f'lm_fine_tuned_enc')
learn_c.freeze()

lr=2e-2
lr *= bs/48

learn_c.fit_one_cycle(2, lr, moms=(0.8,0.7))

epoch,train_loss,valid_loss,accuracy,time
0,2.173878,2.169117,0.428504,03:41
1,1.793146,1.685692,0.544192,04:18


In [20]:
learn_c.fit_one_cycle(2, lr, moms=(0.8,0.7))

epoch,train_loss,valid_loss,accuracy,time
0,2.177176,2.096181,0.463693,03:42


KeyboardInterrupt: 

SequentialRNN(
  (0): MultiBatchEncoder(
    (module): AWD_LSTM(
      (encoder): Embedding(1816, 400, padding_idx=1)
      (encoder_dp): EmbeddingDropout(
        (emb): Embedding(1816, 400, padding_idx=1)
      )
      (rnns): ModuleList(
        (0): WeightDropout(
          (module): LSTM(400, 1152, batch_first=True)
        )
        (1): WeightDropout(
          (module): LSTM(1152, 1152, batch_first=True)
        )
        (2): WeightDropout(
          (module): LSTM(1152, 400, batch_first=True)
        )
      )
      (input_dp): RNNDropout()
      (hidden_dps): ModuleList(
        (0): RNNDropout()
        (1): RNNDropout()
        (2): RNNDropout()
      )
    )
  )
  (1): PoolingLinearClassifier(
    (layers): Sequential(
      (0): BatchNorm1d(1200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (1): Dropout(p=0.2, inplace=False)
      (2): Linear(in_features=1200, out_features=50, bias=True)
      (3): ReLU(inplace=True)
      (4): BatchNorm1d(50, ep

In [27]:
learn_c.export()

## 4.4 数据可视化

In [None]:
def get_emcoding(details,
                model=learn_c):
    encoder = model.model[0]
    out = []
    for i in tqdm(range(len(icdSeq)), position=0, leave=True):
        encoder.reset()
        xb, yb = model.data.one_item(icdSeq[i])
        with torch.no_grad():
            encoded = encoder.eval()(xb)
#             pdb.set_trace()
            out.append(encoded[0][-1][0][-1].detach())
#     pdb.set_trace()
    res = torch.stack(out, dim=0)
    return res

# 5. 结合迁移学习和 Few-shot Metrics learning 小样本学习实现极小病种诊断的高精度分类

# 6. 使用模型