In [1]:
%run ../ipynb_util_tars.py

## Test different ways to load the models

the outcome of this jupynb is an easy way to access the finetuned models via variables:
* `scibert_model` = SciBERT finetuned on ZO_UP
* `llama_model` = LLaMA-3 finetuned on ZO_UP with a classification head
* `unllama_model` = LLaMA-3 finetuned on ZO_UP with a classification head without causal mask

### Dataset + encodings

In [2]:
%run ../ipynb_load_data.py

{'SDG': ClassLabel(names=['1', '10', '11', '12', '13', '14', '15', '16', '17', '2', '3', '4', '5', '6', '7', '8', '9'], id=None), 'ABSTRACT': Value(dtype='string', id=None), 'id': Value(dtype='string', id=None), 'sdg_desc_short': Value(dtype='string', id=None), 'sdg_desc_long': Value(dtype='string', id=None)}
Example instance:	 {'SDG': 16, 'ABSTRACT': 'The first attempts to modernize simply replaced the single huge engine with a huge electric motor, changing little. The drive-shafts were replaced by wires, the huge steam engine by dozens of small motors. Factories spread out, there was natural light, and room to use ceiling-slung cranes. Workers had responsibility for their own machines, they needed better training and better pay. The electric motor was a wonderful invention, once we changed all the everyday details that surrounded it.', 'id': None, 'sdg_desc_short': None, 'sdg_desc_long': None}
Encoded (label2id) label:	 16
Decoded (id2label) label:	 9
9 16 16


In [3]:
sample_sentence = "Is this about poverty?"

### Evaluator

In [4]:
import pprint
import datasets
import evaluate
from evaluate import evaluator, Metric
from sklearn.metrics import accuracy_score


class MulticlassAccuracy(Metric):
    """Workaround for the default Accuracy class which doesn't support passing 'average' to the compute method."""

    def _info(self):
        return evaluate.MetricInfo(
            description="Accuracy",
            citation="",
            inputs_description="",
            features=datasets.Features(
                {
                    "predictions": datasets.Sequence(datasets.Value("int32")),
                    "references": datasets.Sequence(datasets.Value("int32")),
                }
                if self.config_name == "multilabel"
                else {
                    "predictions": datasets.Value("int32"),
                    "references": datasets.Value("int32"),
                }
            ),
            reference_urls=["https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html"],
        )

    def _compute(self, predictions, references, normalize=True, sample_weight=None, **kwargs):
        # take **kwargs to avoid breaking when the metric is used with a compute method that takes additional arguments
        return {
            "accuracy": float(
                accuracy_score(references, predictions, normalize=normalize, sample_weight=sample_weight)
            )
        }

task_evaluator = evaluator("text-classification")
task_evaluator.METRIC_KWARGS = {"average": "weighted"}
metrics_dict = {
    "accuracy": MulticlassAccuracy(),
    "precision": "precision",
    "recall": "recall",
    "f1": "f1",
}

## SciBERT baseline

In [5]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification

SCIBERT_PATH = CHECKPOINT_PATH + "/allenai/scibert_scivocab_uncased-ft-zo_up-lower/checkpoint-240/"

scibert_model = AutoModelForSequenceClassification.from_pretrained(
    SCIBERT_PATH,
    num_labels=len(labels),
    id2label=id2label,
    label2id=label2id
).to("cuda")
scibert_tokenizer = AutoTokenizer.from_pretrained(SCIBERT_PATH)
scibert_model.eval()

# Sample input to SciBERT
sample_input = scibert_tokenizer(sample_sentence, return_tensors="pt").to("cuda")
sample_output = scibert_model(**sample_input)
print(torch.max(torch.softmax(sample_output.logits, dim=-1), dim=-1))

torch.return_types.max(
values=tensor([0.8610], device='cuda:0', grad_fn=<MaxBackward0>),
indices=tensor([0], device='cuda:0'))


In [6]:
# Evaluate SciBERT
eval_results = task_evaluator.compute(
    scibert_model,
    input_column="ABSTRACT",
    label_column="SDG",
    tokenizer=scibert_tokenizer,
    data=dataset["test"],
    label_mapping=label2id,
    metric=evaluate.combine(metrics_dict)
)
pprint.pprint(eval_results)

