In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import BertTokenizer
import pandas as pd
import ast
import argparse
from tqdm import tqdm

# Define the classifier model (same as before)
class BertCFCausalClassifier(nn.Module):
    def __init__(self, bert_cf_model, num_labels):
        super(BertCFCausalClassifier, self).__init__()
        self.bert_cf = bert_cf_model
        self.classifier = nn.Linear(bert_cf_model.config.hidden_size, num_labels)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, input_ids, attention_mask):
        with torch.no_grad():
            outputs = self.bert_cf.bert(input_ids=input_ids, attention_mask=attention_mask)
            pooled_output = outputs.pooler_output  # (batch_size, hidden_size)
        logits = self.classifier(pooled_output)
        probs = self.sigmoid(logits)
        return probs

def load_classifier(model_path, bert_cf_model, num_labels, device):
    classifier = BertCFCausalClassifier(bert_cf_model, num_labels)
    classifier.load_state_dict(torch.load(model_path, map_location=device))
    classifier.to(device)
    classifier.eval()
    return classifier

def prepare_test_dataset(csv_file, tokenizer, max_length):
    class DiseaseTestDataset(Dataset):
        def __init__(self, csv_file, tokenizer, max_length):
            self.data = pd.read_csv(csv_file)
            self.tokenizer = tokenizer
            self.max_length = max_length
        
        def __len__(self):
            return len(self.data)
        
        def __getitem__(self, idx):
            row = self.data.iloc[idx]
            note = row['note']
            
            # Tokenize
            encoding = self.tokenizer.encode_plus(
                note,
                add_special_tokens=True,
                max_length=self.max_length,
                truncation=True,
                padding='max_length',
                return_tensors='pt'
            )
            input_ids = encoding['input_ids'].squeeze()
            attention_mask = encoding['attention_mask'].squeeze()
            
            return {
                'input_ids': input_ids,
                'attention_mask': attention_mask
            }
    
    dataset = DiseaseTestDataset(csv_file, tokenizer, max_length)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=False)
    return dataloader

def calculate_treate(args):
    # Initialize tokenizer
    tokenizer = BertTokenizer.from_pretrained(args.bert_model)
    
    # Load all diseases
    all_diseases = load_all_diseases(args.train_csv)
    num_labels = len(all_diseases)
    
    # Load original BERT classifier
    # Assuming you have a separate classifier trained with original BERT
    # Replace with the path to your original classifier
    original_classifier = load_classifier(
        model_path=args.original_classifier_path,
        bert_cf_model=BertCFModel.from_pretrained(args.bert_model),
        num_labels=num_labels,
        device=args.device
    )
    
    # Load BERT-CF classifier
    bert_cf_model = BertCFModel.from_pretrained(args.bert_model)
    bert_cf_model.load_state_dict(torch.load(args.bert_cf_path, map_location=args.device))
    bert_cf_model.to(args.device)
    bert_cf_model.eval()
    
    counterfactual_classifier = load_classifier(
        model_path=args.counterfactual_classifier_path,
        bert_cf_model=bert_cf_model,
        num_labels=num_labels,
        device=args.device
    )
    
    # Prepare test dataset
    test_dataloader = prepare_test_dataset(args.test_csv, tokenizer, args.max_length)
    
    # Initialize storage for predictions
    original_preds = []
    counterfactual_preds = []
    
    # Get predictions from original classifier
    print("Predicting with Original Classifier...")
    for batch in tqdm(test_dataloader, desc="Original Predictions"):
        input_ids = batch['input_ids'].to(args.device)
        attention_mask = batch['attention_mask'].to(args.device)
        with torch.no_grad():
            probs = original_classifier(input_ids=input_ids, attention_mask=attention_mask)
        original_preds.append(probs.cpu().numpy())
    
    # Get predictions from counterfactual classifier
    print("Predicting with Counterfactual Classifier...")
    for batch in tqdm(test_dataloader, desc="Counterfactual Predictions"):
        input_ids = batch['input_ids'].to(args.device)
        attention_mask = batch['attention_mask'].to(args.device)
        with torch.no_grad():
            probs = counterfactual_classifier(input_ids=input_ids, attention_mask=attention_mask)
        counterfactual_preds.append(probs.cpu().numpy())
    
    # Concatenate all predictions
    original_preds = np.concatenate(original_preds, axis=0)  # Shape: (num_samples, num_labels)
    counterfactual_preds = np.concatenate(counterfactual_preds, axis=0)  # Shape: (num_samples, num_labels)
    
    # Calculate TREaTE
    treate = np.mean(np.abs(original_preds - counterfactual_preds))
    print(f"TREaTE: {treate:.4f}")
    
    # Save TREaTE
    with open(f"{args.output_dir}/treate.txt", 'w') as f:
        f.write(f"TREaTE: {treate:.4f}\n")
    print(f"TREaTE saved to {args.output_dir}/treate.txt")

def load_all_diseases(csv_file):
    df = pd.read_csv(csv_file, usecols=['differential_diagnosis'])
    all_diseases = set()
    for _, row in df.iterrows():
        diseases = ast.literal_eval(row['differential_diagnosis'])
        for disease, _ in diseases:
            all_diseases.add(disease)
    return sorted(list(all_diseases))

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Calculate TREaTE")
    parser.add_argument('--test_csv', type=str, required=True, help="Path to the test CSV file.")
    parser.add_argument('--train_csv', type=str, required=True, help="Path to the training CSV file for disease list.")
    parser.add_argument('--original_classifier_path', type=str, required=True, help="Path to the original classifier model.")
    parser.add_argument('--counterfactual_classifier_path', type=str, required=True, help="Path to the counterfactual classifier model.")
    parser.add_argument('--output_dir', type=str, required=True, help="Directory to save TREaTE results.")
    parser.add_argument('--bert_model', type=str, default='bert-base-uncased', help="Pre-trained BERT model.")
    parser.add_argument('--max_length', type=int, default=512, help="Maximum sequence length.")
    parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu', help="Device to run predictions on.")
    
    args = parser.parse_args()
    
    import os
    import numpy as np
    os.makedirs(args.output_dir, exist_ok=True)
    
    calculate_treate(args)
