### Initial Multitask Model

In [2]:
import socket
print(f"Running on: {socket.gethostname()}")

Running on: landonia01.inf.ed.ac.uk


In [3]:
# check CUDA available
import torch
print(f"CUDA available: {torch.cuda.is_available()}")
print(f"CUDA device count: {torch.cuda.device_count()}")
if torch.cuda.is_available():
    print(f"Current device: {torch.cuda.current_device()}")
    print(f"Device name: {torch.cuda.get_device_name(0)}")

CUDA available: True
CUDA device count: 1
Current device: 0
Device name: NVIDIA GeForce GTX TITAN X


### Import Necessary Libraries

In [1]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from transformers import T5Tokenizer, T5ForConditionalGeneration # huggingface T5 model + tokenizer
from sklearn.metrics import accuracy_score, f1_score, recall_score, roc_auc_score, confusion_matrix, classification_report
import torchmetrics
from model import MultitaskModel, MultitaskDataset, create_weighted_sampler, get_parameter_groups, train, evaluate, evaluate_with_sampling, train_and_validate_model, setup_training_pipeline
from sklearn.model_selection import train_test_split
from bert_score import score as bert_score
from nltk.translate.meteor_score import meteor_score
import nltk
import random

### Import Dataset

In [4]:
mwr_df_simple = pd.read_csv('mwr_simple.csv')

### Extract Features

In [5]:
print(mwr_df_simple.columns)

Index(['Examination ID', 'Conclusion', 'r:Th', 'Weight', 'Height',
       'Ambient temperature', 'r:AgeInYears', 'Mammary diameter', 'Cycle',
       'Day from the first day', 'Hormonal medications',
       'Cancer family history', 'Breast operations', 'Num of pregnancies',
       'R1 int', 'L1 int', 'R2 int', 'L2 int', 'R3 int', 'L3 int', 'R4 int',
       'L4 int', 'R5 int', 'L5 int', 'R6 int', 'L6 int', 'R7 int', 'L7 int',
       'R8 int', 'L8 int', 'R9 int', 'L9 int', 'T1 int', 'T2 int', 'R0 int',
       'L0 int', 'R1 sk', 'L1 sk', 'R2 sk', 'L2 sk', 'R3 sk', 'L3 sk', 'R4 sk',
       'L4 sk', 'R5 sk', 'L5 sk', 'R6 sk', 'L6 sk', 'R7 sk', 'L7 sk', 'R8 sk',
       'L8 sk', 'R9 sk', 'L9 sk', 'T1 sk', 'T2 sk', 'R0 sk', 'L0 sk',
       'Conclusion (Tr)', 'Synthetic_Conclusion', 'y_binary'],
      dtype='object')


In [6]:
# Select feature columns (temperature readings)
feature_cols = [col for col in mwr_df_simple.columns if col.endswith('int') or col.endswith('sk')]
mwr_df_simple['features'] = mwr_df_simple[feature_cols].values.tolist()

# Prepare labels and text targets (binary classification)
mwr_df_simple['class_label'] = mwr_df_simple['y_binary'].astype(int)
mwr_df_simple['synthetic_description'] = mwr_df_simple['Synthetic_Conclusion']

In [7]:
# check class distribution
print(mwr_df_simple['class_label'].value_counts(normalize=True))

class_label
1    0.75883
0    0.24117
Name: proportion, dtype: float64


### Split Data into Test and Train

In [8]:
# 70% training, 15% validation, 15% testing
train_simple_df, temp_simple_df = train_test_split(mwr_df_simple, test_size=0.3, random_state=42)
val_simple_df, test_simple_df = train_test_split(temp_simple_df, test_size=0.5, random_state=42)

### Model Initialisation

In [9]:
tokenizer, train_loader, val_loader, test_loader, model, optimizer, device = setup_training_pipeline(
    df_train=train_simple_df,
    df_val=val_simple_df,
    df_test=test_simple_df,
    multitask_model_class=MultitaskModel,
    multitask_dataset_class=MultitaskDataset
)

Loading tokenizer...


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


Creating datasets...
Creating dataloaders...
Setting up device...
Initializing model...
Initial weights - CLF: 1.0, GEN: 0.001
Moving model to device...
Setting up optimizer...
After .to(device) - CLF: 1.0, GEN: 0.001
✓ Setup complete!



