<a href="https://colab.research.google.com/github/davidsolow/med-abbrev-mystery/blob/kiara/MS_BERT_fine_tuning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import pandas as pd
import numpy as np
import tensorflow as tf

In [2]:
from tensorflow import keras

In [3]:
!pip install transformers==4.37.2



In [5]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [6]:
train = pd.read_csv("drive/MyDrive/266Project/train-3.csv")
#test = pd.read_csv("drive/MyDrive/266Project/test.csv")
validation = pd.read_csv("drive/MyDrive/266Project/validation.csv")

In [7]:
train = train.sample(frac=0.005)
validation = validation.sample(frac=0.01)
print(len(train))
print(len(validation))

15000
10000


In [4]:
from transformers import BertTokenizer, TFBertModel

access_token = "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"

tokenizer = BertTokenizer.from_pretrained("NLP4H/ms_bert", token = access_token)
model = TFBertModel.from_pretrained("NLP4H/ms_bert", token = access_token)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.
Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch

In [8]:
#cleaning location and label columns
def clean_location(location):
  """Takes a number in brackets as input and reterns the number as an int"""
  return int(str(location).strip("[]"))

def clean_label(label):
  """Takes a label in brackets and quotes as input and reterns the label as a string"""
  return label.strip("[]'")

for dataset in [train, validation]:
  dataset['location'] = dataset['location'].apply(clean_location)
  dataset['label'] = dataset['label'].apply(clean_label)

In [9]:
#converting labels to integers
def make_label_map(labels):
  label_map = {}
  for i in range(len(labels.unique())):
    label_map[labels.unique()[i]] = i
  return label_map

label_map = make_label_map(train['label'])
valid_labels = label_map.keys()
validation = validation[validation['label'].isin(valid_labels)]
for dataset in [train, validation]:
   dataset.loc[:, 'label'] = dataset['label'].map(label_map)

num_classes = len(label_map)

In [10]:
print(num_classes)

9381


In [11]:
#filtering by text length and location
max_length = 200
max_location = max_length - 3 # minus [CLS] and [SEP] tokens added and index offset

def add_abbreviation_col(dataset):
    """Adds an abbreviation column to the dataset from the specified location in the text"""
    dataset['abbreviation'] = dataset.apply(lambda row: row['text'].split()[row['location']], axis=1)
    return dataset

def clean_dataset(dataset):
    dataset = dataset.loc[dataset['location'] <= max_location].copy()
    return dataset

for dataset in [train, validation]:
  dataset['abbreviation'] = dataset.apply(lambda row: row['text'].split()[row['location']], axis=1)
  clean_dataset(dataset)

In [12]:
train.head()

Unnamed: 0.1,Unnamed: 0,abstract_id,text,location,label,abbreviation
223160,223160,1004864,cardiovascular responsiveness to sympathoadren...,169,0,VET
2846266,2846266,8270621,various strategies have been studied to reduce...,10,1,ROC
2908905,2908905,3319055,a dayold boy presented to our emergency depart...,110,2,IPA
1248734,1248734,8300813,tularemia is caused by two subspecies of franc...,57,3,SSH
192559,192559,2482370,we have examined the responsiveness of the ver...,50,4,Ra


In [13]:
def get_abbrev_token_positions(text, abbrev):
    """
    Takes text and abbreviation and finds the start and end index of the
    tokenized representation of that abbreviation in the text
    """
    tokenized_text = tokenizer.tokenize(text)
    tokenized_abbrev = tokenizer.tokenize(abbrev)
    token_ids_text = tokenizer.convert_tokens_to_ids(tokenized_text)
    token_ids_abbrev = tokenizer.convert_tokens_to_ids(tokenized_abbrev)
    start = -1
    for i in range(len(token_ids_text) - len(token_ids_abbrev) + 1):
        if token_ids_text[i:i+len(token_ids_abbrev)] == token_ids_abbrev:
            start = i
            break

    if start == -1:
        raise ValueError(f"Abbreviation '{abbrev}' not found in text '{text}'")

    end = start + len(token_ids_abbrev) - 1
    return start, end


def extract_abbrev_positions_from_dataset(dataset):
    """Extracts all the start and end position of each abbreviation in a dataset"""
    start_positions = []
    end_positions = []
    for i, row in dataset.iterrows():
        start, end = get_abbrev_token_positions(row['text'], row['abbreviation'])
        start_positions.append(start + 1)  # add 1 to account for CLS token at start
        end_positions.append(end + 1)

    return start_positions, end_positions


train_start_positions, train_end_positions = extract_abbrev_positions_from_dataset(train)
valid_start_positions, valid_end_positions = extract_abbrev_positions_from_dataset(validation)
#test_start_positions, test_end_positions = extract_abbrev_positions_from_dataset(test_subset)

In [14]:
MAX_SEQUENCE_LENGTH = 200
train_list = train.text.tolist()
validation_list = validation.text.tolist()

train_tokenized = tokenizer(train_list,
              max_length=MAX_SEQUENCE_LENGTH,
              truncation=True,
              padding='max_length',
              return_tensors='tf')

train_inputs = [train_tokenized.input_ids,
                train_tokenized.token_type_ids,
                train_tokenized.attention_mask]

train_labels = np.array(train.label)

validation_tokenized=tokenizer(validation_list,
              max_length=MAX_SEQUENCE_LENGTH,
              truncation=True,
              padding='max_length',
              return_tensors='tf')

validation_inputs = [validation_tokenized.input_ids,
                     validation_tokenized.token_type_ids,
                     validation_tokenized.attention_mask]

