In [43]:
# 📘 T5-Large Fine-Tuning with PEFT LoRA on Risk Communication Dataset

In [44]:
# 🛠️ 1. Install Required Libraries
!pip install -q transformers datasets peft accelerate sentencepiece

In [45]:
# 📥 2. Load and Preprocess the Dataset
from datasets import load_dataset
from transformers import T5Tokenizer

In [46]:
# Upload CSV manually via Colab UI or mount Google Drive
file_path = "/content/risk_data_formatted.csv"  # <-- replace if needed

In [47]:
import pandas as pd
from datasets import Dataset, DatasetDict

# Load the CSV file into a pandas DataFrame
df = pd.read_csv(file_path)

# Convert the pandas DataFrame to a datasets Dataset
# Assuming the CSV has 'input' and 'output' columns
dataset = Dataset.from_pandas(df)

# Rename columns as in the original code
dataset = dataset.rename_columns({"input": "input_text", "output": "target_text"})

# Split the dataset
# Convert the single Dataset object into a DatasetDict with 'train' and 'test' splits
train_test_split = dataset.train_test_split(test_size=0.1)
dataset = DatasetDict({'train': train_test_split['train'], 'test': train_test_split['test']})

In [48]:
# Add prompt to each input_text
def add_prompt_prefix(example):
    example["input_text"] = f"Extract structured  risk information from the following sentence: {example['input_text']}"
    return example

dataset = dataset.map(add_prompt_prefix)


Map:   0%|          | 0/486 [00:00<?, ? examples/s]

Map:   0%|          | 0/54 [00:00<?, ? examples/s]

In [49]:
tokenizer = T5Tokenizer.from_pretrained("t5-large")

In [50]:
def preprocess(example):
    input_enc = tokenizer(example["input_text"], truncation=True, padding="max_length", max_length=512)
    target_enc = tokenizer(example["target_text"], truncation=True, padding="max_length", max_length=512)
    input_enc["labels"] = target_enc["input_ids"]
    return input_enc

In [51]:
tokenized_dataset = dataset.map(preprocess, batched=True, remove_columns=["input_text", "target_text"])

Map:   0%|          | 0/486 [00:00<?, ? examples/s]

Map:   0%|          | 0/54 [00:00<?, ? examples/s]

In [52]:
# 🔌 3. Load T5-Large and Apply LoRA
from transformers import T5ForConditionalGeneration
from peft import get_peft_model, LoraConfig, TaskType

In [53]:
model = T5ForConditionalGeneration.from_pretrained("t5-large")

In [54]:
peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.1,
    bias="none"
)

In [55]:
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

trainable params: 2,359,296 || all params: 740,027,392 || trainable%: 0.3188


In [56]:
# 🎯 4. Training Setup
from transformers import TrainingArguments, Trainer, DataCollatorForSeq2Seq

In [57]:
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./t5_lora_risk_output", #storage place of trained model
    per_device_train_batch_size=1,
    #gradient_accumulation_steps=4,
    per_device_eval_batch_size=1,
    num_train_epochs=5,
    learning_rate=3e-4,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_strategy="epoch",         # ✅ Add this to log every epoch
    logging_dir="./logs",
    report_to="none",
    logging_first_step=True,          # ✅ Optional: log on first step
)