### Train Model

In [11]:
best_acc, best_f1, best_auc = train_and_validate_model(
    model, train_loader, val_loader, optimizer, device, num_epochs=30
)


Epoch 1/30


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Training loss: 0.6162
Validation loss: 0.5168
Validation metrics:
  Accuracy     : 0.7466
  F1-Score     : 0.7162
  Sensitivity  : 0.9090
  Specificity  : 0.2435
  AUC-ROC      : 0.7046
  Confusion Matrix: [[216, 671], [250, 2498]]
✓ New best accuracy: 0.7466
✓ New best F1: 0.7162
✓ New best AUC-ROC: 0.7046

Epoch 2/30
Training loss: 0.5520
Validation loss: 0.5146
Validation metrics:
  Accuracy     : 0.7422
  F1-Score     : 0.7361
  Sensitivity  : 0.8501
  Specificity  : 0.4081
  AUC-ROC      : 0.7302
  Confusion Matrix: [[362, 525], [412, 2336]]
✓ New best F1: 0.7361
✓ New best AUC-ROC: 0.7302

Epoch 3/30
Training loss: 0.5318
Validation loss: 0.5295
Validation metrics:
  Accuracy     : 0.7326
  F1-Score     : 0.7360
  Sensitivity  : 0.8104
  Specificity  : 0.4915
  AUC-ROC      : 0.7383
  Confusion Matrix: [[436, 451], [521, 2227]]
✓ New best AUC-ROC: 0.7383
Current loss weights - CLF: 1.000, GEN: 0.001

Epoch 4/30
Training loss: 0.5196
Validation loss: 0.5296
Validation metrics:
  A

Training loss: 0.4458
Validation loss: 0.4598
Validation metrics:
  Accuracy     : 0.7684
  F1-Score     : 0.7713
  Sensitivity  : 0.8341
  Specificity  : 0.5648
  AUC-ROC      : 0.8072
  Confusion Matrix: [[501, 386], [456, 2292]]

Epoch 30/30
Training loss: 0.4404
Validation loss: 0.4787
Validation metrics:
  Accuracy     : 0.7554
  F1-Score     : 0.7637
  Sensitivity  : 0.7991
  Specificity  : 0.6201
  AUC-ROC      : 0.8048
  Confusion Matrix: [[550, 337], [552, 2196]]
Current loss weights - CLF: 1.000, GEN: 0.001


In [17]:
# test evaluation, gen=0.001
print(f"\nTraining completed!")
print(f"Best results from validation:")
print(f"  Best Accuracy: {best_acc:.4f}")
print(f"  Best F1: {best_f1:.4f}")
print(f"  Best AUC-ROC: {best_auc:.4f}")

print("\n" + "="*50)
print("FINAL TEST SET EVALUATION")
print("="*50)

# Quick classification metrics on full test set
print("\n1. Full Classification Performance:")
test_loss, test_metrics = evaluate(model, test_loader, device)
print(f"Test loss: {test_loss:.4f}")
print(f"Classification metrics (all {len(test_simple_df)} samples):")
print(f"  Accuracy     : {test_metrics['accuracy']:.4f}")
print(f"  F1-Score     : {test_metrics['f1_score']:.4f}")
print(f"  Sensitivity  : {test_metrics['sensitivity']:.4f}")
print(f"  Specificity  : {test_metrics['specificity']:.4f}")
print(f"  AUC-ROC      : {test_metrics['auc_roc']:.4f}" if test_metrics['auc_roc'] is not None else "  AUC-ROC      : N/A")

# Sampled text generation metrics
print("\n2. Sampled Text Generation Performance:")
# Use 500 samples for text metrics (adjust as needed)
test_loss_sampled, test_metrics_sampled = evaluate_with_sampling(
    model, test_loader, device, tokenizer, text_sample_size=500
)

print(f"Text generation metrics ({test_metrics_sampled['text_samples_used']} samples):")
if test_metrics_sampled['avg_bertscore_f1'] is not None:
    print(f"  BERTScore F1 : {test_metrics_sampled['avg_bertscore_f1']:.4f}")
else:
    print(f"  BERTScore F1 : N/A")

