# Introduction

The aim of this project is to create a model to classify clinical trial records (protocol text, inclusion/exclusion) to predict whether trial was completed, terminated, or withdrawn.

[This paper]() states that there are four main factors that determine whether a trial is successful:
- quality of clinical trials
- speed of clinical trials - "the shorter the clinical trial period, the lower the clinical trial cost, thereby reducing the financial burden on the company."
- relationship type - "diversity of collaboration" with many partners (especially hugh-quality partners with proven track records) are a good indicator of success
- communication - having an infrastructure for exchanging information between collaborators

The paper goes on to state that: "the most important features for predicting success of drug approval ... are trial outcomes, trial status, trial accrual rates, duration, prior approval for another indication, and the sponsor’s track record."

# Getting and Cleaning the Data

The main data source for detailed info about clinical trials is [ClinicalTrials.gov]()'s API. [Info about the contents of that is found here](https://clinicaltrials.gov/data-api/about-api/study-data-structure).

 However, for initial investigation, it will be sufficient to use data from Kaggle datasets:
 - [Clinical Trials, the Devastator](https://www.kaggle.com/datasets/thedevastator/a-quick-overview-of-clinical-trials) - contains many potentially informative features for each trial including (crucially) titles and text summaries, enrollment numbers, start/end dates etc. However, the only target variable is `status`, the values of which are Completed, Terminated, Active, Recruiting, etc. which we will use as a proxy for the success of the trial.

In [1]:
# Clear CUDA cache between runs to avoid OutOfMemory errors
import torch, gc
gc.collect()
torch.cuda.empty_cache()

In [2]:
import numpy as np
import pandas as pd
import os
from pathlib import Path

# ! pip install -q datasets

## Getting the Data

In [3]:
trial_info = pd.read_csv("../input/a-quick-overview-of-clinical-trials/AERO-BirdsEye-Data.csv")

## Exploring the Data

In [4]:
trial_info

Unnamed: 0,index,NCT,Sponsor,Title,Summary,Start_Year,Start_Month,Phase,Enrollment,Status,Condition
0,0,NCT00003305,Sanofi,A Phase II Trial of Aminopterin in Adults and ...,RATIONALE: Drugs used in chemotherapy use diff...,1997,7,Phase 2,75,Completed,Leukemia
1,1,NCT00003821,Sanofi,Phase II Trial of Aminopterin in Patients With...,RATIONALE: Drugs used in chemotherapy use diff...,1998,1,Phase 2,0,Withdrawn,Endometrial Neoplasms
2,2,NCT00004025,Sanofi,"Phase I/II Trial of the Safety, Immunogenicity...",RATIONALE: Vaccines made from a person's white...,1999,3,Phase 1/Phase 2,36,Unknown status,Melanoma
3,3,NCT00005645,Sanofi,Phase II Trial of ILX295501 Administered Orall...,RATIONALE: Drugs used in chemotherapy use diff...,1999,5,Phase 2,0,Withdrawn,Ovarian Neoplasms
4,4,NCT00008281,Sanofi,"A Multicenter, Open-Label, Randomized, Three-A...",RATIONALE: Drugs used in chemotherapy use diff...,2000,10,Phase 3,0,Unknown status,Colorectal Neoplasms
...,...,...,...,...,...,...,...,...,...,...,...
13743,13743,NCT03726879,Roche,"A Phase III, Randomized, Double-Blind, Placebo...",This study (also known as IMpassion050) will e...,2018,12,Phase 3,224,Recruiting,Breast Neoplasms
13744,13744,NCT03735121,Roche,A Two-Part Phase Ib/II Study to Investigate th...,"This study will evaluate the pharmacokinetics,...",2018,12,Phase 1,245,Recruiting,Lung Neoplasms
13745,13745,NCT03761849,Roche,"A Randomized, Multicenter, Double-Blind, Place...","This study will evaluate the efficacy, safety,...",2018,12,Phase 3,660,Not yet recruiting,Huntington Disease
13746,13746,NCT03762681,Roche,"A Randomized, Placebo-controlled,Observer-blin...","This study is designed to assess the safety, t...",2018,12,Phase 1,75,Not yet recruiting,Hepatitis


In [5]:
trial_info.describe(include='all')

  has_large_values = (abs_vals > 1e6).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()
  has_small_values = ((abs_vals < 10 ** (-self.digits)) & (abs_vals > 0)).any()


Unnamed: 0,index,NCT,Sponsor,Title,Summary,Start_Year,Start_Month,Phase,Enrollment,Status,Condition
count,13748.0,13748,13748,13604,13748,13748.0,13748.0,13485,13748.0,13748,13748
unique,,13748,10,13434,13565,,,7,,9,867
top,,NCT03779334,GSK,Human Photoallergy Test,#NAME?,,,Phase 3,,Completed,"Diabetes Mellitus, Type 2"
freq,,1,2473,7,11,,,4887,,10568,536
mean,6873.5,,,,,2009.155586,6.691155,,440.783678,,
std,3968.850085,,,,,4.797615,3.486359,,1944.530768,,
min,0.0,,,,,1984.0,1.0,,0.0,,
25%,3436.75,,,,,2006.0,4.0,,40.0,,
50%,6873.5,,,,,2009.0,7.0,,124.0,,
75%,10310.25,,,,,2013.0,10.0,,365.0,,


In [6]:
def get_sorted_counts(df, group_by):
    return df[group_by].value_counts(ascending=False).reset_index()
    
print(f"Distribution of Sponsors:\n{get_sorted_counts(trial_info, 'Sponsor')}")
print(f"\nDistribution of Phases:\n{get_sorted_counts(trial_info, 'Phase')}")
print(f"\nDistribution of Statuses:\n{get_sorted_counts(trial_info, 'Status')}")
print(f"\nDistribution of Conditions Targeted By Trial:\n{get_sorted_counts(trial_info, 'Condition')}")

Distribution of Sponsors:
    Sponsor  count
0       GSK   2473
1  Novartis   2320
2    Pfizer   1970
3     Merck   1770
4    Sanofi   1524
5       JNJ   1143
6     Roche   1095
7     Bayer    619
8    AbbVie    417
9    Gilead    417

Distribution of Phases:
             Phase  count
0          Phase 3   4887
1          Phase 2   3596
2          Phase 1   2516
3          Phase 4   2015
4  Phase 1/Phase 2    322
5  Phase 2/Phase 3    139
6    Early Phase 1     10

Distribution of Statuses:
                    Status  count
0                Completed  10568
1               Terminated   1285
2               Recruiting    800
3   Active, not recruiting    646
4                Withdrawn    291
5       Not yet recruiting    108
6           Unknown status     19
7                Suspended     16
8  Enrolling by invitation     15

Distribution of Conditions Targeted By Trial:
                                  Condition  count
0                 Diabetes Mellitus, Type 2    536
1               

...so the target variable here is `Status`. This is a very unbalanced class with the majority of examples (77%) having a `Completed` Status (which I guess will represent the 'Successful' class of our target). The 'Unsuccessful' class will be represented by the `Terminated` + `Withdrawn` + `Suspended` examples. All of the rest will be discarded, under the assumption that they are ongoing (e.g. still recruiting), so their success/failure is unknown at this time.  

## Creating the Input and Target Features

In [7]:
statuses = ['Completed', 'Terminated', 'Withdrawn', 'Suspended']
trial_info = trial_info[trial_info['Status'].isin(statuses)]
trial_info.head()

Unnamed: 0,index,NCT,Sponsor,Title,Summary,Start_Year,Start_Month,Phase,Enrollment,Status,Condition
0,0,NCT00003305,Sanofi,A Phase II Trial of Aminopterin in Adults and ...,RATIONALE: Drugs used in chemotherapy use diff...,1997,7,Phase 2,75,Completed,Leukemia
1,1,NCT00003821,Sanofi,Phase II Trial of Aminopterin in Patients With...,RATIONALE: Drugs used in chemotherapy use diff...,1998,1,Phase 2,0,Withdrawn,Endometrial Neoplasms
3,3,NCT00005645,Sanofi,Phase II Trial of ILX295501 Administered Orall...,RATIONALE: Drugs used in chemotherapy use diff...,1999,5,Phase 2,0,Withdrawn,Ovarian Neoplasms
6,6,NCT00012389,Sanofi,"A Multicenter, Open-Label, Randomized, Two-Arm...",RATIONALE: Drugs used in chemotherapy use diff...,2000,12,Phase 3,0,Completed,Colorectal Neoplasms
7,7,NCT00017459,Sanofi,The International Tirazone Triple Trial (i3T):...,RATIONALE: Drugs used in chemotherapy use diff...,2000,7,Phase 3,0,Completed,"Carcinoma, Non-Small-Cell Lung"


Okay, so now we have filtered the dataset to only contain records that have the statuses that we can interpret in terms of success. Next, let's create a new column, `Success`, that is a binary value representing whether the trial was successful (1) i.e. it completed, or unsuccessful (0) i.e. it did not complete. 

In [8]:
def status_to_success(status):
    return 1 if status == 'Completed' else 0 

trial_info['Success'] = trial_info['Status'].apply(status_to_success)
trial_info['Success'] = trial_info['Success'].astype(float)
trial_info.head()

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  trial_info['Success'] = trial_info['Status'].apply(status_to_success)
A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  trial_info['Success'] = trial_info['Success'].astype(float)


Unnamed: 0,index,NCT,Sponsor,Title,Summary,Start_Year,Start_Month,Phase,Enrollment,Status,Condition,Success
0,0,NCT00003305,Sanofi,A Phase II Trial of Aminopterin in Adults and ...,RATIONALE: Drugs used in chemotherapy use diff...,1997,7,Phase 2,75,Completed,Leukemia,1.0
1,1,NCT00003821,Sanofi,Phase II Trial of Aminopterin in Patients With...,RATIONALE: Drugs used in chemotherapy use diff...,1998,1,Phase 2,0,Withdrawn,Endometrial Neoplasms,0.0
3,3,NCT00005645,Sanofi,Phase II Trial of ILX295501 Administered Orall...,RATIONALE: Drugs used in chemotherapy use diff...,1999,5,Phase 2,0,Withdrawn,Ovarian Neoplasms,0.0
6,6,NCT00012389,Sanofi,"A Multicenter, Open-Label, Randomized, Two-Arm...",RATIONALE: Drugs used in chemotherapy use diff...,2000,12,Phase 3,0,Completed,Colorectal Neoplasms,1.0
7,7,NCT00017459,Sanofi,The International Tirazone Triple Trial (i3T):...,RATIONALE: Drugs used in chemotherapy use diff...,2000,7,Phase 3,0,Completed,"Carcinoma, Non-Small-Cell Lung",1.0


In [9]:
print(f"Total number of examples remaining:\n{len(trial_info)}")
print(f"\nDistribution of Success/Failure:\n{get_sorted_counts(trial_info, 'Success')}")

Total number of examples remaining:
12160

Distribution of Success/Failure:
   Success  count
0      1.0  10568
1      0.0   1592


Note that there is still a significant class imbalance between success (1; 87%) and unsuccessful (0; 13%). We will need to address this during training and when splitting the data into training, validation and test sets.

In [10]:
trial_info['input'] = 'TITLE: ' + trial_info.Title + '; SUMMARY: ' + trial_info.Summary + '; PHASE: ' + trial_info.Phase

A value is trying to be set on a copy of a slice from a DataFrame.
Try using .loc[row_indexer,col_indexer] = value instead

See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy
  trial_info['input'] = 'TITLE: ' + trial_info.Title + '; SUMMARY: ' + trial_info.Summary + '; PHASE: ' + trial_info.Phase


In [11]:
for text in trial_info.input.sample(10):
    print(text, end="\n\n")

TITLE: Safety and Immunogenicity of Monovalent H5N1 Vaccine GSK1557484A in Children 6 Months to < 18 Years of Age; SUMMARY: This study will assess safety and immunogenicity of GSK Biologicals' H5N1 flu candidate vaccine GSK1557484A in children 6 months to < 18 years of age.; PHASE: Phase 3

TITLE: A Single Centre, Open Label Study to Characterize the PK-PDE10A Enzyme Occupancy Relationship of RO5545965 After a Single Dose in Healthy Male Volunteers Using [11C]IMA107 Positron Emission Tomography; SUMMARY: This single center, open-label study will evaluate the PK-PDE10A Enzyme Occupancy Relationship of RO5545965 after a single dose in healthy male volunteers by positron emission tomography (PET).; PHASE: Phase 1

TITLE: An Open-label Repeat Dosing Study of Eltrombopag Olamine (SB-497115-GR) in Adult Subjects, With Chronic Idiopathic Thrombocytopenic Purpura (ITP); SUMMARY: This open-label, repeat dosing study, TRA108057, will evaluate the efficacy, safety and tolerability of eltrombopag,

In [12]:
# Remove any values from 'input' feature that are nan
data = trial_info.dropna(subset=["input"])
len(data)

11802

In [13]:
data.describe()

Unnamed: 0,index,Start_Year,Start_Month,Enrollment,Success
count,11802.0,11802.0,11802.0,11802.0,11802.0
mean,6669.111845,2008.257753,6.648365,429.082867,0.869259
std,3918.216037,4.269221,3.488122,1872.177639,0.337131
min,0.0,1984.0,1.0,0.0,0.0
25%,3199.25,2005.0,4.0,37.0,1.0
50%,6729.0,2008.0,7.0,120.0,1.0
75%,9953.75,2011.0,10.0,360.0,1.0
max,13740.0,2019.0,12.0,84496.0,1.0


## Tokenisation and Numericalisation

Transformers uses a Dataset object for storing a... well a dataset, of course! We can create one like so:

In [14]:
from datasets import Dataset,DatasetDict

In [15]:
ds = Dataset.from_pandas(data)
ds

Dataset({
    features: ['index', 'NCT', 'Sponsor', 'Title', 'Summary', 'Start_Year', 'Start_Month', 'Phase', 'Enrollment', 'Status', 'Condition', 'Success', 'input', '__index_level_0__'],
    num_rows: 11802
})

But we can't pass the texts directly into a model. A deep learning model expects numbers as inputs, not English sentences! So we need to do two things:

- *Tokenization*: Split each text up into words (or actually, as we'll see, into tokens)
- *Numericalization*: Convert each word (or token) into a number.

The details about how this is done actually depend on the particular model we use. So first we'll need to pick a model. There are thousands of models available, but a reasonable starting point for nearly any NLP problem is to use this (replace "small" with "large" for a slower but more accurate model, once you've finished exploring):

In [16]:
# model_nm = 'microsoft/deberta-v3-small'
model_nm = 'distilbert-base-uncased'

`AutoTokenizer` will create a tokenizer appropriate for a given model:

In [17]:
from transformers import AutoModelForSequenceClassification,AutoTokenizer,DistilBertTokenizer, DistilBertForSequenceClassification,DataCollatorWithPadding
# tokz = AutoTokenizer.from_pretrained(model_nm)
tokz = DistilBertTokenizer.from_pretrained(model_nm)

2025-11-13 14:11:49.983963: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1763043110.006907     950 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1763043110.013914     950 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

Here's an example of how the tokenizer splits a text into "tokens" (which are like words, but can be sub-word pieces, as you see below):

In [18]:
tokz.tokenize("Hello World! This is your pilot speaking.")

['hello', 'world', '!', 'this', 'is', 'your', 'pilot', 'speaking', '.']

Here's a simple function which tokenizes our inputs:

In [19]:
def tok_func(x): 
    return tokz(
        x["input"],
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors='pt'
    )

To run this quickly in parallel on every row in our dataset, use `map`:

In [20]:
tok_ds = ds.map(tok_func, batched=True)

Map:   0%|          | 0/11802 [00:00<?, ? examples/s]

This adds a new item to our dataset called `input_ids`. For instance, here is the input and IDs for the first row of our data:

In [21]:
row = tok_ds[0]
row['input'], row['input_ids']

('TITLE: A Phase II Trial of Aminopterin in Adults and Children With Refractory Acute Leukemia Grant Application Title: A Phase II Trial of Aminopterin in Acute Leukemia; SUMMARY: RATIONALE: Drugs used in chemotherapy use different ways to stop cancer cells from dividing so they stop growing or die. PURPOSE: Phase II trial to study the effectiveness of aminopterin in treating patients who have refractory leukemia.; PHASE: Phase 2',
 [101,
  2516,
  1024,
  1037,
  4403,
  2462,
  3979,
  1997,
  13096,
  13876,
  23282,
  1999,
  6001,
  1998,
  2336,
  2007,
  25416,
  22648,
  7062,
  11325,
  25468,
  3946,
  4646,
  2516,
  1024,
  1037,
  4403,
  2462,
  3979,
  1997,
  13096,
  13876,
  23282,
  1999,
  11325,
  25468,
  1025,
  12654,
  1024,
  11581,
  2063,
  1024,
  5850,
  2109,
  1999,
  27144,
  2224,
  2367,
  3971,
  2000,
  2644,
  4456,
  4442,
  2013,
  16023,
  2061,
  2027,
  2644,
  3652,
  2030,
  3280,
  1012,
  3800,
  1024,
  4403,
  2462,
  3979,
  2000,
  281

So, what are those IDs and where do they come from? The secret is that there's a list called vocab in the tokenizer which contains a unique integer for every possible token string. We can look them up like this, for instance to find the token for the word "of":

In [22]:
tokz.vocab['phase']

4403

Looking above at our input IDs, we do indeed see that 265 appears as expected.

Finally, we need to prepare our labels. Transformers always assumes that your labels has the column name `labels`, but in our dataset it's currently `Success`. Therefore, we need to rename it:

In [23]:
tok_ds = tok_ds.rename_columns({'Success':'label'})

Now that we've prepared our tokens and labels, we need to create our validation set.

## Preparing the Test and Validation Sets

How do we recognise whether our models are under-fit, over-fit, or "just right"? We use a validation set. This is a set of data that we "hold out" from training -- we don't let our model see it at all. The validation set is only ever used to see how we're doing. It's never used as inputs to training the model. 

Metrics (measurements of the accuracy of a model) are calculated using the validation set, to compare different models during model evaluation (by trying different models, training methods, data processing) and select the best one. But be careful! Over time, model selection can over-fit to the validation set! This means that, when comparing many different models, there is some chance that a model will perform particularly well on the validation set by chance. This is the reason that, at the end of model selection and training, the model must be evaluated again on the test set that is held out and unseen during this entire process.

We'll use `eval` as our name for the test set, to avoid confusion with the `test` dataset that is created using the `transformers` library to be the validation set.

`transformers` uses a `DatasetDict` for holding your training and validation sets. To create one that contains 25% of our data for the validation set, and 75% for the training set, you can use `train_test_split`. BUT! A random sampling may not be the best method for splitting the dataset in some real-world cases; [this article (How (and why) to create a good validation set)](https://www.fast.ai/2017/11/13/validation-sets/) explains why. For time series, in particular, it is best to use different periods within the data for the validation and test sets, in order to measure the model's ability to predict future events, not just interpolate missing points from within the time series.

For now, we'll use a random split.

### Stratified Splitting

Recall from earlier that we noticed a class imbalance in the target `label` class, with far more 1s (successful trials) than 0s (unsuccessful trials). We need to account for this at this point when splitting the dataset, making sure that each split (train, validation, test) has roughly the same proportion of each class (0 and 1) as the full dataset — otherwise, the model might overfit or fail to generalize. To do this, we will use *stratified sampling*, which preserves the class distribution across splits.

In [24]:
tok_ds = tok_ds.class_encode_column("label")

Stringifying the column:   0%|          | 0/11802 [00:00<?, ? examples/s]

Casting to class labels:   0%|          | 0/11802 [00:00<?, ? examples/s]

In [25]:
from datasets import load_dataset

# Split entire data set into training and testing data sets (80%/20%)
# Use stratification to ensure roughly equal proportion of each target class in each
train_test = tok_ds.train_test_split(
    test_size=0.2,
    stratify_by_column="label",
    seed=42
)
# Split training data into train/val (60%/20%)
val_train = train_test["train"].train_test_split(
    test_size=0.2,
    stratify_by_column="label",
    seed=42
)

# Combine into a DatasetDict
dds = DatasetDict({
    "train": val_train["train"],
    "validation": val_train["test"],
    "test": train_test["test"]
})

In [26]:
dds

DatasetDict({
    train: Dataset({
        features: ['index', 'NCT', 'Sponsor', 'Title', 'Summary', 'Start_Year', 'Start_Month', 'Phase', 'Enrollment', 'Status', 'Condition', 'label', 'input', '__index_level_0__', 'input_ids', 'attention_mask'],
        num_rows: 7552
    })
    validation: Dataset({
        features: ['index', 'NCT', 'Sponsor', 'Title', 'Summary', 'Start_Year', 'Start_Month', 'Phase', 'Enrollment', 'Status', 'Condition', 'label', 'input', '__index_level_0__', 'input_ids', 'attention_mask'],
        num_rows: 1889
    })
    test: Dataset({
        features: ['index', 'NCT', 'Sponsor', 'Title', 'Summary', 'Start_Year', 'Start_Month', 'Phase', 'Enrollment', 'Status', 'Condition', 'label', 'input', '__index_level_0__', 'input_ids', 'attention_mask'],
        num_rows: 2361
    })
})

# Model Training



## Defining the Metrics

Metrics are used to evaluate the model's performance on the validation (and test) set, and compare models to each other to select the best one. It is important to select the right metrics for the problem at hand: in this case, an imbalanced binary classification problem.

We will use:
- F1-macro score: balances precision (how many predicted positives were correct) and recall (how many actual positives were found).
- ROC–AUC (Area Under the ROC Curve): Measures ranking quality (how well the model separates 0s and 1s), insensitive to class imbalance BUT can be over-optimistic when the imbalance is extreme (in that case, Precision–Recall AUC (PR-AUC) is better).

In [27]:
from sklearn.metrics import f1_score, precision_recall_fscore_support, roc_auc_score, average_precision_score
import numpy as np
# from datasets import load_metric

In [28]:
def corr(x,y): return np.corrcoef(x,y)[0][1]
def corr_d(eval_pred): return {'pearson': corr(*eval_pred)}

In [29]:
# def compute_metrics(eval_pred):
#     logits, labels = eval_pred
#     preds = np.argmax(logits, axis=-1)
#     probs = np.exp(logits) / np.exp(logits).sum(-1, keepdims=True)  # softmax for probabilities

#     precision, recall, f1_macro, _ = precision_recall_fscore_support(
#         labels, preds, average='macro'
#     )

#     roc_auc = roc_auc_score(labels, probs[:, 1])
#     pr_auc = average_precision_score(labels, probs[:, 1])

#     return {
#         'precision_macro': precision,
#         'recall_macro': recall,
#         'f1_macro': f1_macro,
#         'roc_auc': roc_auc,
#         'pr_auc': pr_auc
#     }

def compute_metrics(eval_pred):
    logits, labels = eval_pred

    # Convert logits to predicted class IDs
    preds = np.argmax(logits, axis=-1)

    # Precision, recall, F1 (macro)
    precision, recall, f1_macro, _ = precision_recall_fscore_support(
        labels, preds, average='macro'
    )

    # ROC-AUC requires probabilities for the positive class
    probs = np.exp(logits) / np.exp(logits).sum(-1, keepdims=True)  # softmax
    roc_auc = roc_auc_score(labels, probs[:, 1])

    return {
        'precision_macro': precision,
        'recall_macro': recall,
        'f1_macro': f1_macro,
        'roc_auc': roc_auc
    }

In [30]:
data_collator = DataCollatorWithPadding(tokz)

In [31]:
# Remove all cols except inputs and labels. Reformat lists as Tensors
def reformat_dds(df):
    d = df.remove_columns([c for c in df.column_names if c not in ["input_ids", "attention_mask", "label"]])
    d.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
    return d

dds["train"] = reformat_dds(dds["train"])
dds["validation"] = reformat_dds(dds["validation"])
dds["test"] = reformat_dds(dds["test"])

In [32]:
# Check everything is formatted correctly
def check_format(df):
    if set(df.column_names) != {'label', 'input_ids', 'attention_mask'}:
        return False
    correct_format = {'input_ids': torch.Tensor, 'attention_mask': torch.Tensor, 'label': torch.Tensor}
    for i in range(3):
        row = df[i]
        for k in correct_format:
            if not isinstance(row[k], torch.Tensor):
                return False
    return True


print(f"Training set passed checks: {check_format(dds['train'])}")
print(f"Validation set passed checks: {check_format(dds['validation'])}")
print(f"Test set passed checks: {check_format(dds['test'])}")

Training set passed checks: True
Validation set passed checks: True
Test set passed checks: True


## Training the Model

In [33]:
from transformers import TrainingArguments,Trainer,EarlyStoppingCallback

In terms of the hyperparameters, we need to pick a batch size that fits our GPU, and a small number of epochs so we can run experiments quickly:

In [34]:
bs = 16
epochs = 6

The most important hyperparameter is the learning rate. fastai provides a learning rate finder to help you figure this out, but Transformers doesn't, so we'll just have to use trial and error. The idea is to find the largest value you can that doesn't cause oscillation and a worsening of the model's accuracy over time, but large enough that model training isn't too slow.

In [35]:
lr = 1e-5

Transformers uses the `TrainingArguments` class to set up arguments. Don't worry too much about the values we're using here -- they should generally work fine in most cases. It's just the 3 parameters above that you may need to change for different models.

In [36]:
args = TrainingArguments(
    # 'outputs', 
    output_dir = "/kaggle/temp/results", 
    learning_rate=lr, 
    warmup_ratio=0.1, 
    lr_scheduler_type='cosine', 
    fp16=True,
    load_best_model_at_end=True,
    metric_for_best_model="f1_macro",  # choose best checkpoint by F1-macro
    greater_is_better=True,
    eval_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=bs, 
    per_device_eval_batch_size=bs*2,
    eval_accumulation_steps=1,   # avoids keeping all eval tensors in memory
    gradient_accumulation_steps=2,
    num_train_epochs=epochs, 
    weight_decay=0.01, 
    report_to='none'
)

We can now create our model, and `Trainer`, which is a class which combines the data and model together (just like `Learner` in fastai):

In [41]:
# model = AutoModelForSequenceClassification.from_pretrained(model_nm, num_labels=2)
# model = DistilBertForSequenceClassification.from_pretrained(model_nm, num_labels=2, output_attentions=True)
model = DistilBertForSequenceClassification.from_pretrained(model_nm, num_labels=2)
model.to("cuda")

trainer = Trainer(
    model, 
    args, 
    train_dataset=dds['train'], 
    eval_dataset=dds['validation'],
    tokenizer=tokz, 
    # compute_metrics=corr_d
    data_collator=data_collator,
    compute_metrics=compute_metrics
    #callbacks=[EarlyStoppingCallback(early_stopping_patience=2)]
)

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Now that we have the `model` and the `Trainer`, it's time to train our model!

In [38]:
import datasets
dds = dds.cast_column("label", datasets.Value("int64"))
print(dds["train"].features)

Casting the dataset:   0%|          | 0/7552 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/1889 [00:00<?, ? examples/s]

Casting the dataset:   0%|          | 0/2361 [00:00<?, ? examples/s]

{'label': Value('int64'), 'input_ids': List(Value('int32')), 'attention_mask': List(Value('int8'))}


In [42]:
trainer.train()

Epoch,Training Loss,Validation Loss,Precision Macro,Recall Macro,F1 Macro,Roc Auc
1,No log,0.381116,0.434621,0.5,0.465024,0.654696
2,No log,0.367058,0.434621,0.5,0.465024,0.686513
3,0.404400,0.379132,0.434621,0.5,0.465024,0.696777
4,0.404400,0.384929,0.659956,0.542583,0.548389,0.692326
5,0.314700,0.395258,0.661054,0.571712,0.588536,0.689655
6,0.314700,0.399974,0.663403,0.584861,0.603683,0.686326


  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))


TrainOutput(global_step=1416, training_loss=0.33200575402901, metrics={'train_runtime': 307.9972, 'train_samples_per_second': 147.118, 'train_steps_per_second': 4.597, 'total_flos': 1500590691975168.0, 'train_loss': 0.33200575402901, 'epoch': 6.0})

# Model Evaluation

Let's now use the trained model to predict the labels for the test set, and then compare these to the true labels to evaluate the model's performance:

In [43]:
results = trainer.evaluate(dds["test"])
print(results)

{'eval_loss': 0.3946785628795624, 'eval_precision_macro': 0.603800306064918, 'eval_recall_macro': 0.5609753212589186, 'eval_f1_macro': 0.5713391625750284, 'eval_roc_auc': 0.7109174410315613, 'eval_runtime': 4.4758, 'eval_samples_per_second': 527.509, 'eval_steps_per_second': 16.534, 'epoch': 6.0}


## Interpreting the Model Performance

| Metric | Value | Interpretation |
|---|---|---|
|Eval Loss | 0.3947|Reasonable for binary classification. Lower is better; relative to cross-entropy, ~0.4 is acceptable.|
|Precision Macro|0.604|On average, ~60% of predicted positives/negatives are correct. Shows good correctness across both classes.|
|Recall Macro|0.561|Model captures ~56% of all true instances on average. Could be higher, but not bad for imbalanced data.|
|F1 Macro|0.571|Balanced measure of precision & recall. Improvement over early runs (was ~0.465).|
|ROC-AUC|0.711|Over 0.7 → the model ranks positive instances above negative ones well.|
|Eval speed|~527 samples/sec	Fast, suitable for practical deployment.|

# Model Interpretability

## Gradient-based attribution (saliency maps)

Techniques like Integrated Gradients, Saliency Maps, or Gradient SHAP from the Captum library can tell you which tokens contributed most to the predicted class.

Below, we use Integrated Gradients. The attribution of each word in the input text to the model's prediction is shown by the colour and depth of the word's highlighting. Deeper colours mean a stronger attribution. Red means that the word made a positive contribution (i.e. pushed the prediction closer to the target class 1; successful) while blue means that the word made a negative contribution (pushed the prediction away from the target class). 

In [44]:
# ! pip install captum
from captum.attr import IntegratedGradients
from matplotlib import cm
from IPython.display import HTML, display

example_inputs = [
    "TITLE: Effects of Two Doses of a Common Cold Treatment on Cognitive Function; SUMMARY: This study will investigate any improvement in alertness and performance based on cognitive function and mood assessment in subjects suffering the common cold, when taking a novel paracetamol and caffeine combination verses paracetamol alone.; PHASE: Phase 3", 
    "TITLE: A Single Dose Study to Assess the Pharmacokinetics of SCH 900800 Administered as Oral Tablets in L-DOPA-treated Subjects With Parkinson's Disease; SUMMARY: This study is being done to assess the pharmacokinetics of SCH 900800 in participants with moderate to severe Parkinson's Disease (PD) being treated with L-DOPA.; PHASE: Phase 1",
    "TITLE: A Registration Study of the Safety, Tolerability, and Immunogenicity of V441 in Healthy Infants in Taiwan; SUMMARY: The purpose of this study is to evaluate the safety, tolerability, and immune response of an investigational vaccine being evaluated to reduce the incidence of diphtheria, pertussis, tetanus, hepatitis B, poliomyelitis, and Haemophilus influenza type b.; PHASE: Phase 3",
    "TITLE: A Phase IIIB, Multi-Center, Open Label Study For Postmenopausal Women With Estrogen Receptor Positive Locally Advanced or Metastatic Breast Cancer Treated With Everolimus (RAD001) in Combination With Exemestane: 4EVER - Efficacy, Safety, Health Economics, Translational Research; SUMMARY: The present multi-center, open-label, single-arm study aims to evaluate the efficacy and safety, quality of life and health resources utilization in postmenopausal women with hormone receptor positive breast cancer progressing following prior therapy with non-steroidal aromatase inhibitors (NSAI) treated with the combination of Everolimus and Exemestane.; PHASE: Phase 3",
    "TITLE: A Multicenter, Randomized, Double-Blind, 'Crossover' Design Study to Evaluate the Lipid-Altering Efficacy and Safety of MK-0524B Combination Tablet Compared to MK-0524A + Simvastatin Coadministration in Patients With Primary Hypercholesterolemia and Mixed Dyslipidemia; SUMMARY: This is a 20-week clinical trial in participants with primary hypercholesterolemia or mixed dyslipidemia to demonstrate the effect of MK-0524B compared to MK-0524A + Simvastatin on lipid values.; PHASE: Phase 3",
    "TITLE: A Multicenter, Randomized, Double-blind, Phase III Trial to Evaluate the Safety, Immunogenicity, and Efficacy of MSB11022 Compared With Humira¬Æ in Patients With Moderately to Severely Active Rheumatoid Arthritis; SUMMARY: The purpose of this study is to compare the efficacy, safety and immunogenicity of MSB11022 and Humira¬Æ in adult subjects with rheumatoid arthritis.; PHASE: Phase 3",
    "TITLE: A Multicenter, Open-label Extension, Multiple Dose, Parallel Group Study To Investigate The Long-term Safety And Tolerability Of Aab-003 (Pf-05236812) Administered Intravenously In Subjects With Mild To Moderate Alzheimer's Disease Previously Treated With Aab-003 Or Placebo In Protocol B2601001; SUMMARY: This is a study to evaluate the safety and tolerability of multiple doses of AAB-003 (PF-05236812) in patients with mild to moderate Alzheimer's Disease. Patients who complete study B2601001 may participate in this trial and receive AAB-003 (PF-05236812). Each patient's participation will last approximately 52 weeks.; PHASE: Phase 1",
    "TITLE: A Phase III Randomized, Double-blind, Controlled Study Comparing Clofarabine and Cytarabine Versus Cytarabine Alone in Adult Patients 55 Years and Older With Acute Myelogenous Leukemia (AML) Who Have Relapsed or Are Refractory After Receiving up to Two Prior Induction Regimens; SUMMARY: Clofarabine (injection) is approved by the Food and Drug Administration (FDA) for the treatment of pediatric patients 1 to 21 years old with relapsed acute or refractory lymphoblastic leukemia (ALL) who have had at least 2 prior treatment regimens. There is no recommended standard treatment for relapsed or refractory acute myelogenous leukemia in older patients. Cytarabine is the most commonly used drug to treat these patients. This study will determine if there is benefit by combining clofarabine with cytarabine. Patients will be randomized to receive up to 3 cycles of treatment with either placebo in combination with cytarabine or clofarabine in combination with cytarabine. Randomization was stratified by remission status following the first induction regimen (no remission [i.e., CR1 = refractory] or remission <6 months vs CR1 = remission ‚â•6 months). CR1 is defined as remission after first pre-study induction regimen. The safety and tolerability of clofarabine in combination with cytarabine and cytarabine alone will be monitored throughout the study.; PHASE: Phase 3",
    "TITLE: A Phase 2 Multi-Centre, Randomised, Double-Blind, Placebo Controlled, Dose Ranging Study Of TBC3711 In Subjects With Resistant Hypertension; SUMMARY: The study was to determine the safe and effective dose of TBC3711 in patients with uncontrolled high blood pressure while already taking blood pressure medications.; PHASE: Phase 2",
    "TITLE: A Multicenter, Open Label, Phase 1B Study of Escalating Doses of RO5045337 Administered Orally, With Cytarabine Administered A) Subcutaneously, or B) Intravenously, in Patients With Acute Myelogenous Leukemia (AML); SUMMARY: This multi-center, open-label, Phase 1b study will evaluate the safety, pharmacokinetics and efficacy of RO5045337 in combination with cytarabine in patients with acute myelogenous leukemia. In Arm A, cohorts of previously untreated patients deemed unsuitable for standard induction therapy will receive escalating oral doses of RO5045377 and cytarabine 20 mg/m2 subcutaneously daily for Days 1 to 10 of each 28-day cycle. In Arm B, cohorts of patients who have relapsed or are refractory after at least one cytarabine/anthracycline containing regimen will receive escalating oral doses of RO5045377 on Days 1 to 5 and cytarabine 1 gm/m2 intravenously on Days 1 to 6 of each 28-day cycle. Patients will receive up to 4 cycles of therapy, patients in Arm A who achieve hematologic response may continue additional cycles until disease progression.; PHASE: Phase 1"
]

def gradient_based_attribution(text):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    # Tokenize input
    model.to(device)
    inputs = tokz(text, return_tensors="pt")
    input_ids = inputs["input_ids"].to(device)  
    attention_mask = inputs["attention_mask"].to(device)
    
    # Get embeddings for the tokens
    embeddings = model.distilbert.embeddings(input_ids)  # FloatTensor [1, seq_len, hidden_size]
    embeddings.requires_grad_()  # Important for gradients

    # Define forward function taking embeddings as input
    def forward_embeds(embeds, attention_mask):
        outputs = model.distilbert(inputs_embeds=embeds, attention_mask=attention_mask)
        cls_embeds = outputs.last_hidden_state[:, 0, :]
        logits = model.classifier(cls_embeds)
        return logits
    
    # Initialize IntegratedGradients
    ig = IntegratedGradients(forward_embeds)
    
    # Compute attributions
    attr, delta = ig.attribute(
        embeddings,  # Float tensor, requires_grad=True
        additional_forward_args=(attention_mask,),
        target=1,  # class index
        return_convergence_delta=True
    )

    # Summarize token importance
    token_importance = attr.sum(dim=-1).squeeze(0).detach().cpu().numpy()
    tokens = tokz.convert_ids_to_tokens(input_ids.squeeze(0).cpu())
    
    # Normalize attributions for color mapping
    max_val = np.max(np.abs(token_importance)) + 1e-10  # avoid div by zero
    norm_attr = token_importance / max_val
    
    # # Sort tokens by absolute attribution
    # sorted_indices = np.argsort(np.abs(token_importance))[::-1]  # descending by strength
    # print(f"{'Token':15} {'Attribution':>12}")
    # print("-" * 28)
    # for idx in sorted_indices:
    #     print(f"{tokens[idx]:15} {token_importance[idx]:12.4f}")

    # Generate HTML color-coded text
    def colorize_token(token, score):
        """
        Red = positive contribution, Blue = negative contribution
        """
        cmap = cm.bwr  # blue-white-red colormap
        color = cmap(0.5 + 0.5 * score)  # map [-1,1] to [0,1]
        color_hex = "#{:02x}{:02x}{:02x}".format(
            int(color[0]*255), int(color[1]*255), int(color[2]*255)
        )
        return f"<span style='background-color:{color_hex}'>{token} </span>"
    
    colored_tokens = [colorize_token(t, s) for t, s in zip(tokens, norm_attr)]
    html_str = " ".join(colored_tokens)
    
    display(HTML(html_str))

for i in range(0, len(example_inputs)):
    gradient_based_attribution(example_inputs[i])
    print("\n")









































## Feature importance via attention

Transformers have attention layers, which determine how much each token “attends” to other tokens. We can extract attention weights from the trained model:

In [45]:
import matplotlib.pyplot as plt
import seaborn as sns

# -----------------------------
# 1️⃣ Setup
# -----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model.to(device)
model.eval()

# -----------------------------
# 2️⃣ Tokenize input
# -----------------------------
inputs = tokz(example_inputs[0], return_tensors="pt")
inputs = {k: v.to(device) for k, v in inputs.items()}

# -----------------------------
# 3️⃣ Forward pass to get attentions
# -----------------------------
with torch.no_grad():
    outputs = model(**inputs)

attentions = outputs.attentions  # list of layers
num_layers = len(attentions)
num_heads = attentions[0].shape[1]
seq_len = attentions[0].shape[2]

# -----------------------------
# 4️⃣ Aggregate attention across heads and layers
# -----------------------------
# Sum across heads, then average across layers
attn_matrix = torch.stack(attentions).mean(0)       # [batch, heads, seq_len, seq_len]
attn_matrix = attn_matrix[0].mean(0)               # average heads -> [seq_len, seq_len]

# -----------------------------
# 5️⃣ Compute per-token attention scores
# -----------------------------
# Sum attention each token receives from all others
token_importance = attn_matrix.sum(dim=0).cpu().numpy()
tokens = tokz.convert_ids_to_tokens(inputs["input_ids"][0].cpu())

# -----------------------------
# 6️⃣ Select top N tokens for clarity
# -----------------------------
top_n = 8
top_indices = np.argsort(token_importance)[-top_n:][::-1]  # descending
top_tokens = [tokens[i] for i in top_indices]
top_scores = token_importance[top_indices]

# -----------------------------
# 7️⃣ Plot clean bar chart
# -----------------------------
plt.figure(figsize=(8,4))
plt.bar(top_tokens, top_scores, color='skyblue')
plt.title("Top Tokens by Aggregated Attention")
plt.ylabel("Attention Score")
plt.xticks(rotation=45)
plt.show()


TypeError: object of type 'NoneType' has no len()