In [1]:
import sys
import json
import os

sys.path.append("/home/yuxiang/liao/workspace/fast-coref/src")

In [3]:
from os import path

import torch
from inference.tokenize_doc import basic_tokenize_doc, tokenize_and_segment_doc
from model.entity_ranking_model import EntityRankingModel
from model.utils import action_sequences_to_clusters
from omegaconf import OmegaConf
from transformers import AutoModel, AutoTokenizer


class Inference:
    def __init__(self, model_path, encoder_name=None):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

        # Load model
        checkpoint = torch.load(path.join(model_path, "model.pth"), map_location=self.device)
        self.config = OmegaConf.create(checkpoint["config"])
        if encoder_name is not None:
            self.config.model.doc_encoder.transformer.model_str = encoder_name
        self.model = EntityRankingModel(self.config.model, self.config.trainer)
        self._load_model(checkpoint, model_path, encoder_name=encoder_name)

        self.max_segment_len = self.config.model.doc_encoder.transformer.max_segment_len
        self.tokenizer = self.model.mention_proposer.doc_encoder.tokenizer

    def _load_model(self, checkpoint, model_path, encoder_name=None):
        self.model.load_state_dict(checkpoint["model"], strict=False)

        if self.config.model.doc_encoder.finetune:
            # Load the document encoder params if encoder is finetuned
            if encoder_name is None:
                doc_encoder_dir = path.join(model_path, self.config.paths.doc_encoder_dirname)
                # else:
                # 	doc_encoder_dir = encoder_name
                # Load the encoder
                self.model.mention_proposer.doc_encoder.lm_encoder = AutoModel.from_pretrained(pretrained_model_name_or_path=doc_encoder_dir)
                self.model.mention_proposer.doc_encoder.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=doc_encoder_dir)

            if torch.cuda.is_available():
                self.model.cuda()

        self.model.eval()

    @torch.no_grad()
    def perform_coreference(self, document):
        if isinstance(document, list):
            # Document is already tokenized
            tokenized_doc = tokenize_and_segment_doc(document, self.tokenizer, max_segment_len=self.max_segment_len)
        elif isinstance(document, str):
            # Raw document string. First perform basic tokenization before further tokenization.
            import spacy

            basic_tokenizer = spacy.load("en_core_web_sm")
            basic_tokenized_doc = basic_tokenize_doc(document, basic_tokenizer)
            tokenized_doc = tokenize_and_segment_doc(
                basic_tokenized_doc,
                self.tokenizer,
                max_segment_len=self.max_segment_len,
            )
        elif isinstance(document, dict):
            tokenized_doc = document
        else:
            raise ValueError

        extra_output_dict = {}
        pred_mentions, _, _, pred_actions = self.model(tokenized_doc, extra_output=extra_output_dict)
        idx_clusters = action_sequences_to_clusters(pred_actions, pred_mentions)

        subtoken_map = tokenized_doc["subtoken_map"]
        orig_tokens = tokenized_doc["orig_tokens"]
        clusters = []
        for idx_cluster in idx_clusters:
            cur_cluster = []
            for ment_start, ment_end in idx_cluster:
                cur_cluster.append(
                    (
                        (ment_start, ment_end),
                        " ".join(orig_tokens[subtoken_map[ment_start] : subtoken_map[ment_end] + 1]),
                    )
                )

            clusters.append(cur_cluster)

        return {
            "tokenized_doc": tokenized_doc,
            "clusters": clusters,
            "subtoken_idx_clusters": idx_clusters,
            "actions": pred_actions,
            "mentions": pred_mentions,
        }

In [24]:
model_str = "/home/yuxiang/liao/resources/downloaded_models/coref_model_9b02_25_4/best"  # exp11
encoder_name = "/home/yuxiang/liao/resources/downloaded_models/longformer_coreference_joint"
# model = Inference(model_str)
model = Inference(model_str, encoder_name)

# doc = " ".join(open("/home/shtoshni/Research/coref_resources/data/ccarol/doc.txt").readlines())
doc = [["A", "calcific", "density", "is", "seen", "projecting", "at", "the", "left", "lung", "base", "laterally", "."], ["The", "calcific", "density", "may", "reflect", "a", "granuloma", "."], ["The", "calcific", "density", "may", "reflect", "a", "sclerotic", "finding", "within", "the", "rib", "."], ["The", "calcific", "density", "may", "reflect", "an", "object", "external", "to", "the", "patient", "."]]
output_dict = model.perform_coreference(doc)
print(output_dict["clusters"])

[[((0, 3), 'A calcific density'), ((15, 18), 'The calcific density'), ((26, 29), 'The calcific density'), ((41, 44), 'The calcific density')]]