if test_metrics_sampled['avg_meteor'] is not None:
    print(f"  METEOR       : {test_metrics_sampled['avg_meteor']:.4f}")
else:
    print(f"  METEOR       : N/A")

print(f"\nSample efficiency: {test_metrics_sampled['text_samples_used']}/{test_metrics_sampled['total_samples']} samples used for text metrics")


Training completed!
Best results from validation:
  Best Accuracy: 0.7780
  Best F1: 0.7762
  Best AUC-ROC: 0.8117

FINAL TEST SET EVALUATION

1. Full Classification Performance:


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Test loss: 0.4615
Classification metrics (all 3636 samples):
  Accuracy     : 0.7731
  F1-Score     : 0.7807
  Sensitivity  : 0.8185
  Specificity  : 0.6213
  AUC-ROC      : 0.8193

2. Sampled Text Generation Performance:
Sampling 500 out of 3636 samples for text generation metrics


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

Computing text metrics on 500 samples...
Error calculating BERT Score: '/home/s2080063/MWR-to-Text/models/roberta-large'
Text generation metrics (500 samples):
  BERTScore F1 : N/A
  METEOR       : 0.9963

Sample efficiency: 500/3636 samples used for text metrics


In [10]:
# Training loop, gen=0
num_epochs = 30
best_accuracy = 0
best_f1 = 0
best_auc = 0

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    
    # Train
    train_loss = train(model, train_loader, optimizer, device)
    print(f"Training loss: {train_loss:.4f}")
    
    # Validate
    val_loss, val_metrics = evaluate(model, val_loader, device)
    print(f"Validation loss: {val_loss:.4f}")
    print(f"Validation metrics:")
    print(f"  Accuracy     : {val_metrics['accuracy']:.4f}")
    print(f"  F1-Score     : {val_metrics['f1_score']:.4f}")
    print(f"  Sensitivity  : {val_metrics['sensitivity']:.4f}")
    print(f"  Specificity  : {val_metrics['specificity']:.4f}")
    print(f"  AUC-ROC      : {val_metrics['auc_roc']:.4f}" if val_metrics['auc_roc'] is not None else "  AUC-ROC      : N/A")
    print(f"  Confusion Matrix: {val_metrics['confusion_matrix']}")

    # Track best accuracy (or optionally best F1/AUC too)
    if val_metrics['accuracy'] > best_accuracy:
        best_accuracy = val_metrics['accuracy']
        print(f"✓ New best accuracy: {best_accuracy:.4f}")
    
    # Track best AUC and F1 if relevant
    if val_metrics['f1_score'] > best_f1:
        best_f1 = val_metrics['f1_score']
        print(f"✓ New best F1: {best_f1:.4f}")

    if val_metrics['auc_roc'] is not None and val_metrics['auc_roc'] > best_auc:
        best_auc = val_metrics['auc_roc']
        print(f"✓ New best AUC-ROC: {best_auc:.4f}")
    
    # Print the learned loss weights every few epochs
    if (epoch + 1) % 3 == 0:
        print(f"Current loss weights - CLF: {model.classification_weight:.3f}, GEN: {model.generation_weight:.3f}")


Epoch 1/30


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Training loss: 0.6356
Validation loss: 0.5216
Validation metrics:
  Accuracy     : 0.7494
  F1-Score     : 0.7314
  Sensitivity  : 0.8865
  Specificity  : 0.3247
  AUC-ROC      : 0.7126
  Confusion Matrix: [[288, 599], [312, 2436]]
✓ New best accuracy: 0.7494
✓ New best F1: 0.7314
✓ New best AUC-ROC: 0.7126

Epoch 2/30
Training loss: 0.5556
Validation loss: 0.5306
Validation metrics:
  Accuracy     : 0.7365
  F1-Score     : 0.7339
  Sensitivity  : 0.8344
  Specificity  : 0.4329
  AUC-ROC      : 0.7294
  Confusion Matrix: [[384, 503], [455, 2293]]
✓ New best F1: 0.7339
✓ New best AUC-ROC: 0.7294

Epoch 3/30
Training loss: 0.5283
Validation loss: 0.5307
Validation metrics:
  Accuracy     : 0.7227
  F1-Score     : 0.7285
  Sensitivity  : 0.7944
  Specificity  : 0.5006
  AUC-ROC      : 0.7429
  Confusion Matrix: [[444, 443], [565, 2183]]