validation_labels = np.array(validation.label)
train_locations = tf.convert_to_tensor(np.array(train.location), dtype=tf.int32)
validation_locations = tf.convert_to_tensor(np.array(validation.location), dtype=tf.int32)

In [15]:
RANDOM_SEED = 42

In [21]:
class ExtractAbbreviationHiddenStates(tf.keras.layers.Layer):
    """
    Custom layer that extracts abbreviation embeddings from BERT
    hidden layer state and position inputs
    """
    def call(self, inputs):
        last_hidden_state, start_abbrev_token_positions, end_abbrev_token_positions = inputs
        batch_size = tf.shape(last_hidden_state)[0]

        max_length = tf.shape(last_hidden_state)[1]
        hidden_size = tf.shape(last_hidden_state)[2]

        span_hidden_states = tf.TensorArray(tf.float32, size=batch_size)

        for i in tf.range(batch_size):
            start_pos = start_abbrev_token_positions[i, 0]
            end_pos = end_abbrev_token_positions[i, 0]

            start_pos = tf.clip_by_value(start_pos, 0, max_length - 1)
            end_pos = tf.clip_by_value(end_pos, 0, max_length - 1)

            span_hidden_state = last_hidden_state[i, start_pos:end_pos + 1, :]
            span_length = end_pos - start_pos + 1

            # pad to the maximum length
            padded_span_hidden_state = tf.pad(span_hidden_state, [[0, max_length - span_length], [0, 0]])
            span_hidden_states = span_hidden_states.write(i, padded_span_hidden_state)

        return span_hidden_states.stack()

def create_bert_multiclass_model(checkpoint="NLP4H/ms_bert", num_classes=10, learning_rate=0.00005):
    """
    Build a simple classification model with BERT. Use the Pooler Output for classification purposes.
    """
    tf.keras.backend.clear_session()
    tf.random.set_seed(42)

    input_ids = tf.keras.layers.Input(shape=(200,), dtype=tf.int32, name='input_ids_layer')
    token_type_ids = tf.keras.layers.Input(shape=(200,), dtype=tf.int32, name='token_type_ids_layer')
    attention_mask = tf.keras.layers.Input(shape=(200,), dtype=tf.int32, name='attention_mask_layer')
    start_abbrev_token_positions = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name='start_abbreviation_token_positions_layer')
    end_abbrev_token_positions = tf.keras.layers.Input(shape=(1,), dtype=tf.int32, name='end_abbreviation_token_positions_layer')

    bert_model = TFBertModel.from_pretrained(checkpoint)
    bert_inputs = [input_ids, attention_mask, token_type_ids]

    bert_out = bert_model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
    last_hidden_state = bert_out.last_hidden_state

    span_hidden_states = ExtractAbbreviationHiddenStates()([last_hidden_state, start_abbrev_token_positions, end_abbrev_token_positions])
    pooled_output = tf.reduce_mean(span_hidden_states, axis=1)

    classification = tf.keras.layers.Dense(num_classes, activation='softmax', name='classification_layer')(pooled_output)

    classification_model = tf.keras.Model(
        inputs=[input_ids, token_type_ids, attention_mask, start_abbrev_token_positions, end_abbrev_token_positions],
        outputs=[classification],
    )

    classification_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
                                 loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
                                 metrics='accuracy')

    return classification_model


In [22]:
abbreviation_model = create_bert_multiclass_model()
abbreviation_model.summary()

Some weights of the PyTorch model were not used when initializing the TF 2.0 model TFBertModel: ['cls.predictions.transform.dense.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight']
- This IS expected if you are initializing TFBertModel from a PyTorch model trained on another task or with another architecture (e.g. initializing a TFBertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing TFBertModel from a PyTorch model that you expect to be exactly identical (e.g. initializing a TFBertForSequenceClassification model from a BertForSequenceClassification model).
All the weights of TFBertModel were initialized from the PyTorch model.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.


Model: "model"
__________________________________________________________________________________________________
 Layer (type)                Output Shape                 Param #   Connected to                  
 input_ids_layer (InputLaye  [(None, 200)]                0         []                            
 r)                                                                                               
                                                                                                  
 attention_mask_layer (Inpu  [(None, 200)]                0         []                            
 tLayer)                                                                                          
                                                                                                  
 token_type_ids_layer (Inpu  [(None, 200)]                0         []                            
 tLayer)                                                                                      

In [23]:
train_inputs = [
    np.array(train_tokenized.input_ids, dtype=np.int32),
    np.array(train_tokenized.token_type_ids, dtype=np.int32),
    np.array(train_tokenized.attention_mask, dtype=np.int32),
    np.array(train_start_positions, dtype=np.int32),
    np.array(train_end_positions, dtype=np.int32),
]

valid_inputs = [
    np.array(validation_tokenized.input_ids, dtype=np.int32),
    np.array(validation_tokenized.token_type_ids, dtype=np.int32),
    np.array(validation_tokenized.attention_mask, dtype=np.int32),
    np.array(valid_start_positions, dtype=np.int32),
    np.array(valid_end_positions, dtype=np.int32),
]

train_labels = np.array(train_labels, dtype=np.int32)
validation_labels = np.array(validation_labels, dtype=np.int32)

history = abbreviation_model.fit(
    train_inputs,
    train_labels,
    validation_data=(valid_inputs, validation_labels),
    batch_size=16,
    shuffle=True,
    verbose=1,
    epochs=1,
)