In [36]:
d = [i for s in doc for i in s]
output_dict["tokenized_doc"]["sentence_map"].tolist()

[0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 0,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 1,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 2,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3,
 3]

In [7]:
input_dir = "/home/yuxiang/liao/workspace/arrg_sentgen/outputs/interpret_cxr"

with open(os.path.join(input_dir, "raw.json"), "r") as f:
    raw_docs = [json.loads(line) for line in f]

In [19]:
raw_docs[1]

{'doc_key': 'train#0#impression',
 'valid_key': 'CheXpert#data/chexpert-public/train/patient32815/study11/view1_frontal.jpg',
 'sentences': [['1.DECREASED',
   'BIBASILAR',
   'PARENCHYMAL',
   'OPACITIES',
   ',',
   'NOW',
   'MINIMAL',
   '.'],
  ['STABLE', 'SMALL', 'LEFT', 'PLEURAL', 'EFFUSION', '.'],
  ['2',
   '.',
   'FEEDING',
   'TUBE',
   'AND',
   'STERNAL',
   'PLATES',
   'AGAIN',
   'SEEN',
   '.']]}

In [10]:
with open(os.path.join(input_dir, "llm_split_sents_1_of_3.json"), "r") as f:
    split_docs = [json.loads(line) for line in f]

In [17]:
split_docs[9]

{'doc_key': 'train#0#impression',
 'sent_idx': 1,
 'original_sent': 'STABLE SMALL LEFT PLEURAL EFFUSION .',
 'sent_splits': ['Stable small left pleural effusion.']}

In [7]:
from collections import defaultdict
import json
import os

input_dir = "/home/yuxiang/liao/workspace/arrg_data_processing/outputs/interpret_sents/raw.json"

docs_dict = defaultdict(list)
with open(input_dir, "r") as f:
    for line in f:
        doc = json.loads(line)
        data_split_name, data_row_idx, section_name, doc_sent_id = doc["doc_key"].split("#")
        doc_sent_idx, split_sent_idx = doc_sent_id.split("@")
        doc_key = f"{data_split_name}#{data_row_idx}#{section_name}"
        doc["doc_sent_idx"] = doc_sent_idx
        doc["split_sent_idx"] = split_sent_idx
        docs_dict[doc_key].append(doc)

In [8]:
len(docs_dict)

734635

In [12]:
docs_dict["train#0#impression"]

[{'doc_key': 'train#0#impression#0@0',
  'sentences': [['Decreased',
    'bibasilar',
    'parenchymal',
    'opacities',
    'are',
    'seen',
    '.']],
  'tok_indices': [[[0, 9],
    [10, 19],
    [20, 31],
    [32, 41],
    [42, 45],
    [46, 50],
    [50, 51]]],
  'raw_sent': 'Decreased bibasilar parenchymal opacities are seen.',
  'doc_sent_idx': '0',
  'split_sent_idx': '0'},
 {'doc_key': 'train#0#impression#0@1',
  'sentences': [['The',
    'bibasilar',
    'parenchymal',
    'opacities',
    'are',
    'now',
    'minimal',
    '.']],
  'tok_indices': [[[0, 3],
    [4, 13],
    [14, 25],
    [26, 35],
    [36, 39],
    [40, 43],
    [44, 51],
    [51, 52]]],
  'raw_sent': 'The bibasilar parenchymal opacities are now minimal.',
  'doc_sent_idx': '0',
  'split_sent_idx': '1'},
 {'doc_key': 'train#0#impression#1@0',
  'sentences': [['Stable', 'small', 'left', 'pleural', 'effusion', '.']],
  'tok_indices': [[[0, 6], [7, 12], [13, 17], [18, 25], [26, 34], [34, 35]]],
  'raw_sent':

In [15]:
doc_key = "train#0#impression"
out_dict = {"doc_key": doc_key, "split_sents": [], "sent_idx_split_idx": []}
for doc in docs_dict["train#0#impression"]:
    out_dict["split_sents"] += doc["sentences"]
    out_dict["sent_idx_split_idx"].append((int(doc["doc_sent_idx"]), int(doc["split_sent_idx"])))

print(out_dict)

{'doc_key': 'train#0#impression', 'split_sents': [['Decreased', 'bibasilar', 'parenchymal', 'opacities', 'are', 'seen', '.'], ['The', 'bibasilar', 'parenchymal', 'opacities', 'are', 'now', 'minimal', '.'], ['Stable', 'small', 'left', 'pleural', 'effusion', '.'], ['Feeding', 'tube', 'is', 'again', 'seen', '.'], ['Sternal', 'plates', 'are', 'again', 'seen', '.']], 'sent_idx_split_idx': [(0, 0), (0, 1), (1, 0), (2, 0), (2, 1)]}
