In [None]:
! pip install datasets transformers[torch] evaluate wandb
import evaluate
from datasets import load_dataset, load_metric
import pandas as pd
from transformers import AutoTokenizer
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from transformers import AutoConfig
from glob import glob
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

import wandb
wandb.login(key="...")
%env WANDB_LOG_MODEL=true
%env WANDB_PROJECT=...


In [2]:
import torch
from torch import nn
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
from transformers import BertPreTrainedModel, BertModel
from transformers.modeling_outputs import SequenceClassifierOutput
from typing import List, Optional, Tuple, Union
from dataclasses import dataclass

class RobertaFCClassificationHead(torch.nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        classifier_dropout = (
            config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
        )
        self.dropout = nn.Dropout(classifier_dropout)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
    def forward(self, features, **kwargs):
        x = torch.mean(features, 1)  # take <s> token (equiv. to [CLS])
        x = self.dropout(x)
        x = self.dense(x)
        x = torch.tanh(x)
        x = self.dropout(x)
        x = self.out_proj(x)
        return x

In [11]:
GLUE_TASKS = ["cola", "mnli", "mnli-mm", "mrpc", "qnli", "qqp", "rte", "sst2", "stsb", "wnli"]
task = "sst2"
fc_head = False

freeze = []#["head", "encoder8", "encoder9", "encoder10"]
freeze_map = dict(
    head = "classifier",
    all_encoder = "roberta.encoder",
    **{"encoder" + str(i): "roberta.encoder.layer." + str(i) for i in range(14)}
)
shard = 1
model_checkpoint = "roberta-base"
batch_size = 16

head_str = ("FC" if fc_head else "CLS")
print(f"CHECKPOINT: {model_checkpoint}\nTASK: {task}\nFREEZE: {freeze}\nHEAD: {head_str}\nSHARD: {shard}")

CHECKPOINT: roberta-base
TASK: sst2
FREEZE: []
HEAD: CLS
SHARD: 1


In [12]:
import evaluate
actual_task = "mnli" if task == "mnli-mm" else task
dataset = load_dataset("glue", actual_task)
metric = load_metric('glue', actual_task)

Downloading data:   0%|          | 0.00/7.44M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [13]:
task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mnli-mm": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
    "imdb": ("text", None),
    "yahoo_answers_topics": ("question_title", "question_content")
}

In [14]:
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)

sentence1_key, sentence2_key = task_to_keys[task]
def preprocess_function(examples):
    if sentence2_key is None:
        return tokenizer(examples[sentence1_key], truncation=True)
    return tokenizer(examples[sentence1_key], examples[sentence2_key], truncation=True)

encoded_dataset = dataset.map(preprocess_function, batched=True).shuffle()

validation_key = "validation_mismatched" if task == "mnli-mm" else "validation_matched" if task == "mnli" else "unsupervised" if task == "imdb" else "validation"

