# SPECTER Fine-Tuning

In this 2nd Extension to the SPECTER paper, we carried out the task of studying how to fine-tune the pre-trained SPECTER embedder. Please note that for the following we will refer to the embedder considered and to the term *SPECTER* interchangeably.

The [SPECTER paper](https://arxiv.org/abs/2004.07180) suggests that classification of the embeddings $e_i \in \mathbb R^{d}\ \forall i \in \mathcal X$ can be carried out using standard classical Machine Learning Algorithms such as SVM. In particular, the authors findings suggest that using a Linear Kernel SVM (with a fine-tuned value of $C$) one can obtain the following results in terms of classification performance:

|         Task        | Macro F1 Score |
|:-------------------:|:--------------:|
| MeSH Classification |      87.7      |
|  MAG Classification |      79.4      |

The basic intuition behind fine-tuning pretrained models is that the loss obtained in performing any downstream task can indeed be used to adjust the model's embeddings, *i.e.* the embeddings can be changed (albeit slightly only, hence the term *fine-tuning*) to better perform any given task such as the classification task we focus on in this extension. 

In general, classification consists in learning a discriminative function $\hat f:\mathcal X \mapsto \mathcal Y$ that maps data points in a given feature space $\mathcal X$ to their corresponding label (**one** of a finite number in a label set $\mathcal Y$). The original intuition of SPECTER's authors is that one can decouple the task of Text Classification into two main subcomponents: 

1. **Natural Language Embedding**, *i.e.* the task of obtaining contextualized numerical representations of textual data

2. **Embedding Classification**, *i.e.* the task of actually classyfing such numerical representations using any given classification algorithm (that is, learning $\hat f$).


<p align="center">
    <img width=1000 src="https://i.ibb.co/Dp6K4wt/SPECTERclassification.png" alt="ext2-scheme" border="0">
</p>

The process just displayed is one in which each paper $P_i$ is first embedded through $\texttt{SPECTER}$ into the corresponding embedding $e_i$. 
Later on, traditional Machine Learning techniques (here represented with the scikit-learn symbol) are used to learn the discriminative function $\hat f$ (hopefully) minimizing the classification error $\Vert l - \hat f(e) \Vert_{p} \ \forall i$ and for some $p$ norm. Here, the loss function is used to only "learn" $\hat f$. If one uses a SVM algorithm to classify the embedded papers, then it is possible to reproduce the results of SPECTER, *i.e.* to correctly classify the majority of *classification-static* paper embeddings.

<p align="center">
    <img width=1000 src="https://i.ibb.co/G7cJtLj/improvement.png" alt="ext2-scheme" border="0">
</p>

Despite this approach clearly is very well-performing in a variety of different situations, the information about the loss can in principle be used differently. In particular, one can propagate back the loss information to also change the procedure with which the very same papers are embedded. This idea relies on the simple yet possibly very powerful intuition that embeddings produced by SPECTER might suffer from over-generalization (*i.e.*, they might be not so specific for the tackled task) when used in the context of Text Classification.

The fact the embeddings might be slightly sub-optimal in terms of performance for classification tasks, follows from the fact that said embeddings are produced in the sake of producing high-quality (citation-network informed) numerical representations of scientific papers. To this aim, Text Classification simply is a downstream activity and in this does not enter the pipeline in its initial stages. Indeed, it is no more relevant to the procedure with which to embed paper than other tasks such as Citation Prediciton, for instance. 
This aspect clearly hinders the possibility of using SPECTER to its fullest in one specific application, since the embeddings it produces might be simply non tailored to be used to this aim. 

Our intuition is that one can **chain** the two steps on which Text Classification is based, thus unifying the whole process. 
After an often very extensive and data-intensive phase of pre-training, the embeddings produced by SPECTER are then fed in a **Classification Head** (CH) based on a Multi-Layer Perceptron architecture. This allows a complete flow of information between not only the CH parameters and the classification output, but also between the SPECTER embedding model and the classification output itself.

Theoretically, this flow of information can be used to tweak (or better, **fine-tune**) SPECTER parameters specifically for classification (or really any downstream task).

This is justified by the fact that, if one has **labelled** dataset $\tau$ defined as: 

$$
\begin{equation}
\tau =  \{ \mathcal P_i \vert l_i \}_{i = 1, \dots, \vert \mathcal X \vert}
\end{equation}
$$

Then, in the bottom part of the diagram, it is clear that the classification function $\hat f: \mathcal X \mapsto \mathcal Y$ is applied to any given paper $\mathcal P$ as follows: 

$$
\begin{equation}
g_{\text{bottom}}(\mathcal P) = \hat f(\texttt{SPECTER}(P))
\end{equation}
$$

Which yields that if one uses as loss-function the misclassification error $L(y, \hat y) \mapsto \mathbb R^+$ then, clearly enough, one practically observes that:

$$
\begin{equation}
\frac{\partial L}{\partial w_\texttt{SPECTER}} \neq 0
\end{equation}
$$

Now, of course, one cannot expect the extensively trained weights of SPECTER to significantly change for one specific task: as an encoder, SPECTER's job is, at the end of the day, to turn text into meaningful and contextual *general-purpose* numerical representation. 
Indeed, the major use of the information in $\tau$ is used in training the CH on top of SPECTER.
Nevertheless, the embeddings are indeed updated, so that classification is performed in a dynamical feature space, whose geometry is affected by the actual task rather than being independent.

In [5]:
# use wandb to track experiments and trainings.
!wandb login

In [2]:
# seed this notebook
from commons.utils import *
from transformers import AutoTokenizer, AutoModel

seedEverything(seed=321)  # specter's seed

# load SPECTER pre-trained model and tokenizer
tokenizer = AutoTokenizer.from_pretrained('allenai/specter')
model = AutoModel.from_pretrained('allenai/specter')

The data to perform the considered task are stored in the `data` folder. Nevertheless, accessing them in a way that is straight-forward to use to carry out the task here presented is not possible, as the actual textual data used are separated from the correspoding label. 

Should the data folder be empty (or not present) in your version, you can create by simply running: 
```bash
$ bash commons/getdata.sh
```

This bash file will download the `data` folder so that later steps of analysis are possible.

Inside the `data` folder, one can find:

1. `paper_metadata_mag_mesh.json`, which contains various features, such as:
    
   `pid`: Paper-Id, string (*e.g.*, `00021eeee2bf4e06fec98941206f97083c38b54d`).
   
   `abstract`: Paper abstract, Text,
   
   `cited_by`, List. Citation list in which each element is the `pid` of other papers citing this one, 
   
   `references`: List, Citation list in which each element is the `pid` of other papers cited by this one, 
   
   `title`: Paper title. Text
   
   
2. `mag/{train/val/test}.csv` which is organized as:

    `pid`: Paper-Id, string, (*e.g.*, `00021eeee2bf4e06fec98941206f97083c38b54d`).
    
    `label`: Label, int. An integer value in the 0-18 range representing one of the [MAG classes](https://github.com/allenai/scidocs/blob/ebf239d30d70062b4111f9e3a8efe2b3d3f3d303/README.md?plain=1#L121-L139)


3. `mesh/{train/val/test}.csv` which is organized as:

    `pid`: Paper-Id, string, (*e.g.*, `00021eeee2bf4e06fec98941206f97083c38b54d`).
    
    `label`: Label, int. An integer value in the 0-10 range representing one of the [MeSH classes](https://github.com/allenai/scidocs/blob/ebf239d30d70062b4111f9e3a8efe2b3d3f3d303/README.md?plain=1#L106-L116)
    

This clearly indicates that the data need to be preprocessed to make it usable for the the considered model. 

In particular, the data shall undergo: 
1. A step of **cleaning**, in which invalid papers are removed from the pool of the one that will later be considered. Considering our limited computational resources, we chose to avoid considering papers that do not present both title *and* abstract, as well as papers that are not in english. This reduces the original dataset size of ~23%.
2. A step in which they are **joined with the labelled data** (which is in `mesh` and `mag`).

In [3]:
from commons.data_utils import *
# perform data pre-processing
scidocs = load_metadata()
mag = load_dataset(dataset="mag").join(scidocs, how="inner")
mesh = load_dataset(dataset="mesh").join(scidocs, how="inner")

del scidocs

Retrieving non-english papers: 100%|██████████| 37556/37556 [00:16<00:00, 2286.01it/s]


Total number of papers in SciDocs: 48473
Total number of papers after data removing abstract/title lacking papers: 37556
Total number of papers after data removing non english papers: 37227


Once the different data sources have been poured all together, it is necessary to spend a little effort in interfacing Pandas and the DL framework used in here, that is PyTorch.

Here, we will use Huggingface Datasets as middle ground to obtain our final result.

Various different steps are needed to turn `pd.DataFrame` into something one can use to train a Pytorch model on. These are:

1. Change the `class_label` name for the label columns into `labels` (as per Pytorch API).
2. Concatenate title and abstract, using the formula: `title + tokenizer.sep_token + abstract`.
3. Obtain the numerical representation of title and abstract themselves, using a tokenizer.
4. Remove all useless columns from the dataset.
   
Here, we do all these four fundamental steps calling a custom-defined function presented in `commons`.

In [4]:
mesh_hf, mag_hf = [
    tokenize_hf(
        hf = to_hf_dataset(dataset=dataset), 
        tokenizer=tokenizer
        ) 
    for dataset in [mag, mesh]
]

Casting the dataset:   0%|          | 0/15 [00:00<?, ?ba/s]

  0%|          | 0/15 [00:00<?, ?ba/s]

Casting the dataset:   0%|          | 0/24 [00:00<?, ?ba/s]

  0%|          | 0/24 [00:00<?, ?ba/s]

In [9]:
from commons.model_utils import embed_data

In [13]:
embed_data(model=model, data=mesh_hf.remove_columns("labels"))

Output()

KeyboardInterrupt: 

MeSH and MAG embeddings through `SPECTER` model

In [None]:
stop

In [None]:
do_embed=False
if do_embed:
    with torch.no_grad():
        mesh_embeddings = []
        for batch in track(mesh_tokenized, description="MeSH embeddings..."):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            mesh_embeddings.append(outputs.last_hidden_state[:,0,:].to("cpu"))
            del batch
            torch.cuda.empty_cache()

        mag_embeddings = []
        for batch in track(mag_tokenized, description="MAG embeddings..."):
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            mag_embeddings.append(outputs.last_hidden_state[:,0,:].to("cpu"))
            del batch
            torch.cuda.empty_cache()

Once proven that the embeddings my be produced considering the actual structure given, we here produce a `SPECTER` model that is pretrained and has a classification head to classify input papers.

In [None]:
import torch.nn as nn
import math

class SPECTERClassifier(nn.Module):
    def __init__(
        self, 
        base_model, 
        n_labels:int, 
        n_layers:int=2, 
        n_units:int=512,
        activation_function:str="relu", 
        use_gpu:bool=True,
        use_dropout:bool=False,
        use_batchnorm:bool=False): 
        """
        Args: 
            base_model (transformers.AutoModel): Pre-trained model from the trans-
                                                 formers API.
            n_labels (int): Number of labels that need to predicted
            n_layers (int, optional): Number of layers in the classification head. 
                                      Defaults to 2.
            n_units (int, optional): Number of units in each hidden layer of the
                                     classification head. Defaults to 512.
            activation_function (str, optional): Activation function to be used in the
                                                 classification head. Defaults to ReLU.
            use_gpu (bool, optional): Whether or not to use an available GPU. Defaults to True.
            use_dropout(bool, optional): Whether or not to use dropout to regularize the classification
                                         head of the network. Defaults to False.
            use_batchnorm(bool, optional): Whether or not to use 1d batch norm to scale layers.
                                           Defaults to False.
        """
        super().__init__()

        # SPECTER model
        self.model = base_model

        # accessing output dimension of base_model
        *_, prelast, _ = self.model.modules()

        transformer_output = nn.Linear(prelast.out_features, n_units)
        hidden_layers = [nn.Linear(n_units, n_units) for _ in range(n_layers)]
        logits_layer = nn.Linear(n_units, n_labels)
        
        if activation_function.lower()=="relu": 
            self.act_func = nn.ReLU
        elif activation_function.lower()=="tanh":
            self.act_func = nn.Tanh
        elif activation_function.lower()=="sigmoid":
            self.act_func = nn.Sigmoid
        else:
            print(f"Input Activation function: {activation_function}")
            raise NotImplementedError("No activation functions other than ReLU currently implemented")
        
        # classification head is a cell-like structure defined as
        # (layer->act_function->layer...) x n_layers

        layers = [transformer_output, *hidden_layers, logits_layer]
        
        act_functions = [self.act_func() for _ in range(len(layers)-1)]
        
        clf_head = [None for _ in range(len(layers) + len(act_functions))]
        clf_head[::2] = layers; clf_head[1::2] = act_functions
        
        if use_batchnorm:
            for b_index in range(len(clf_head)-1, 3):
                clf_head.insert(b_index, nn.BatchNorm1d(num_features=n_units))
        
        if use_dropout:
            for d_index in range(len(clf_head)-1, 4):
                clf_head.insert(d_index, nn.Dropout(p=0.5))

        self.classification_head = nn.Sequential(*clf_head)
        self.device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"

    def set_classification_head(self, clf_head:nn.Sequential):
        """Sets a classification head for the model considered"""
        if isinstance(clf_head, nn.Sequential):
            self.classification_head = clf_head
        elif isinstance(clf_head, list): 
            self.classification_head = nn.Sequential(*clf_head)
        else: 
            raise ValueError("Classification Head not a list nor an iterable!")

    def forward(self, x:dict)->torch.Tensor:
        """Forward pass"""
        device = self.device
        SPECTER_input = {
            key: x[key].to(device) for key in ["input_ids", "token_type_ids", "attention_mask"]
        }
        SPECTER_model = self.model.to(device)
        SPECTER_output = SPECTER_model(**SPECTER_input)
        # remove un-necessary input from device memory
        del SPECTER_input
        # classification is applied on SPECTER output
        classifier_input = SPECTER_output[1]
        classifier_model = self.classification_head.to(device)
        return classifier_model(classifier_input)


In [None]:
from torch.optim import Optimizer, lr_scheduler
from datasets import DatasetDict
from torch.nn.modules.loss import _WeightedLoss
from torch.optim.lr_scheduler import _LRScheduler
from torch.utils.data import DataLoader
from tqdm import tqdm
tqdm.pandas()
from sklearn.metrics import f1_score

class Trainer: 
    def __init__(
        self,
        model:torch.nn.Module, 
        splits:DatasetDict, 
        optimizer:Optimizer, 
        loss_function:_WeightedLoss, 
        batch_size:int=4,
        use_gpu:bool=True, 
        use_scheduler:bool=False,
        scheduler:_LRScheduler=None,
        steps_up:int=200):

        self.train_loader = DataLoader(
            splits["train"], 
            batch_size=batch_size, 
            shuffle=True
            )

        self.test_loader = DataLoader(
            splits["test"], 
            batch_size=batch_size, 
            shuffle=True
            )
        
        self.optimizer = optimizer
        self.use_scheduler = use_scheduler
        if self.use_scheduler:
            self.scheduler = scheduler

        self.loss = loss_function

        self.model = model
        self.batch_size = batch_size
        
        self.device = "cuda" if torch.cuda.is_available() and use_gpu else "cpu"
    
    def do_train(self, n_epochs:int=5, log_every:int=10): 
        """
        Performs training.

        Args:
            n_epochs (int, optional): Training epochs. Defaults to 5.
            log_every (int, optional): How often (number of batches) to log
                                       current training loss.
        """
        device = self.device
        
        # set training mode
        self.model.train()
        step = 0
        for epoch in (training_bar:=tqdm(range(n_epochs))): 
            # loop over training data
            for batch in self.train_loader: 
                labels = batch["labels"].to(device)
                # zerograd optimizer
                self.optimizer.zero_grad()
                # forward pass
                outputs = self.model(batch)
                # loss computation
                loss_value = self.loss(outputs, labels)
                # backward pass
                loss_value.backward()
                # step parameters
                self.optimizer.step()
                
                if self.use_scheduler:
                    self.scheduler.step()

                training_bar.set_description("Training Loss: {:.4f}".format(loss_value.item()))
                step += 1

                if step % log_every == 0: 
                    training_bar.set_description("Training Loss: {:.4f}".format(loss_value.item()))
                    wandb.log({
                        "CrossEntropy-Loss": loss_value.item(),
                    })

    def do_test(self)->float: 
        """Performs testing for the considered model. Testes f1 score in particular
        
        Args: 
            labels (list): List of labels as per problem specification
        
        Returns: 
            float: Average f1 score over all batches in test set.
        """
        
        # set testing mode
        self.model.eval()

        batches_f1 = torch.zeros(len(self.test_loader))

        with torch.no_grad():
            idx = 0
            for batch in tqdm(self.test_loader):
                output = self.model(batch).cpu()
                y_true = batch["labels"]
                # max(axis=1) gives the maximal value and maximal index too
                _, y_batch = torch.max(output, 1)
                # store this batch f1 score
                batches_f1[idx] = f1_score(
                    y_true=y_true, 
                    y_pred=y_batch.numpy(),
                    average="macro")
                
                idx += 1
        return batches_f1.mean().item()


After having defined an architecture and a model that allows us to predict classes based on SPECTER embedding, we specialize our work to correctly classify MeSH and MAG data.

# MeSH Classification

In [None]:
experiment=False
if experiment: 
    # track experiments
    wandb.init(
        # set the wandb project where this run will be logged
        project="Extension2-MeSH classification",
        # track hyperparameters and run metadata
        config={
        "learning_rate": 5e-5,
        "architecture": "1-layer only",
        "hidden_layers": 1,
        "epochs": 5, 
        "dropout":False,
        "batchnorm":False
        }
    )

In [None]:
# import pre-trained SPECTER
model = AutoModel.from_pretrained("allenai/SPECTER")

sc_mesh = SPECTERClassifier(
	  base_model=model, 
	  n_labels=11, 
	  n_layers=0, 
	  n_units=0)
sc_mesh.set_classification_head([nn.Linear(768,11)])

total_params = sum(
	param.numel() for param in sc_mesh.parameters()
)
print("Number of parameters (MeSH model): {:.4e}".format(total_params))

Number of parameters (MeSH model): 1.0995e+08




In [None]:
from torch.optim import AdamW

# splitting mesh data into training and test data
mesh_splits = mesh_tokenized_labels.train_test_split(test_size=0.1)

# instantiate an optimizer
optimizer = AdamW(sc_mesh.parameters(), lr=5e-5)

# define a loss function
loss_function = nn.CrossEntropyLoss()

# instantiate a trainer object
trainer = Trainer(
    model=sc_mesh, 
    splits=mesh_splits, 
    optimizer=optimizer, 
    loss_function=loss_function,
    batch_size=8
    )

In [None]:
model_path = "/content/drive/MyDrive/trained-models/mesh"

perform_training=False
if perform_training:
    trainer.do_train(5, log_every=5)
    PATH = f"{model_path}/run4.pth"
    torch.save(sc_mesh.state_dict(), PATH)
    wandb.finish()

load_run=True
if load_run:
    sc_mesh = SPECTERClassifier(
	  base_model=model, 
	  n_labels=11, 
	  n_layers=0, 
	  n_units=0)
    sc_mesh.set_classification_head([nn.Linear(768,11)])

    sc_mesh.load_state_dict(torch.load(f"{model_path}/run4.pth"))

In [None]:
test=True
if test: 
    trainer = Trainer(
    model=sc_mesh,
    splits=mesh_splits,
    optimizer=optimizer, 
    loss_function=loss_function,
    batch_size=12
    )

    avg_f1 = trainer.do_test()
    print("\nAverage F1-Score {:.4f}".format(avg_f1))

100%|██████████| 193/193 [01:21<00:00,  2.36it/s]


Average F1-Score 0.9486





# MAG Classification

In [None]:
experiment=False
if experiment:
    # track experiments
    wandb.init(
        # set the wandb project where this run will be logged
        project="Extension2-MAG classification",
        
        # track hyperparameters and run metadata
        config={
        "learning_rate": 1e-3,
        "architecture": "Structured Clf Head",
        "hidden_layers": 2,
        "n_units":8,
        "epochs": 10,
        "batchsize":16
        }
    )

In [None]:
# reimporting regular SPECTER
model = AutoModel.from_pretrained("allenai/SPECTER")
sc_mag = SPECTERClassifier(
		base_model=model,
		n_labels=19
)
sc_mag.set_classification_head([
	  nn.Linear(768,8), 
	  nn.Tanh(),
	  nn.Linear(8,8),
	  nn.BatchNorm1d(num_features=8),
	  nn.Linear(8,19)]
)

total_params = sum(
	param.numel() for param in sc_mag.parameters()
)
print("Number of parameters (MeSH model): {:.4e}".format(total_params))

Number of parameters (MeSH model): 1.0994e+08


In [None]:
from torch.optim import AdamW, SGD
from torch.optim.lr_scheduler import CyclicLR

# splitting mesh data into training and test data
mag_splits = mag_tokenized_labels.train_test_split(test_size=0.1)

# instantiate an optimizer
adam = True
if adam:
    optimizer = AdamW(sc_mag.parameters(), lr=1e-3)
else:
    optimizer = SGD(sc_mag.parameters(), lr=5e-4, momentum=0.9, weight_decay=0.01)

# define learning rate scheduler
scheduler = CyclicLR(
    optimizer, 
    base_lr=5e-6, 
    max_lr=5e-3, 
    step_size_up=50, 
    cycle_momentum=not adam)

# define a loss function
loss_function = nn.CrossEntropyLoss()

# instantiate a trainer object
trainer = Trainer(
    model=sc_mag, 
    splits=mag_splits,
    optimizer=optimizer, 
    loss_function=loss_function,
    batch_size=16
    )

In [None]:
model_path = "/content/drive/MyDrive/trained-models/mag"

perform_training=False
if perform_training:
    # perform training until stopped, then save trained model
    try:
        trainer.do_train(10, log_every=3)
    except KeyboardInterrupt:
        pass

    PATH = f"{model_path}/run8.pth"
    torch.save(sc_mag.state_dict(), PATH)
    wandb.finish()

load_pretrained=True
if load_pretrained:
    sc_mag = SPECTERClassifier(base_model=model, n_labels=19)
    sc_mag.set_classification_head([nn.Linear(768,8), nn.ReLU(), nn.Linear(8,19)])
    sc_mag.load_state_dict(torch.load(f"{model_path}/run7.pth"))

In [None]:
# instantiate a trainer object
del trainer
trainer = Trainer(
    model=sc_mag, 
    splits=mag_splits, 
    optimizer=optimizer, 
    loss_function=loss_function,
    batch_size=64
    )

do_test = True
if do_test:
    avg_f1 = trainer.do_test()
    print("\nAverage F1-Score {:.4f}".format(avg_f1))

100%|██████████| 23/23 [00:46<00:00,  2.02s/it]


Average F1-Score 0.9726



