In [5]:
!pip install torch

Collecting torch
  Downloading torch-2.5.0-cp310-none-macosx_11_0_arm64.whl.metadata (28 kB)
Collecting networkx (from torch)
  Using cached networkx-3.4.2-py3-none-any.whl.metadata (6.3 kB)
Collecting jinja2 (from torch)
  Using cached jinja2-3.1.4-py3-none-any.whl.metadata (2.6 kB)
Collecting sympy==1.13.1 (from torch)
  Using cached sympy-1.13.1-py3-none-any.whl.metadata (12 kB)
Collecting mpmath<1.4,>=1.1.0 (from sympy==1.13.1->torch)
  Using cached mpmath-1.3.0-py3-none-any.whl.metadata (8.6 kB)
Collecting MarkupSafe>=2.0 (from jinja2->torch)
  Downloading MarkupSafe-3.0.2-cp310-cp310-macosx_11_0_arm64.whl.metadata (4.0 kB)
Downloading torch-2.5.0-cp310-none-macosx_11_0_arm64.whl (64.3 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m64.3/64.3 MB[0m [31m16.8 MB/s[0m eta [36m0:00:00[0ma [36m0:00:01[0m
[?25hUsing cached sympy-1.13.1-py3-none-any.whl (6.2 MB)
Using cached jinja2-3.1.4-py3-none-any.whl (133 kB)
Using cached networkx-3.4.2-py3-none-any.whl (1.7

In [6]:
import pandas as pd
import json
import ast
from torch.utils.data import Dataset
from transformers import BertTokenizer
import torch
import random

# Load the clinical notes dataset
df = pd.read_csv('../data/release_train_patients.csv')  # Replace with your actual file path

# Load the evidence mapping dataset
with open('../data/release_evidences.json', 'r') as f:
    evidence_map = json.load(f)


  from .autonotebook import tqdm as notebook_tqdm


In [7]:
def generate_note_for_patient(evidences_str, evidence_map):
    if not evidences_str or pd.isna(evidences_str):
        return ''
    try:
        # Convert the string representation of the list into an actual list
        evidences = ast.literal_eval(evidences_str)
        if not isinstance(evidences, list):
            return ''
        note_parts = []
        for evidence in evidences:
            # Evidence codes might be in the format 'E_55_@_V_89'
            if '_@_' in evidence:
                e_code, v_code = evidence.split('_@_')
            else:
                e_code = evidence
                v_code = None
            mapping = evidence_map.get(e_code, {})
            question = mapping.get('question_en', '').strip()
            if not question:
                continue  # Skip if question is not available
            if v_code and 'value_meaning' in mapping and v_code in mapping['value_meaning']:
                answer = mapping['value_meaning'][v_code]['en'].strip()
            else:
                # Use default value or indicate missing answer
                default_value = mapping.get('default_value', '')
                if default_value and 'value_meaning' in mapping and default_value in mapping['value_meaning']:
                    answer = mapping['value_meaning'][default_value]['en'].strip()
                else:
                    answer = ''  # Missing answer
            # Construct the question-answer pair
            if answer:
                note_parts.append(f"Q: {question} A: {answer}")

        # Combine all question-answer pairs to form the note
        note = '. '.join(note_parts)
        return note
    except (ValueError, SyntaxError):
        return ''


In [8]:
def extract_chest_pain(evidences_str, chest_pain_evidences):
    if not evidences_str or pd.isna(evidences_str):
        return 0
    try:
        # Convert the string representation of the list into an actual list
        evidences = ast.literal_eval(evidences_str)
        if not isinstance(evidences, list):
            return 0
        # Check if any evidence code indicates chest pain
        for evidence in evidences:
            if evidence in chest_pain_evidences:
                return 1
        return 0
    except (ValueError, SyntaxError):
        return 0


In [9]:
df['note'] = df['EVIDENCES'].apply(lambda x: generate_note_for_patient(x, evidence_map))

df['unique_id'] = df.index


In [12]:
print(df.size)

9230418


In [11]:
def get_chest_pain_evidences(evidence_map):
    chest_pain_evidences = set()
    # Keywords to identify chest pain
    chest_keywords = ['chest', 'sternum', 'thorax', 'breast', 'pectoral', 'rib', 'precordial']
    pain_keywords = ['pain', 'douleur']
    
    for e_code, mapping in evidence_map.items():
        question_en = mapping.get('question_en', '').lower()
        question_fr = mapping.get('question_fr', '').lower()
        
        # Check if the question is about pain
        if any(pain_kw in question_en for pain_kw in pain_keywords) or \
           any(pain_kw in question_fr for pain_kw in pain_keywords):
            # Check if there are value meanings
            value_meaning = mapping.get('value_meaning', {})
            for v_code, meaning in value_meaning.items():
                meaning_en = meaning.get('en', '').lower()
                # Check if the meaning indicates chest area
                if any(chest_kw in meaning_en for chest_kw in chest_keywords):
                    chest_pain_evidences.add(f"{e_code}_@_{v_code}")
    return chest_pain_evidences

# Get the set of evidence codes that indicate chest pain
chest_pain_evidences = get_chest_pain_evidences(evidence_map)


def extract_chest_pain(evidences_str, chest_pain_evidences):
    if not evidences_str or pd.isna(evidences_str):
        return 0
    try:
        evidences = ast.literal_eval(evidences_str)
        if not isinstance(evidences, list):
            return 0
        for evidence in evidences:
            if evidence in chest_pain_evidences:
                return 1
        return 0
    except (ValueError, SyntaxError):
        return 0

# Apply the function to create the 'chest_pain' column
df['chest_pain'] = df['EVIDENCES'].apply(lambda x: extract_chest_pain(x, chest_pain_evidences))




In [13]:
df.to_csv('../data/clinical_notes_with_chest_pain.csv', index=False)

In [15]:
class ClinicalNotesCausalDataset(Dataset):
    def __init__(self, csv_file, tokenizer, max_length, controlled_concepts=None, mask_prob=0.15):
        """
        Args:
            csv_file (str): Path to the csv file with clinical notes.
            tokenizer (BertTokenizer): Tokenizer for BERT.
            max_length (int): Maximum sequence length.
            controlled_concepts (list): List of controlled concept labels (optional).
            mask_prob (float): Probability of masking tokens for MLM.
        """
        self.data = pd.read_csv(csv_file)
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.controlled_concepts = controlled_concepts
        self.mask_prob = mask_prob
        
        # Precompute all controlled concept labels if provided
        if self.controlled_concepts:
            self.controlled_concept_map = {concept: idx for idx, concept in enumerate(self.controlled_concepts)}
        else:
            self.controlled_concept_map = {}
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        row = self.data.iloc[idx]
        note = row['note']
        chest_pain = row['chest_pain']
        controlled_concept = row.get('controlled_concept', None)  # Assuming you have this column
        
        # Tokenize
        encoding = self.tokenizer.encode_plus(
            note,
            add_special_tokens=True,
            max_length=self.max_length,
            truncation=True,
            padding='max_length',
            return_tensors='pt'
        )
        input_ids = encoding['input_ids'].squeeze()
        attention_mask = encoding['attention_mask'].squeeze()
        
        # Prepare MLM labels
        labels = input_ids.clone()
        # Create mask
        probability_matrix = torch.full(labels.shape, self.mask_prob)
        special_tokens_mask = [
            self.tokenizer.get_special_tokens_mask(val, already_has_special_tokens=True) for val in labels.unsqueeze(0).tolist()
        ]
        special_tokens_mask = torch.tensor(special_tokens_mask, dtype=torch.bool)
        probability_matrix.masked_fill_(special_tokens_mask, value=0.0)
        masked_indices = torch.bernoulli(probability_matrix).bool()
        labels[~masked_indices] = -100  # We only compute loss on masked tokens
        # Replace masked input tokens with [MASK] token
        input_ids[masked_indices] = self.tokenizer.convert_tokens_to_ids(self.tokenizer.mask_token)
        
        # Treated Concept Label
        chest_pain_label = torch.tensor(chest_pain, dtype=torch.long)
        
        # Controlled Concept Label (if any)
        if self.controlled_concepts and controlled_concept:
            cc_label = torch.tensor(self.controlled_concept_map.get(controlled_concept, 0), dtype=torch.long)
        else:
            cc_label = torch.tensor(0, dtype=torch.long)  # Default to 0 if no controlled concept
        
        return {
            'input_ids': input_ids,
            'attention_mask': attention_mask,
            'labels': labels,
            'chest_pain': chest_pain_label,
            'controlled_concept': cc_label
        }
