<a href="https://colab.research.google.com/github/danielsaggau/IR_LDC/blob/main/model/SCOTUS/scotus_pertrain_scotus_classification.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets
import torch as nn

In [2]:
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    EvalPrediction,
    HfArgumentParser,
    TrainingArguments,
    default_data_collator,
    set_seed,
    EarlyStoppingCallback,
    Trainer
)

In [3]:
from transformers import TrainerCallback 
from datasets import load_metric
import numpy as np
import torch as nn

In [None]:
from datasets import load_dataset
dataset = load_dataset("lex_glue", "scotus")

In [5]:
!python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('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')"

In [None]:
tokenizer = AutoTokenizer.from_pretrained('danielsaggau/longformer_simcse_scotus', use_auth_token=True,use_fast=True)
model = AutoModelForSequenceClassification.from_pretrained('danielsaggau/longformer_simcse_scotus',use_auth_token=True, num_labels=14)

In [7]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

In [8]:
tokenized_data = dataset.map(preprocess_function, batched=True)

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [16]:
def compute_metrics(eval_pred):
    metric1 = load_metric("f1")
    accuracy = load_metric("accuracy")
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    micro1 = metric1.compute(predictions=predictions, references=labels, average="micro")["f1"]
    macro1 = metric1.compute(predictions=predictions, references=labels, average="macro")["f1"]
    accuracy = accuracy.compute(references=labels, predictions=predictions)['accuracy']
    return { "f1-micro": micro1, "f1-macro": macro1, "accuracy": accuracy}

In [None]:
training_args = TrainingArguments(
    output_dir="/scotus_experiments_MEAN_POOL",
    learning_rate=3e-5,
    per_device_train_batch_size=6,
    per_device_eval_batch_size=6,
    num_train_epochs=20,
    weight_decay=0.01,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    push_to_hub=True,
    fp16=True,
    warmup_ratio=3e-5,
    gradient_accumulation_steps=1,
    metric_for_best_model="f1-micro",
    greater_is_better=True,
    lr_scheduler_type='cosine',
    load_best_model_at_end = True
)

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [14]:
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) # fp16

In [None]:
model = AutoModelForSequenceClassification.from_pretrained('danielsaggau/longformer_simcse_scotus',use_auth_token=True, num_labels=14)

In [41]:
from torch import nn
class CustomLongformerPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        mean_token_tensor = hidden_states.mean(dim=1)
        pooled_output = self.dense(mean_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

In [42]:
model.longformer.pooler = CustomLongformerPooler(model.config)

In [44]:
import torch 
torch.cuda.empty_cache() 

In [45]:
trainer = Trainer(
    model=model,
    compute_metrics=compute_metrics,
    args=training_args,
    eval_dataset=tokenized_data['test'],
    train_dataset=tokenized_data["train"],
    tokenizer=tokenizer,
    data_collator=data_collator,    
    callbacks = [EarlyStoppingCallback(early_stopping_patience=5)])
trainer.train()

/scotus_experiments_MEAN_POOL is already a clone of https://huggingface.co/danielsaggau/scotus_experiments_MEAN_POOL. Make sure you pull the latest changes with `repo.git_pull()`.
Using cuda_amp half precision backend
The following columns in the training set don't have a corresponding argument in `LongformerForSequenceClassification.forward` and have been ignored: text. If text are not expected by `LongformerForSequenceClassification.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 5000
  Num Epochs = 20
  Instantaneous batch size per device = 6
  Total train batch size (w. parallel, distributed & accumulation) = 6
  Gradient Accumulation steps = 1
  Total optimization steps = 16680
  Number of trainable parameters = 41902094
Initializing global attention on CLS token...


RuntimeError: ignored

In [31]:
eval_dataset=tokenized_data['validation']
trainer.evaluate(eval_dataset=eval_dataset)

The following columns in the evaluation set don't have a corresponding argument in `LongformerForSequenceClassification.forward` and have been ignored: text. If text are not expected by `LongformerForSequenceClassification.forward`,  you can safely ignore this message.
***** Running Evaluation *****
  Num examples = 1400
  Batch size = 6
Initializing global attention on CLS token...


Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on CLS token...
Initializing global attention on C

{'eval_loss': 1.3238890171051025,
 'eval_f1-micro': 0.7707142857142857,
 'eval_f1-macro': 0.6967766365477588,
 'eval_accuracy': 0.7707142857142857,
 'eval_runtime': 64.0517,
 'eval_samples_per_second': 21.857,
 'eval_steps_per_second': 3.653,
 'epoch': 10.0}