In [58]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer, model=model),
)

  trainer = Trainer(
No label_names provided for model class `PeftModelForSeq2SeqLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


In [59]:
trainer.train()

Epoch,Training Loss,Validation Loss
1,1.1792,0.169034
2,0.2062,0.135126
3,0.1685,0.123696


Epoch,Training Loss,Validation Loss
1,1.1792,0.169034
2,0.2062,0.135126
3,0.1685,0.123696
4,0.1521,0.117321
5,0.1428,0.114935


TrainOutput(global_step=2430, training_loss=0.37840457727879656, metrics={'train_runtime': 1778.5561, 'train_samples_per_second': 1.366, 'train_steps_per_second': 1.366, 'total_flos': 5278676979548160.0, 'train_loss': 0.37840457727879656, 'epoch': 5.0})

In [60]:
# 🔍 5. Inference Function

In [61]:
# 🔍 5. Inference Function
# %%
def predict(text):
    prompt = f"Extract structured risk information from the following sentence:  {text}"
    # Ensure inputs are on the correct device
    inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).input_ids.cuda()
    # Pass input_ids and max_length as keyword arguments
    outputs = model.generate(input_ids=inputs, max_length=512) # Changed: Pass inputs as keyword argument
    return tokenizer.decode(outputs[0], skip_special_tokens=True)





In [62]:
# ✅ 6. Test on a custom example
print("\nCustom test case:\n")
custom_text = (
    "Living in a city with high air pollution increases your risk of lung disease by 10%. "
    "While that may seem minor, it translates to an increase from 30 in 1,000 to 33 in "
    "1,000 people developing chronic respiratory issues over a decade."
)
predicted_custom_output = predict(custom_text)
print(f"Custom Input: {custom_text}")
print(f"Model Output: {predicted_custom_output}")

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.



Custom test case:

Custom Input: Living in a city with high air pollution increases your risk of lung disease by 10%. While that may seem minor, it translates to an increase from 30 in 1,000 to 33 in 1,000 people developing chronic respiratory issues over a decade.
Model Output: risk_communication: 1.0 single_case_base: 0.0 absolute_risk_base: null absolute_risk_new: null absolute_number_base: null absolute_number_new: null absolute_risk_difference: null relative_risk: null absolute_number_difference: null verbal_descriptor_base: null verbal_descriptor_new: null verbal_descriptor_change: increase population_size: null reference_class_size_base: null reference_class_size_new: null reference_class_description_base: people living in a city with high air pollution reference_class_description_new: people developing chronic respiratory issues over a decade source_base: null source_new: null topic_and_unit: risk of lung disease in %


In [63]:
import pandas as pd
import re
from tqdm import tqdm

# === Step 1: Extract the 10% test set ===
test_inputs = dataset["test"]["input_text"]
test_outputs = dataset["test"]["target_text"]

# === Step 2: Generate model predictions ===
pred_outputs = []
for text in tqdm(test_inputs, desc="Generating predictions"):
    input_ids = tokenizer(text, return_tensors="pt", truncation=True).input_ids.to(model.device)
    output_ids = model.generate(input_ids=input_ids, max_length=512)
    decoded = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    pred_outputs.append(decoded)

# === Step 3: Save predictions and ground truth (optional) ===
pd.DataFrame({"input": test_inputs, "output": test_outputs}).to_csv("ground_truth_test_only.csv", index=False)
pd.DataFrame({"input": test_inputs, "output": pred_outputs}).to_csv("model_outputs_test_only.csv", index=False)

# === Step 4: Evaluation Functionality ===

def parse_output(text):
    result = {}
    for line in str(text).strip().split("\n"):
        if ": " in line:
            key, value = line.split(": ", 1)
            result[key.strip()] = value.strip()
    return result

def normalize_value(val):
    val = val.lower().strip()
    if val in ["null", "n/a", "none", ""]:
        return "null"
    val = val.replace(",", "")
    match = re.match(r"([0-9.]+)\s*(million|thousand)?", val)
    if match:
        num, unit = match.groups()
        num = float(num)
        if unit == "million":
            num *= 1_000_000
        elif unit == "thousand":
            num *= 1_000
        return str(int(num)) if num.is_integer() else str(num)
    return val

# === Step 5: Compare predictions to ground truth ===
field_stats = {}
results = []

for inp, true_text, pred_text in zip(test_inputs, test_outputs, pred_outputs):
    true_dict = parse_output(true_text)
    pred_dict = parse_output(pred_text)
    row_result = {"input": inp}

    for field in true_dict:
        pred_val = normalize_value(pred_dict.get(field, "null"))
        true_val = normalize_value(true_dict.get(field, "null"))
        correct = (pred_val == true_val)
        row_result[field] = "✅" if correct else f"❌ (pred: {pred_val})"
        if field not in field_stats:
            field_stats[field] = {"correct": 0, "total": 0}
        field_stats[field]["total"] += 1
        if correct:
            field_stats[field]["correct"] += 1
    results.append(row_result)

# === Step 6: Show sample and summary ===
results_df = pd.DataFrame(results)
print("\n=== Sample Comparison ===")
print(results_df.head(3).to_string())

print("\n=== Field-wise Accuracy ===")
for field, stats in field_stats.items():
    acc = 100 * stats["correct"] / stats["total"]
    print(f"{field:50s} → {acc:.1f}% ({stats['correct']}/{stats['total']})")



Generating predictions: 100%|██████████| 54/54 [07:28<00:00,  8.31s/it]


=== Sample Comparison ===
                                                                                                                                                                                                                                                                                                                                                                                                                                      input risk_communication single_case_base absolute_risk_base absolute_risk_new absolute_number_base absolute_number_new absolute_risk_difference   relative_risk absolute_number_difference verbal_descriptor_base verbal_descriptor_new verbal_descriptor_change population_size reference_class_size_base reference_class_size_new reference_class_description_base reference_class_description_new     source_base      source_new  topic_and_unit
0                                                                                                    Extract stru


