# Dependencies and prep env

In [0]:
!pip install mlflow transformers torch scikit-learn nltk
%restart_python

[43mNote: you may need to restart the kernel using %restart_python or dbutils.library.restartPython() to use updated packages.[0m


In [0]:
import nltk
import torch
from sklearn.model_selection import train_test_split
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import DataLoader, TensorDataset
from sklearn.metrics import classification_report

# Unified Data Loader

In [0]:
import nltk
nltk.download('punkt')

def load_fallacy_data(fallacy_name: str):
    # Query using Spark SQL
    examples_df = spark.sql(f"""
        SELECT text 
        FROM logical_fallacy_data.{fallacy_name}
        WHERE label = true
    """)

    non_examples_df = spark.sql(f"""
        SELECT text 
        FROM logical_fallacy_data.{fallacy_name}
        WHERE label = false
    """)

    # Convert to Pandas (safe in Databricks serverless)
    examples = examples_df.toPandas()['text'].tolist()
    non_examples = non_examples_df.toPandas()['text'].tolist()

    print(f"[{fallacy_name}] examples: {len(examples)}")
    print(f"[{fallacy_name}] non-examples: {len(non_examples)}")

    return examples, non_examples

fallacies = [
    "red_herring",
    "straw_man",
    "slippery_slope",
    "attacking",
    "ad_hominem",
    "hasty_generalization",
    "ignorance",
    "hypocrisy",
    "stacking_deck"
]

fallacy_data = {}

for name in fallacies:
    examples, non_examples = load_fallacy_data(name)
    fallacy_data[name] = {
        "examples": examples,
        "non_examples": non_examples
    }


[nltk_data] Downloading package punkt to
[nltk_data]     /home/spark-b1faa9a0-ee82-4d76-938f-b6/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


[red_herring] examples: 36
[red_herring] non-examples: 38
[straw_man] examples: 1
[straw_man] non-examples: 7
[slippery_slope] examples: 1
[slippery_slope] non-examples: 7
[attacking] examples: 0
[attacking] non-examples: 0
[ad_hominem] examples: 28
[ad_hominem] non-examples: 60
[hasty_generalization] examples: 0
[hasty_generalization] non-examples: 0
[ignorance] examples: 42
[ignorance] non-examples: 37
[hypocrisy] examples: 22
[hypocrisy] non-examples: 43
[stacking_deck] examples: 16
[stacking_deck] non-examples: 15


# Unified Training Function

In [0]:
import torch
from transformers import BertTokenizer, BertForSequenceClassification
from torch.utils.data import TensorDataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report
import os

def train_fallacy_model(fallacy_name, examples, non_examples, output_dir="models", epochs=3, batch_size=8, lr=2e-5):
    if not examples or not non_examples:
        print(f"⚠️ Skipping {fallacy_name}: empty examples or non-examples.")
        return
    # Prepare training data
    texts = examples + non_examples
    labels = [1] * len(examples) + [0] * len(non_examples)

    # Load tokenizer and encode
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    encodings = tokenizer(texts, truncation=True, padding=True, return_tensors='pt')
    inputs = encodings['input_ids']
    masks = encodings['attention_mask']
    labels_tensor = torch.tensor(labels)

    # Build dataset and split
    dataset = TensorDataset(inputs, masks, labels_tensor)
    train_data, test_data = train_test_split(dataset, test_size=0.2, random_state=42)
    train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)
    test_loader = DataLoader(test_data, batch_size=batch_size)

    # Load model
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

    # Train loop
    model.train()
    for epoch in range(epochs):
        print(f"Epoch {epoch+1}/{epochs}")
        for batch in train_loader:
            b_input_ids, b_input_mask, b_labels = batch
            outputs = model(b_input_ids, attention_mask=b_input_mask, labels=b_labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()

    # Evaluate
    model.eval()
    predictions, true_labels = [], []
    with torch.no_grad():
        for batch in test_loader:
            b_input_ids, b_input_mask, b_labels = batch
            outputs = model(b_input_ids, attention_mask=b_input_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1)
            predictions.extend(preds.tolist())
            true_labels.extend(b_labels.tolist())

    print(f"\n[Evaluation for {fallacy_name}]")
    print(classification_report(true_labels, predictions))

    # Save model/tokenizer
    model_dir = os.path.join(output_dir, fallacy_name)
    os.makedirs(model_dir, exist_ok=True)
    model.save_pretrained(model_dir)
    tokenizer.save_pretrained(model_dir)
    print(f"✅ Model saved to {model_dir}")


# Load fallacy data and train

In [0]:
%python
# Safe temporary local storage
local_model_dir = "/local_disk0/tmp/fallacy_models"
os.makedirs(local_model_dir, exist_ok=True)

fallacies = [
    "red_herring",
    "straw_man",
    "slippery_slope",
    "attacking",
    "ad_hominem",
    "hasty_generalization",
    "ignorance",
    "hypocrisy",
    "stacking_deck"
]

for fallacy_name in fallacies:
    examples, non_examples = load_fallacy_data(fallacy_name)
    train_fallacy_model(fallacy_name, examples, non_examples)
    if not examples or not non_examples:
        print(f"❌ No data for: {fallacy_name}")

[red_herring] examples: 36
[red_herring] non-examples: 38


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/3
Epoch 2/3
Epoch 3/3

[Evaluation for red_herring]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00         5
           1       1.00      1.00      1.00        10

    accuracy                           1.00        15
   macro avg       1.00      1.00      1.00        15
weighted avg       1.00      1.00      1.00        15

✅ Model saved to models/red_herring
[straw_man] examples: 1
[straw_man] non-examples: 7


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/3
Epoch 2/3
Epoch 3/3

[Evaluation for straw_man]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00         2

    accuracy                           1.00         2
   macro avg       1.00      1.00      1.00         2
weighted avg       1.00      1.00      1.00         2

✅ Model saved to models/straw_man
[slippery_slope] examples: 1
[slippery_slope] non-examples: 7


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/3
Epoch 2/3
Epoch 3/3

[Evaluation for slippery_slope]
              precision    recall  f1-score   support

           0       0.00      0.00      0.00       2.0
           1       0.00      0.00      0.00       0.0

    accuracy                           0.00       2.0
   macro avg       0.00      0.00      0.00       2.0
weighted avg       0.00      0.00      0.00       2.0



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


✅ Model saved to models/slippery_slope
[attacking] examples: 0
[attacking] non-examples: 0
⚠️ Skipping attacking: empty examples or non-examples.
❌ No data for: attacking
[ad_hominem] examples: 28
[ad_hominem] non-examples: 60


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/3
Epoch 2/3
Epoch 3/3

[Evaluation for ad_hominem]
              precision    recall  f1-score   support

           0       1.00      1.00      1.00        11
           1       1.00      1.00      1.00         7

    accuracy                           1.00        18
   macro avg       1.00      1.00      1.00        18
weighted avg       1.00      1.00      1.00        18

✅ Model saved to models/ad_hominem
[hasty_generalization] examples: 0
[hasty_generalization] non-examples: 0
⚠️ Skipping hasty_generalization: empty examples or non-examples.
❌ No data for: hasty_generalization
[ignorance] examples: 42
[ignorance] non-examples: 37


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/3
Epoch 2/3
Epoch 3/3

[Evaluation for ignorance]
              precision    recall  f1-score   support

           0       0.83      1.00      0.91         5
           1       1.00      0.91      0.95        11

    accuracy                           0.94        16
   macro avg       0.92      0.95      0.93        16
weighted avg       0.95      0.94      0.94        16

✅ Model saved to models/ignorance
[hypocrisy] examples: 22
[hypocrisy] non-examples: 43


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/3
Epoch 2/3
Epoch 3/3

[Evaluation for hypocrisy]
              precision    recall  f1-score   support

           0       0.80      1.00      0.89         8
           1       1.00      0.60      0.75         5

    accuracy                           0.85        13
   macro avg       0.90      0.80      0.82        13
weighted avg       0.88      0.85      0.84        13

✅ Model saved to models/hypocrisy
[stacking_deck] examples: 16
[stacking_deck] non-examples: 15


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/3
Epoch 2/3
Epoch 3/3

[Evaluation for stacking_deck]
              precision    recall  f1-score   support

           0       1.00      0.25      0.40         4
           1       0.50      1.00      0.67         3

    accuracy                           0.57         7
   macro avg       0.75      0.62      0.53         7
weighted avg       0.79      0.57      0.51         7

✅ Model saved to models/stacking_deck


# Move raw inference models

In [0]:
import os
import shutil

# Safe temporary local storage
local_model_dir = "/local_disk0/tmp/fallacy_models"
os.makedirs(local_model_dir, exist_ok=True)

# Copy only trained models into safe path
for fallacy in fallacies:
    src = f"models/{fallacy}"
    dst = f"{local_model_dir}/{fallacy}"
    if os.path.exists(src):
        shutil.copytree(src, dst, dirs_exist_ok=True)
    else:
        print(f"Model for {fallacy} does not exist at {src}")

Model for attacking does not exist at models/attacking
Model for hasty_generalization does not exist at models/hasty_generalization


# Serve via MLflow

In [0]:
%python
import os
import shutil
import mlflow.pyfunc
from transformers import BertTokenizer, BertForSequenceClassification

class FallacyEnsembleModel(mlflow.pyfunc.PythonModel):
    def load_context(self, context):
        self.fallacy_types = [
            "straw_man", "red_herring", "ad_hominem",
            "hasty_generalization", "appeal_to_ignorance",
            "hypocrisy", "stacking_deck"
        ]
        self.model_dir = context.artifacts["model_dir"]
        self.fallacy_models = {}

        for fallacy in self.fallacy_types:
            model_path = os.path.join(self.model_dir, fallacy)
            if not os.path.exists(model_path):
                print(f"Skipping {fallacy}: model not found at {model_path}")
                continue
            try:
                model = BertForSequenceClassification.from_pretrained(model_path)
                tokenizer = BertTokenizer.from_pretrained(model_path)
                model.eval()
                self.fallacy_models[fallacy] = (tokenizer, model)
            except Exception as e:
                print(f"Error loading {fallacy}: {e}")

    def predict(self, context, model_input):
        texts = model_input["text"].tolist()
        results = []
        for text in texts:
            fallacy_scores = []
            for fallacy, (tokenizer, model) in self.fallacy_models.items():
                inputs = tokenizer(text, return_tensors='pt', truncation=True, padding=True)
                with torch.no_grad():
                    logits = model(**inputs).logits
                    probs = torch.softmax(logits, dim=1)
                    fallacy_prob = probs[0][1].item()
                fallacy_scores.append((fallacy, fallacy_prob))
            fallacy_scores.sort(key=lambda x: x[1], reverse=True)
            results.append(fallacy_scores)
        return results

# Path to save the MLflow model
save_path = "/local_disk0/tmp/fallacy_ensemble_model_v1"
shutil.rmtree(save_path, ignore_errors=True)

# Save the model using MLflow
mlflow.pyfunc.save_model(
    path=save_path,
    python_model=FallacyEnsembleModel(),     # Class instance
    artifacts={"model_dir": "/local_disk0/tmp/fallacy_models"}
)



Skipping hasty_generalization: model not found at /local_disk0/tmp/fallacy_models/hasty_generalization
Skipping appeal_to_ignorance: model not found at /local_disk0/tmp/fallacy_models/appeal_to_ignorance


Downloading artifacts:   0%|          | 0/35 [00:00<?, ?it/s]

# Register model to unity catalog
weird databricks specific thing with respect to mlflow model registries

In [0]:
%python
# Register it in Unity Catalog or default registry
registered_model_name = "fallacy_ensemble_uc"
model_artifact_path = f"<valid_run_id>/artifacts/fallacy_ensemble_model_v1"
try:
    mlflow.register_model(f"runs:/{model_artifact_path}", registered_model_name)
except:
    print("")




Registered model 'fallacy_ensemble_uc' already exists. Creating a new version of this model...


# Example Usage

In [0]:
model = mlflow.pyfunc.load_model(save_path)

import pandas as pd
df = pd.DataFrame({"text": ["You're just a student."]})
predictions = model.predict(df)
print(predictions)


Skipping hasty_generalization: model not found at /local_disk0/tmp/fallacy_ensemble_model_v1/artifacts/fallacy_models/hasty_generalization
Skipping appeal_to_ignorance: model not found at /local_disk0/tmp/fallacy_ensemble_model_v1/artifacts/fallacy_models/appeal_to_ignorance
[[('ad_hominem', 0.9208627939224243), ('stacking_deck', 0.683536946773529), ('straw_man', 0.527265191078186), ('red_herring', 0.37025919556617737), ('hypocrisy', 0.1460869014263153)]]