Asking to truncate to max_length but no maximum length is provided and the model has no predefined maximum length. Default to no truncation.


{'accuracy': 0.7269372693726938,
 'f1': 0.718275735363878,
 'latency_in_seconds': 0.00714278105831949,
 'precision': 0.7210623353133084,
 'recall': 0.7269372693726938,
 'samples_per_second': 140.0014912728228,
 'total_time_in_seconds': 1.9356936668045819}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [7]:
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

# Evaluate SciBERT (manual metrics calculation sanity check)
scibert_tokenized_dataset = dataset.map(
    preprocess_data(scibert_tokenizer, include_labels=False), batched=True, remove_columns=dataset["train"].column_names
)
scibert_tokenized_dataset.set_format("torch")
scibert_out_logits = torch.tensor([])
for batch in scibert_tokenized_dataset["test"]:
    scibert_out_logits = torch.cat(
        (
            scibert_out_logits,
            scibert_model(
                input_ids=batch["input_ids"].to("cuda").unsqueeze(0),
                attention_mask=batch["attention_mask"].to("cuda").unsqueeze(0)
            ).logits.detach().cpu()
        )
    )

#scibert_accuracy = accuracy_score(y_true=dataset["test"]["SDG"][:64], y_pred=preds_scibert.cpu())
_, scibert_preds = torch.max(torch.softmax(scibert_out_logits, dim=-1), dim=-1)
scibert_accuracy = accuracy_score(y_true=dataset["test"]["SDG"], y_pred=scibert_preds.cpu())

scibert_f1 = f1_score(y_true=dataset["test"]["SDG"], y_pred=scibert_preds.cpu(), average="weighted")
scibert_precision = precision_score(y_true=dataset["test"]["SDG"], y_pred=scibert_preds.cpu(), average="weighted")
scibert_recall = recall_score(y_true=dataset["test"]["SDG"], y_pred=scibert_preds.cpu(), average="weighted")

print(dataset["test"]["SDG"][:32])
print(scibert_preds[:32].tolist())

pprint.pprint({
    "accuracy": scibert_accuracy,
    "precision": scibert_precision,
    "recall": scibert_recall,
    "f1": scibert_f1
})

  StockPickler.save(self, obj, save_persistent_id)
  StockPickler.save(self, obj, save_persistent_id)


[3, 9, 4, 9, 14, 15, 1, 5, 5, 10, 11, 12, 15, 14, 14, 10, 0, 14, 4, 1, 7, 1, 6, 11, 7, 10, 1, 1, 9, 9, 2, 14]
[3, 9, 4, 9, 14, 15, 16, 5, 5, 12, 11, 12, 1, 4, 14, 10, 0, 13, 4, 1, 7, 15, 6, 11, 7, 10, 1, 15, 9, 9, 2, 10]
{'accuracy': 0.7269372693726938,
 'f1': 0.718275735363878,
 'precision': 0.7210623353133084,
 'recall': 0.7269372693726938}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


## LLaMA-3

In [8]:
import torch.nn as nn
import torch
from transformers.models.llama.modeling_llama import LlamaForSequenceClassification, LlamaDecoderLayer, LlamaConfig, LlamaRMSNorm, LlamaPreTrainedModel, LlamaModel, LLAMA_INPUTS_DOCSTRING, add_start_docstrings_to_model_forward, SequenceClassifierOutputWithPast, BaseModelOutputWithPast, BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers.cache_utils import Cache, DynamicCache
from typing import Optional, List, Union, Tuple


