<a href="https://colab.research.google.com/github/danielsaggau/IR_LDC/blob/main/model/SCOTUS/freezing_CLS.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers datasets
from datasets import load_dataset

In [None]:
dataset=load_dataset('lex_glue', 'scotus')

In [None]:
from transformers import AutoTokenizer
!python -c "from huggingface_hub.hf_api import HfFolder; HfFolder.save_token('XXX')"
tokenizer = AutoTokenizer.from_pretrained('danielsaggau/longformer_simcse_scotus', use_auth_token=True,use_fast=True)

In [None]:
from transformers import AutoModelForSequenceClassification
model_1 = AutoModelForSequenceClassification.from_pretrained('danielsaggau/longformer_simcse_scotus',use_auth_token=True, num_labels=14)

In [5]:
import torch 
from torch import nn

In [6]:
# Copied from transformers.models.bert.modeling_bert.BertPooler
class CustomLongformerPooler(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.activation = nn.Tanh()

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        # We "pool" the model by simply taking the hidden state corresponding
        # to the first token.
        mean_token_tensor = hidden_states.mean(dim=1)
        pooled_output = self.dense(mean_token_tensor)
        pooled_output = self.activation(pooled_output)
        return pooled_output

In [None]:
# Use custom pooler with mean-pooling instead of cls-pooling
model_1.longformer.pooler = CustomLongformerPooler(model_1.config)

In [None]:
model_1

In [20]:
for name, param in model_1.named_parameters():
     if name.startswith("longformer."): # choose whatever you like here
        param.requires_grad = False

In [None]:
for name, param in model_1.named_parameters(): #q update parameter for pooler
     print(name, param.requires_grad)

In [19]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)

In [22]:
tokenized_data = dataset.map(preprocess_function, batched=True)

  0%|          | 0/5 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [10]:
from datasets import load_metric
import numpy as np

def compute_metrics(eval_pred):
    metric1 = load_metric("f1")
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=-1)
    micro1 = metric1.compute(predictions=predictions, references=labels, average="micro")["f1"]
    macro1 = metric1.compute(predictions=predictions, references=labels, average="macro")["f1"]
    return { "f1-micro": micro1, "f1-macro": macro1}

In [11]:
from transformers import DataCollatorWithPadding
data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8) # fp16

In [None]:
!git clone https://github.com/danielsaggau/IR_LDC.git

In [13]:
%cd IR_LDC

/content/IR_LDC


In [None]:
!pip install wandb
import wandb
wandb.login()

In [None]:
wandb.init(project="IR_LDC")
wandb.init(name="mean_linear")

In [24]:
from transformers import TrainingArguments
training_args = TrainingArguments(
    output_dir="/scotus_mean",
    learning_rate=3e-5,
    per_device_train_batch_size=6,
    per_device_eval_batch_size=6,
    num_train_epochs=20,
    weight_decay=0.01,
    save_strategy="epoch",
    evaluation_strategy="epoch",
    push_to_hub=True,
    report_to="wandb",
    fp16=True,
    run_name="mean_frozen",
    metric_for_best_model="f1-micro",
    greater_is_better=True,
    load_best_model_at_end = True
)

PyTorch: setting up devices


In [None]:
from transformers import Trainer, EarlyStoppingCallback
import torch
trainer = Trainer(
    model=model_1,
    compute_metrics=compute_metrics,
    args=training_args,
    eval_dataset=tokenized_data['validation'],
    train_dataset=tokenized_data["train"],
    tokenizer=tokenizer,
    data_collator=data_collator,    
    callbacks = [EarlyStoppingCallback(early_stopping_patience=10)])
trainer.train()

In [None]:
trainer.evaluate(eval_dataset=tokenized_data['test'])