✓ New best AUC-ROC: 0.7429
Current loss weights - CLF: 1.000, GEN: 0.000

Epoch 4/30
Training loss: 0.5217
Validation loss: 0.5108
Validation metrics:
  A

Training loss: 0.4465
Validation loss: 0.4997
Validation metrics:
  Accuracy     : 0.7367
  F1-Score     : 0.7508
  Sensitivity  : 0.7515
  Specificity  : 0.6911
  AUC-ROC      : 0.8020
  Confusion Matrix: [[613, 274], [683, 2065]]

Epoch 30/30
Training loss: 0.4485
Validation loss: 0.4987
Validation metrics:
  Accuracy     : 0.7400
  F1-Score     : 0.7524
  Sensitivity  : 0.7656
  Specificity  : 0.6607
  AUC-ROC      : 0.7962
  Confusion Matrix: [[586, 301], [644, 2104]]
Current loss weights - CLF: 1.000, GEN: 0.000


In [10]:
# Training loop, 0.1, gen weight=0.001
num_epochs = 30
best_accuracy = 0
best_f1 = 0
best_auc = 0

for epoch in range(num_epochs):
    print(f"\nEpoch {epoch + 1}/{num_epochs}")
    
    # Train
    train_loss = train(model, train_loader, optimizer, device)
    print(f"Training loss: {train_loss:.4f}")
    
    # Validate
    val_loss, val_metrics = evaluate(model, val_loader, device)
    print(f"Validation loss: {val_loss:.4f}")
    print(f"Validation metrics:")
    print(f"  Accuracy     : {val_metrics['accuracy']:.4f}")
    print(f"  F1-Score     : {val_metrics['f1_score']:.4f}")
    print(f"  Sensitivity  : {val_metrics['sensitivity']:.4f}")
    print(f"  Specificity  : {val_metrics['specificity']:.4f}")
    print(f"  AUC-ROC      : {val_metrics['auc_roc']:.4f}" if val_metrics['auc_roc'] is not None else "  AUC-ROC      : N/A")
    print(f"  Confusion Matrix: {val_metrics['confusion_matrix']}")

    # Track best accuracy (or optionally best F1/AUC too)
    if val_metrics['accuracy'] > best_accuracy:
        best_accuracy = val_metrics['accuracy']
        print(f"✓ New best accuracy: {best_accuracy:.4f}")
    
    # Track best AUC and F1 if relevant
    if val_metrics['f1_score'] > best_f1:
        best_f1 = val_metrics['f1_score']
        print(f"✓ New best F1: {best_f1:.4f}")

    if val_metrics['auc_roc'] is not None and val_metrics['auc_roc'] > best_auc:
        best_auc = val_metrics['auc_roc']
        print(f"✓ New best AUC-ROC: {best_auc:.4f}")
    
    # Print the learned loss weights every few epochs
    if (epoch + 1) % 3 == 0:
        print(f"Current loss weights - CLF: {model.classification_weight:.3f}, GEN: {model.generation_weight:.3f}")


Epoch 1/30


Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.48.0. You should pass an instance of `EncoderDecoderCache` instead, e.g. `past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`.


Training loss: 0.7206
Validation loss: 0.5232
Validation metrics:
  Accuracy     : 0.7486
  F1-Score     : 0.7113
  Sensitivity  : 0.9221
  Specificity  : 0.2108
  AUC-ROC      : 0.6963
  Confusion Matrix: [[187, 700], [214, 2534]]
✓ New best accuracy: 0.7486
✓ New best F1: 0.7113
✓ New best AUC-ROC: 0.6963

Epoch 2/30
Training loss: 0.5599
Validation loss: 0.5088
Validation metrics:
  Accuracy     : 0.7629
  F1-Score     : 0.7282
  Sensitivity  : 0.9309
  Specificity  : 0.2424
  AUC-ROC      : 0.7228
  Confusion Matrix: [[215, 672], [190, 2558]]
✓ New best accuracy: 0.7629
✓ New best F1: 0.7282
✓ New best AUC-ROC: 0.7228

