In [54]:
import logging
import math
import os
from pathlib import Path
import json


import torch
import pandas as pd
import hydra
from omegaconf import OmegaConf
from rich.pretty import pprint

from src.data.data_pipeline import data_pipeline
from src.factories import (
    get_callbacks,
    get_dataloaders,
    get_datasets,
    get_lookups,
    get_lr_scheduler,
    get_metric_collections,
    get_model,
    get_optimizer,
    get_text_encoder,
    get_transform,
)
from src.trainer.trainer import Trainer
from src.utils.seed import set_seed


In [3]:
### Parameters for inference ###
model_to_load = 'mimic_axa_cpt_hierarchical_short'

# Makes all necessary imports

## Import config

In [4]:
#set directory 
dir_all_models = Path('files')
model_checkpoints = dir_all_models/model_to_load

#load config file
cfg = OmegaConf.load(model_checkpoints/'config.yaml')

In [5]:
# Check if CUDA_VISIBLE_DEVICES is set
if "CUDA_VISIBLE_DEVICES" not in os.environ:
    if cfg.gpu != -1 and cfg.gpu is not None and cfg.gpu != "":
        os.environ["CUDA_VISIBLE_DEVICES"] = (
            ",".join([str(gpu) for gpu in cfg.gpu])
            if isinstance(cfg.gpu, list)
            else str(cfg.gpu)
        )

    else:
        os.environ["CUDA_VISIBLE_DEVICES"] = ""

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pprint(f"Device: {device}")
pprint(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")

In [6]:
set_seed(cfg.seed)

## Import model

In [7]:
data = data_pipeline(config=cfg.data)

text_encoder = get_text_encoder(
    config=cfg.text_encoder, data_dir=cfg.data.dir, texts=data.get_train_documents
) #not needed
label_transform = get_transform(
    config=cfg.label_transform,
    targets=data.all_targets,
    load_transform_path=cfg.load_model,
) #not needed
text_transform = get_transform(
    config=cfg.text_transform,
    texts=data.get_train_documents,
    text_encoder=text_encoder,
    load_transform_path=cfg.load_model,
)
data.truncate_text(cfg.data.max_length) #not needed
data.transform_text(text_transform.batch_transform) #not needed

lookups = get_lookups(
    config=cfg.lookup,
    data=data,
    label_transform=label_transform,
    text_transform=text_transform,
)

model = get_model(
        config=cfg.model, data_info=lookups.data_info, text_encoder=text_encoder, label_transform = label_transform
    )
model.to(device)
model_weights = torch.load(model_checkpoints/"best_model.pt", map_location=device)
model.load_state_dict(model_weights['model'])

# Inference

In [42]:
model.eval()

def prepare_inputs(tokenized_text, text_transform, chunk_size):
    token_ids, attention_mask = tokenized_text.values()
    data = text_transform.seq2batch(token_ids, chunk_size=chunk_size)
    attention_mask = text_transform.seq2batch(
        attention_mask, chunk_size=chunk_size
    )
    return data, attention_mask

In [43]:
text = "The patient underwent surgery on his right eye. Carcinoma of circumference 5mm found by biopsy. No other trauma detected expect an ankle profound wound. "
tokenized_text = text_transform.transform(text)
input_ids, attention_mask = prepare_inputs(tokenized_text, text_transform, cfg.dataset.configs.chunk_size)

In [50]:
with torch.no_grad():
    logits = model(input_ids, attention_mask)
    logits = torch.sigmoid(logits)

## Analyse results

In [75]:
# Define your tensor and JSON data
tensor_probs = logits[0]
target2index_path = model_checkpoints/'target2index.json'
with open(target2index_path, 'r') as json_file:
    target2index = json.load(json_file)

# Create an empty DataFrame
df = pd.DataFrame(columns=target2index)

In [76]:
# Populate the DataFrame with probabilities
data_to_append = {}

for target, index in target2index.items():
    probability = tensor_probs[index].item()
    data_to_append[str(target)] = probability
new_record = pd.DataFrame([data_to_append])

df = pd.concat([df, new_record], ignore_index=True)

In [88]:
row_index = 0
max_column = df.iloc[row_index].idxmax()
confidence = df.iloc[row_index].max()
print(f'CPT predicted: {max_column} Confidence {confidence}')

CPT predicted: 99291 Confidence 0.218035027384758