train = encoded_dataset["train"].select(range(len(encoded_dataset["train"]) // shard))
validation = encoded_dataset[validation_key].select(range(len(encoded_dataset[validation_key]) // shard))

Map:   0%|          | 0/67349 [00:00<?, ? examples/s]

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

Map:   0%|          | 0/1821 [00:00<?, ? examples/s]

In [19]:
num_labels = 3 if task.startswith("mnli") else 1 if task=="stsb" else 2
config = AutoConfig.from_pretrained(model_checkpoint, num_labels=num_labels, output_hidden_states=True)

if "roberta" in model_checkpoint or not fc_head:
  model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, config=config, ignore_mismatched_sizes=True)
elif "bert" in model_checkpoint and fc_head:
  model = BertForSequenceClassification.from_pretrained(model_checkpoint, config=config, ignore_mismatched_sizes=True)

if fc_head and "roberta" in model_checkpoint:
  model.classifier = RobertaFCClassificationHead(config)
  model.post_init()

for name, param in model.named_parameters():
  for freeze_param in freeze:
    if name.startswith(freeze_map[freeze_param]):
      param.requires_grad = False
      print("%s frozen." % name)
      continue

model

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.out_proj.bias', 'classifier.dense.weight', 'classifier.out_proj.weight', 'classifier.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


RobertaForSequenceClassification(
  (roberta): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(50265, 768, padding_idx=1)
      (position_embeddings): Embedding(514, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
             

In [16]:
metric_name = "pearson" if task == "stsb" else "matthews_correlation" if task == "cola" else "accuracy"
model_name = model_checkpoint.split("/")[-1]
head_name = ("fc-head" if fc_head else "cls-head") + ("-frozen-"  + "_".join(freeze) if len(freeze) > 0 else "")
run_name = f"{model_name}-{head_name}-{task}"
basedir = "models/%s" % run_name
idx = max((int(d.split('-')[-1]) for d in glob(basedir + "-*")), default=0) + 1
output_dir = basedir + "-%i" % idx
run_name = f"{run_name}-{idx}"

print("RUN NAME: %s" % run_name)

args = TrainingArguments(
    report_to = 'wandb',
    output_dir = output_dir,
    run_name = run_name,
    evaluation_strategy = "epoch",
    save_strategy = "epoch",
    eval_steps=200,
    save_steps=200,
    learning_rate=2e-5,
    per_device_train_batch_size=batch_size,
    per_device_eval_batch_size=batch_size,
    num_train_epochs=4,
    warmup_steps=200,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model=metric_name,
    logging_strategy="steps",
    logging_steps=200

)

def compute_ncc(preds, hidden_states, labels, mode="full"):
  classes = np.unique(labels)
  nccs = []
  for i, layer in enumerate([*hidden_states, preds]):
    if mode == "cls" and len(layer.shape) == 3:
      layer_flattened = layer[:,:1,:].reshape(layer.shape[0], -1)
    elif mode == "mean" and len(layer.shape) == 3:
      layer_flattened = layer.mean(axis=1)
    else:
      layer_flattened = layer.reshape(layer.shape[0], -1)

    class_means = np.array([layer_flattened[labels == c].mean(axis=0) for c in classes])
    diff = layer_flattened[:,:,np.newaxis] - class_means.transpose()[np.newaxis,:,:]
    distances = np.linalg.norm(diff, axis=1)
    nearest_class = classes[distances.argmin(axis=1)]

    ncc = (nearest_class != labels).sum() / layer_flattened.shape[0]
    class_distance = np.linalg.norm(class_means[0]-class_means[1])
    class_variability = np.median(distances, axis=0)
    class_distance_normalized = np.linalg.norm(class_means, axis=1)**(-1) * class_distance
    class_variability_normalized = class_variability / np.linalg.norm(class_means.transpose(), axis=0)
    nccs.append((ncc, class_distance, class_variability, class_distance_normalized, class_variability_normalized))

  return nccs

def compute_metrics(eval_pred, specific_nccs=[], debug=False):
    predictions, labels = eval_pred
    assert len(predictions) == 2
    last_pred = predictions[0]
    activations = predictions[1]

    if task != "stsb":
        last_pred = np.argmax(last_pred, axis=1)
    else:
        last_pred = last_pred[:, 0]
    metrics = metric.compute(predictions=last_pred, references=labels)
    nccs = {}
    nccs["full"] = compute_ncc(predictions[0], activations, labels, mode="full")
    for mode in specific_nccs:
      nccs[mode] = compute_ncc(predictions[0], activations, labels, mode=mode)

    for t in ["full", *specific_nccs]:
      for i, row in enumerate(nccs[t]):
        ncc, center_dist, var, dist_normalized, var_normalized = row
        k = "layer_%i_" % (i+1)
        metrics[k + "ncc_" + t] = ncc
        if debug:
          for label in np.unique(labels):
            suffix = "_label_%s" % label
            metrics[k + "centers_dist" + suffix] = dist_normalized[label]
            metrics[k + "var" + suffix] = var_normalized[label]
            metrics[k + "separability" + suffix] = (center_dist / var)[label]

    return metrics

RUN NAME: roberta-base-cls-head-sst2-1


In [20]:
metrics = lambda pred: compute_metrics(pred, specific_nccs=["cls"])
trainer = Trainer(
    model,
    args,
    train_dataset=train,
    eval_dataset=validation,
    tokenizer=tokenizer,
    compute_metrics=metrics
)

In [21]:
trainer.evaluate()

{'eval_loss': 0.7033823132514954,
 'eval_accuracy': 0.5091743119266054,
 'eval_layer_1_ncc_full': 0.4908256880733945,
 'eval_layer_2_ncc_full': 0.4908256880733945,
 'eval_layer_3_ncc_full': 0.4908256880733945,
 'eval_layer_4_ncc_full': 0.4908256880733945,
 'eval_layer_5_ncc_full': 0.4908256880733945,
 'eval_layer_6_ncc_full': 0.4908256880733945,
 'eval_layer_7_ncc_full': 0.4908256880733945,
 'eval_layer_8_ncc_full': 0.4908256880733945,
 'eval_layer_9_ncc_full': 0.4908256880733945,
 'eval_layer_10_ncc_full': 0.4908256880733945,
 'eval_layer_11_ncc_full': 0.4908256880733945,
 'eval_layer_12_ncc_full': 0.4908256880733945,
 'eval_layer_13_ncc_full': 0.4908256880733945,
 'eval_layer_14_ncc_full': 0.46674311926605505,
 'eval_layer_1_ncc_cls': 0.5091743119266054,
 'eval_layer_2_ncc_cls': 0.34288990825688076,
 'eval_layer_3_ncc_cls': 0.36353211009174313,
 'eval_layer_4_ncc_cls': 0.36123853211009177,
 'eval_layer_5_ncc_cls': 0.3841743119266055,
 'eval_layer_6_ncc_cls': 0.2775229357798165,
 'eva

In [17]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,Layer 1 Ncc Full,Layer 2 Ncc Full,Layer 3 Ncc Full,Layer 4 Ncc Full,Layer 5 Ncc Full,Layer 6 Ncc Full,Layer 7 Ncc Full,Layer 8 Ncc Full,Layer 9 Ncc Full,Layer 10 Ncc Full,Layer 11 Ncc Full,Layer 12 Ncc Full,Layer 13 Ncc Full,Layer 14 Ncc Full,Layer 1 Ncc Cls,Layer 2 Ncc Cls,Layer 3 Ncc Cls,Layer 4 Ncc Cls,Layer 5 Ncc Cls,Layer 6 Ncc Cls,Layer 7 Ncc Cls,Layer 8 Ncc Cls,Layer 9 Ncc Cls,Layer 10 Ncc Cls,Layer 11 Ncc Cls,Layer 12 Ncc Cls,Layer 13 Ncc Cls,Layer 14 Ncc Cls
1,No log,0.695193,0.472924,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.440433,0.527076,0.34296,0.34296,0.33213,0.32852,0.34657,0.368231,0.397112,0.415162,0.436823,0.422383,0.393502,0.306859,0.440433
2,0.698600,0.606147,0.68231,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.299639,0.472924,0.32852,0.375451,0.33574,0.382671,0.32491,0.350181,0.256318,0.299639,0.33213,0.3213,0.3213,0.303249,0.299639
3,0.604600,0.632271,0.685921,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.277978,0.527076,0.33574,0.361011,0.33213,0.32491,0.31769,0.33213,0.32491,0.32852,0.299639,0.288809,0.277978,0.270758,0.277978
4,0.380400,0.701772,0.718412,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.425993,0.259928,0.472924,0.33574,0.357401,0.33574,0.33935,0.32852,0.34296,0.33935,0.350181,0.267148,0.256318,0.267148,0.259928,0.259928


TrainOutput(global_step=624, training_loss=0.5513087847293952, metrics={'train_runtime': 362.0362, 'train_samples_per_second': 27.511, 'train_steps_per_second': 1.724, 'total_flos': 839868988705200.0, 'train_loss': 0.5513087847293952, 'epoch': 4.0})

In [12]:
wandb.finish()

VBox(children=(Label(value='478.771 MB of 478.771 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0,…

0,1
eval/accuracy,▁████
eval/layer_10_ncc_cls,█▃▂▁▁
eval/layer_10_ncc_full,▁▁▁▁▁
eval/layer_11_ncc_cls,█▃▂▁▁
eval/layer_11_ncc_full,▁▁▁▁▁
eval/layer_12_ncc_cls,█▂▂▁▁
eval/layer_12_ncc_full,▁▁▁▁▁
eval/layer_13_ncc_cls,█▃▃▁▁
eval/layer_13_ncc_full,▁▁▁▁▁
eval/layer_14_ncc_cls,█▁▁▁▁

0,1
eval/accuracy,0.94828
eval/layer_10_ncc_cls,0.05517
eval/layer_10_ncc_full,0.48276
eval/layer_11_ncc_cls,0.04483
eval/layer_11_ncc_full,0.48276
eval/layer_12_ncc_cls,0.05172
eval/layer_12_ncc_full,0.48276
eval/layer_13_ncc_cls,0.05172
eval/layer_13_ncc_full,0.48276
eval/layer_14_ncc_cls,0.05172