class UnmaskingLlamaModel(LlamaModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    @add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING)
    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        cache_position: Optional[torch.LongTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
        output_hidden_states = (
            output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError(
                "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one"
            )

        if self.gradient_checkpointing and self.training and use_cache:
            logger.warning_once(
                "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
            )
            use_cache = False

        if inputs_embeds is None:
            inputs_embeds = self.embed_tokens(input_ids)

        return_legacy_cache = False
        if use_cache and not isinstance(past_key_values, Cache):  # kept for BC (non `Cache` `past_key_values` inputs)
            return_legacy_cache = True
            past_key_values = DynamicCache.from_legacy_cache(past_key_values)

        if cache_position is None:
            past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
            cache_position = torch.arange(
                past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
            )
        if position_ids is None:
            position_ids = cache_position.unsqueeze(0)

        causal_mask = self._update_causal_mask(
            attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions
        )
        if causal_mask is not None:
            #print("b4", input_ids.shape, causal_mask.shape, causal_mask)
            # Assuming causal_mask is a tensor with shape (batch_size, 1, seq_length, hidden_size)
            causal_mask_last_row = causal_mask[:, :, -1, :].unsqueeze(2)
            causal_mask = causal_mask_last_row.expand_as(causal_mask)
            # causal_mask = torch.zeros_like(causal_mask, device=inputs_embeds.device)

            #print("after", causal_mask.shape, causal_mask)
        else:
            pass
            #print("kek it's none", causal_mask, input_ids)

        # embed positions
        hidden_states = inputs_embeds

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = None

        for decoder_layer in self.layers:
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            if self.gradient_checkpointing and self.training:
                layer_outputs = self._gradient_checkpointing_func(
                    decoder_layer.__call__,
                    hidden_states,
                    causal_mask,
                    position_ids,
                    past_key_values,
                    output_attentions,
                    use_cache,
                    cache_position,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=causal_mask,
                    position_ids=position_ids,
                    past_key_value=past_key_values,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                    cache_position=cache_position,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache = layer_outputs[2 if output_attentions else 1]

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if return_legacy_cache:
            next_cache = next_cache.to_legacy_cache()

        if not return_dict:
            return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )

class UnmaskingLlamaForSequenceClassification(LlamaForSequenceClassification):

    def __init__(self, config):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.model = UnmaskingLlamaModel(config)
        self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)

        # Initialize weights and apply final processing
        self.post_init()

In [9]:
import torch
from transformers import AutoTokenizer, LlamaForSequenceClassification

#LLAMA_PATH = "meta-llama/Meta-Llama-3-8B"
#LLAMA_PATH = f"{CHECKPOINT_PATH}/meta-llama/Meta-Llama-3-8B-ft-zo_up/checkpoint-2200/"
LLAMA_PATH = f"{CHECKPOINT_PATH}/meta-llama/Meta-Llama-3-8B-ft-zo_up-unmasked/checkpoint-1850/"
llama_tokenizer = AutoTokenizer.from_pretrained(LLAMA_PATH)
# llama_tokenizer.pad_token = llama_tokenizer.eos_token
llama_model = UnmaskingLlamaForSequenceClassification.from_pretrained(
    LLAMA_PATH,
    num_labels=17,
    device_map="auto",
    torch_dtype=torch.bfloat16
)
llama_model.eval()
llama_model.config.pad_token_id = llama_tokenizer.pad_token_id

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

Some weights of UnmaskingLlamaForSequenceClassification were not initialized from the model checkpoint at meta-llama/Meta-Llama-3-8B and are newly initialized: ['score.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
tokenized_sample = llama_tokenizer(sample_sentence, return_tensors="pt")
token_ids = tokenized_sample["input_ids"]

llama_out = llama_model(token_ids)
print(torch.max(torch.softmax(llama_out.logits, dim=-1), dim=-1))

torch.return_types.max(
values=tensor([0.8867], dtype=torch.bfloat16),
indices=tensor([12]))


In [11]:
# Evaluate LLaMA - can't use task evaluator because it doesn't support accelerate which is required for inference larger models
# https://github.com/huggingface/evaluate/issues/487

# tokenize the dataset first
llama_tokenized_dataset = dataset.map(
    preprocess_data(llama_tokenizer, padding="longest", max_length=1024, include_labels=False), batched=True, remove_columns=dataset["train"].column_names
)
llama_tokenized_dataset.set_format("torch")

# llama_out_logits = torch.tensor([])
# for batch in llama_tokenized_dataset["test"]:
#     llama_out_logits = torch.cat(
#         (
#             llama_out_logits,
#             llama_model(
#                 input_ids=batch["input_ids"].to("cuda").unsqueeze(0),
#                 attention_mask=batch["attention_mask"].to("cuda").unsqueeze(0)
#             ).logits.detach().cpu()
#         )
#     )

# need to split the input_ids tensor into two tensors to avoid CUDA out of memory error
# out = llama_model(**llama_tokenized_dataset["test"][:128])
# out2 = llama_model(**llama_tokenized_dataset["test"][128:])
# llama_out_logits = torch.cat((out.logits, out2.logits), dim=0)

# Batch size 32 to avoid CUDA out of memory error
llama_out_logits = torch.tensor([])
batch_size = 64
for i in range(0, len(llama_tokenized_dataset["test"]), batch_size):
    batch = llama_tokenized_dataset["test"][i:i+batch_size]
    out = llama_model(**batch)
    llama_out_logits = torch.cat((llama_out_logits, out.logits), dim=0)

  StockPickler.save(self, obj, save_persistent_id)
  StockPickler.save(self, obj, save_persistent_id)


In [12]:
llama_pred_probs, llama_preds = torch.max(torch.softmax(llama_out_logits, dim=-1), dim=-1)

llama_accuracy = accuracy_score(y_true=dataset["test"]["SDG"], y_pred=llama_preds)
llama_f1 = f1_score(y_true=dataset["test"]["SDG"], y_pred=llama_preds, average="weighted")
llama_precision = precision_score(y_true=dataset["test"]["SDG"], y_pred=llama_preds, average="weighted")
llama_recall = recall_score(y_true=dataset["test"]["SDG"], y_pred=llama_preds, average="weighted")

pprint.pprint({
    "accuracy": llama_accuracy,
    "precision": llama_precision,
    "recall": llama_recall,
    "f1": llama_f1
})

{'accuracy': 0.7564575645756457,
 'f1': 0.7488491461489248,
 'precision': 0.758700999587592,
 'recall': 0.7564575645756457}


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [13]:
from sklearn.metrics import classification_report

print(classification_report(
    y_true=dataset["test"]["SDG"],
    y_pred=llama_preds,
    target_names=[f"SDG {id2label[i]}" for i in range(len(labels))]
))

              precision    recall  f1-score   support

       SDG 1       0.71      0.71      0.71        17
      SDG 10       0.50      0.41      0.45        17
      SDG 11       0.87      0.76      0.81        17
      SDG 12       0.53      0.59      0.56        17
      SDG 13       0.83      0.88      0.86        17
      SDG 14       1.00      1.00      1.00        17
      SDG 15       0.84      0.94      0.89        17
      SDG 16       0.64      0.53      0.58        17
      SDG 17       0.00      0.00      0.00         1
       SDG 2       0.72      0.81      0.76        16
       SDG 3       0.75      0.88      0.81        17
       SDG 4       0.84      0.94      0.89        17
       SDG 5       0.74      1.00      0.85        17
       SDG 6       0.85      1.00      0.92        17
       SDG 7       1.00      0.59      0.74        17
       SDG 8       0.53      0.56      0.55        16
       SDG 9       0.82      0.53      0.64        17

    accuracy              

  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


In [14]:
# llama-3
{'accuracy': 0.7269372693726938,
 'f1': 0.7303868168770821,
 'precision': 0.7581306326098612,
 'recall': 0.7269372693726938}
{'accuracy': 0.7158671586715867,
 'f1': 0.7183098469704836,
 'precision': 0.7475784117481534,
 'recall': 0.7158671586715867}
# eval + 128batch:
{'accuracy': 0.7195571955719557,
 'f1': 0.7218231123589773,
 'precision': 0.7505385512396582,
 'recall': 0.7195571955719557}
# no eval + 128batch:
{'accuracy': 0.7269372693726938,
 'f1': 0.7303868168770821,
 'precision': 0.7581306326098612,
 'recall': 0.7269372693726938}

# llama-3 unmasked
{'accuracy': 0.7564575645756457,
 'f1': 0.7485577477676632,
 'precision': 0.7555064756637431,
 'recall': 0.7564575645756457}

{'accuracy': 0.7564575645756457,
 'f1': 0.7485577477676632,
 'precision': 0.7555064756637431,
 'recall': 0.7564575645756457}