Epoch 3/30
Training loss: 0.5391
Validation loss: 0.5003
Validation metrics:
  Accuracy     : 0.7554
  F1-Score     : 0.7336
  Sensitivity  : 0.8999
  Specificity  : 0.3078
  AUC-ROC      : 0.7387
  Confusion Matrix: [[273, 614], [275, 2473]]
✓ New best F1: 0.7336
✓ New best AUC-ROC: 0.7387
Current loss weights - CLF: 1.000, GEN: 0.001

Epoch 4/30
Training loss: 0.52

Training loss: 0.4408
Validation loss: 0.4697
Validation metrics:
  Accuracy     : 0.7618
  F1-Score     : 0.7671
  Sensitivity  : 0.8184
  Specificity  : 0.5862
  AUC-ROC      : 0.8041
  Confusion Matrix: [[520, 367], [499, 2249]]

Epoch 30/30
Training loss: 0.4410
Validation loss: 0.4848
Validation metrics:
  Accuracy     : 0.7538
  F1-Score     : 0.7614
  Sensitivity  : 0.8020
  Specificity  : 0.6043
  AUC-ROC      : 0.7976
  Confusion Matrix: [[536, 351], [544, 2204]]
Current loss weights - CLF: 1.000, GEN: 0.001


In [11]:
# test evaluation, gen=0.001
print(f"\nTraining completed!")
print(f"Best results from validation:")
print(f"  Best Accuracy: {best_accuracy:.4f}")
print(f"  Best F1: {best_f1:.4f}")
print(f"  Best AUC-ROC: {best_auc:.4f}")

print("\n" + "="*50)
print("FINAL TEST SET EVALUATION")
print("="*50)

# Quick classification metrics on full test set
print("\n1. Full Classification Performance:")
test_loss, test_metrics = evaluate(model, test_loader, device)
print(f"Test loss: {test_loss:.4f}")
print(f"Classification metrics (all {len(test_dataset)} samples):")
print(f"  Accuracy     : {test_metrics['accuracy']:.4f}")
print(f"  F1-Score     : {test_metrics['f1_score']:.4f}")
print(f"  Sensitivity  : {test_metrics['sensitivity']:.4f}")
print(f"  Specificity  : {test_metrics['specificity']:.4f}")
print(f"  AUC-ROC      : {test_metrics['auc_roc']:.4f}" if test_metrics['auc_roc'] is not None else "  AUC-ROC      : N/A")

# Sampled text generation metrics
print("\n2. Sampled Text Generation Performance:")
# Use 500 samples for text metrics (adjust as needed)
test_loss_sampled, test_metrics_sampled = evaluate_with_sampling(
    model, test_loader, device, tokenizer, text_sample_size=500
)

print(f"Text generation metrics ({test_metrics_sampled['text_samples_used']} samples):")
if test_metrics_sampled['avg_bertscore_f1'] is not None:
    print(f"  BERTScore F1 : {test_metrics_sampled['avg_bertscore_f1']:.4f}")
else:
    print(f"  BERTScore F1 : N/A")

if test_metrics_sampled['avg_meteor'] is not None:
    print(f"  METEOR       : {test_metrics_sampled['avg_meteor']:.4f}")
else:
    print(f"  METEOR       : N/A")

print(f"\nSample efficiency: {test_metrics_sampled['text_samples_used']}/{test_metrics_sampled['total_samples']} samples used for text metrics")


Training completed!
Best results from validation:
  Best Accuracy: 0.7846
  Best F1: 0.7750
  Best AUC-ROC: 0.8084

FINAL TEST SET EVALUATION

1. Full Classification Performance:


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.


Test loss: 0.4716
Classification metrics (all 3636 samples):
  Accuracy     : 0.7701
  F1-Score     : 0.7773
  Sensitivity  : 0.8192
  Specificity  : 0.6057
  AUC-ROC      : 0.8068

2. Sampled Text Generation Performance:
Sampling 500 out of 3636 samples for text generation metrics


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.
The following generation flags are not va

Computing text metrics on 500 samples...
Error calculating BERT Score: '/home/s2080063/MWR-to-Text/models/roberta-large'
Text generation metrics (500 samples):
  BERTScore F1 : N/A
  METEOR       : 0.9961

Sample efficiency: 500/3636 samples used for text metrics
