In [None]:
from longformer.longformer import Longformer, LongformerConfig

In [None]:
cfg = LongformerConfig(attention_window=[512]* 12, attention_dilation=[1] * 12, autoregressive=False, attention_mode='sliding_chunks')
cfg.max_position_embeddings = 4096

In [None]:
model = Longformer(cfg, add_pooling_layer=False)

In [None]:
import torch
from torch import nn

from typing import Optional, Union, Tuple
from transformers.models.longformer.modeling_longformer import LongformerPreTrainedModel, LongformerSequenceClassifierOutput

class LongformerForSequenceClassification(LongformerPreTrainedModel):
    def __init__(self, config, num_labels, id2label, label2id, model: Longformer):
        super().__init__(config)
        self.num_labels = config.num_labels
        self.config = config

        self.longformer = model
        self.classifier = LongformerClassificationHead(config)

        self.post_init()

    def forward(
        self,
        input_ids: Optional[torch.Tensor] = None,
        attention_mask: Optional[torch.Tensor] = None,
        global_attention_mask: Optional[torch.Tensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        token_type_ids: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.Tensor] = None,
        inputs_embeds: Optional[torch.Tensor] = None,
        labels: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ) -> Union[Tuple, LongformerSequenceClassifierOutput]:
        r"""
        labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
            Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
            config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
            `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        if global_attention_mask is None:
            logger.warning_once("Initializing global attention on CLS token...")
            global_attention_mask = torch.zeros_like(input_ids)
            # global attention on cls token
            global_attention_mask[:, 0] = 1

        outputs = self.longformer(
            input_ids,
            attention_mask=attention_mask,
            # global_attention_mask=global_attention_mask,
            head_mask=head_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )


    #     self_outputs = self.self(
    # 334         hidden_states,
    # 335         attention_mask,
    # 336         head_mask,
    # 337         encoder_hidden_states,
    # 338         encoder_attention_mask,
    # 339         past_key_value,
    # 340         output_attentions,
    # 341     )
        
        sequence_output = outputs[0]
        logits = self.classifier(sequence_output)

        loss = None
        if labels is not None:
            labels = labels.to(logits.device)

            if self.config.problem_type is None:
                if self.num_labels == 1:
                    self.config.problem_type = "regression"
                elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
                    self.config.problem_type = "single_label_classification"
                else:
                    self.config.problem_type = "multi_label_classification"

            if self.config.problem_type == "regression":
                loss_fct = MSELoss()
                if self.num_labels == 1:
                    loss = loss_fct(logits.squeeze(), labels.squeeze())
                else:
                    loss = loss_fct(logits, labels)
            elif self.config.problem_type == "single_label_classification":
                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            elif self.config.problem_type == "multi_label_classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return LongformerSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
            global_attentions=outputs.global_attentions,
        )


class LongformerClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(self, config):
        super().__init__()
        self.dense = nn.Linear(config.hidden_size, config.hidden_size)
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.out_proj = nn.Linear(config.hidden_size, config.num_labels)

    def forward(self, hidden_states, **kwargs):
        hidden_states = hidden_states[:, 0, :]  # take <s> token (equiv. to [CLS])
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        output = self.out_proj(hidden_states)
        return output

------------------------------------------------------------------------------------------------------

In [None]:
import json
import os

import evaluate
import numpy as np
import pandas as pd
import torch
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)

In [None]:
with open("config.json") as f:
    config = json.load(f)

In [None]:
config

In [None]:
import shutil

torch.cuda.empty_cache()
model_name = config["model_name"]
model_path = (
    model_name.split("/")[-1].replace("-", "_") + "_text_classification_imdb" + "_adamw"
)
BATCH_SIZE = config["batch_size"]
NUM_EPOCHS = config["num_epochs"]

if os.path.exists('model_path'):
    shutil.rmtree('model_path')

imdb = load_dataset("imdb", cache_dir="model_path")

tokenizer = AutoTokenizer.from_pretrained(model_name)

def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True, padding='max_length', max_length=4096)

tokenized_imdb = imdb.map(preprocess_function, batched=True, remove_columns=['text'])

In [None]:
accuracy = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

longformer = Longformer(cfg, add_pooling_layer=False)
model = LongformerForSequenceClassification(
    cfg, num_labels=2, id2label=id2label, label2id=label2id, model=longformer,
)

In [None]:
# based on this

# https://github.com/huggingface/transformers/blob/main/examples/pytorch/text-classification/run_glue_no_trainer.py

In [None]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(tokenized_imdb['train'].with_format('torch'), shuffle=True, collate_fn=data_collator, batch_size=32)
eval_dataloader = DataLoader(tokenized_imdb['test'].with_format('torch'), collate_fn=data_collator, batch_size=32)

In [None]:
for batch in train_dataloader:
    print(batch)
    print(batch.keys())
    break

In [None]:
lr = 0.001 # this is default one from adamW

optimizer = torch.optim.AdamW(model.parameters(), lr=lr)

In [None]:
import math

from transformers import get_scheduler
from accelerate import Accelerator


max_train_steps = None # left it as None, to be calculated dynamically

num_train_epochs = 3
gradient_accumulation_steps = 1

# Scheduler and math around the number of training steps.
overrode_max_train_steps = False
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
if max_train_steps is None:
    max_train_steps = num_train_epochs * num_update_steps_per_epoch
    overrode_max_train_steps = True

lr_scheduler_type = 'linear' 
# possible values 
# class SchedulerType(ExplicitEnum):
#     LINEAR = "linear"
#     COSINE = "cosine"
#     COSINE_WITH_RESTARTS = "cosine_with_restarts"
#     POLYNOMIAL = "polynomial"
#     CONSTANT = "constant"
#     CONSTANT_WITH_WARMUP = "constant_with_warmup"
#     INVERSE_SQRT = "inverse_sqrt"
#     REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
#     COSINE_WITH_MIN_LR = "cosine_with_min_lr"
#     WARMUP_STABLE_DECAY = "warmup_stable_decay"


num_warmup_steps = 0 # was default

lr_scheduler = get_scheduler(
    name=lr_scheduler_type,
    optimizer=optimizer,
    num_warmup_steps=num_warmup_steps,
    num_training_steps=max_train_steps,
)

accelerator = Accelerator()

# Prepare everything with our `accelerator`.
model, optimizer, train_dataloader, eval_dataloader, lr_scheduler = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader, lr_scheduler
)

In [None]:
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)

if overrode_max_train_steps:
    max_train_steps = num_train_epochs * num_update_steps_per_epoch

# Afterwards we recalculate our number of training epochs
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)

# Figure out how many steps we should save the Accelerator states


checkpointing_steps = 'epoch' # (can be integer values or idk, check in the script - link above) 
if checkpointing_steps is not None and checkpointing_steps.isdigit():
    checkpointing_steps = int(checkpointing_steps)

In [None]:
# Train!
total_batch_size = 32 * accelerator.num_processes * gradient_accumulation_steps

In [None]:
import logging

from accelerate.logging import get_logger

logger = get_logger('logger')

logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
logger.info(accelerator.state, main_process_only=False)

In [None]:
from tqdm.auto import tqdm

progress_bar = tqdm(range(max_train_steps), disable=not accelerator.is_local_main_process)
completed_steps = 0
starting_epoch = 0

# Potentially load in the weights and states from a previous save
# if args.resume_from_checkpoint:
#     if args.resume_from_checkpoint is not None or args.resume_from_checkpoint != "":
#         checkpoint_path = args.resume_from_checkpoint
#         path = os.path.basename(args.resume_from_checkpoint)
#     else:
#         # Get the most recent checkpoint
#         dirs = [f.name for f in os.scandir(os.getcwd()) if f.is_dir()]
#         dirs.sort(key=os.path.getctime)
#         path = dirs[-1]  # Sorts folders by date modified, most recent checkpoint is the last
#         checkpoint_path = path
#         path = os.path.basename(checkpoint_path)

#     accelerator.print(f"Resumed from checkpoint: {checkpoint_path}")
#     accelerator.load_state(checkpoint_path)
#     # Extract `epoch_{i}` or `step_{i}`
#     training_difference = os.path.splitext(path)[0]

#     if "epoch" in training_difference:
#         starting_epoch = int(training_difference.replace("epoch_", "")) + 1
#         resume_step = None
#         completed_steps = starting_epoch * num_update_steps_per_epoch
#     else:
#         # need to multiply `gradient_accumulation_steps` to reflect real steps
#         resume_step = int(training_difference.replace("step_", "")) * args.gradient_accumulation_steps
#         starting_epoch = resume_step // len(train_dataloader)
#         completed_steps = resume_step // args.gradient_accumulation_steps
#         resume_step -= starting_epoch * len(train_dataloader)

# update the progress_bar if load from checkpoint
# progress_bar.update(completed_steps)

output_dir = model_path

for epoch in range(starting_epoch, num_train_epochs):
    model.train()
    total_loss = 0
    active_dataloader = train_dataloader
    
    for step, batch in enumerate(active_dataloader):
        outputs = model(**batch)
        loss = outputs.loss
        
        # We keep track of the loss at each epoch
        total_loss += loss.detach().float()
        loss = loss / gradient_accumulation_steps
        accelerator.backward(loss)
        if step % gradient_accumulation_steps == 0 or step == len(train_dataloader) - 1:
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()
            progress_bar.update(1)
            completed_steps += 1

        if isinstance(checkpointing_steps, int):
            if completed_steps % checkpointing_steps == 0:
                output_dir = f"step_{completed_steps}"
                if output_dir is not None:
                    output_dir = os.path.join(output_dir, output_dir)
                accelerator.save_state(output_dir)

        if completed_steps >= max_train_steps:
            break

    model.eval()
    samples_seen = 0
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            outputs = model(**batch)
        predictions = outputs.logits.argmax(dim=-1) if not is_regression else outputs.logits.squeeze()
        predictions, references = accelerator.gather((predictions, batch["labels"]))
        # If we are in a multiprocess environment, the last batch has duplicates
        if accelerator.num_processes > 1:
            if step == len(eval_dataloader) - 1:
                predictions = predictions[: len(eval_dataloader.dataset) - samples_seen]
                references = references[: len(eval_dataloader.dataset) - samples_seen]
            else:
                samples_seen += references.shape[0]
        metric.add_batch(
            predictions=predictions,
            references=references,
        )

    eval_metric = metric.compute()
    logger.info(f"epoch {epoch}: {eval_metric}")

    accelerator.log(
        {
            "accuracy" if args.task_name is not None else "glue": eval_metric,
            "train_loss": total_loss.item() / len(train_dataloader),
            "epoch": epoch,
            "step": completed_steps,
        },
        step=completed_steps,
    )

    # if args.push_to_hub and epoch < args.num_train_epochs - 1:
    #     accelerator.wait_for_everyone()
    #     unwrapped_model = accelerator.unwrap_model(model)
    #     unwrapped_model.save_pretrained(
    #         args.output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
    #     )
    #     if accelerator.is_main_process:
    #         tokenizer.save_pretrained(args.output_dir)
    #         api.upload_folder(
    #             commit_message=f"Training in progress epoch {epoch}",
    #             folder_path=args.output_dir,
    #             repo_id=repo_id,
    #             repo_type="model",
    #             token=args.hub_token,
    #         )

    if checkpointing_steps == "epoch":
        output_dir_epoch = f"epoch_{epoch}"
        if output_dir is not None:
            output_dir_epoch = os.path.join(output_dir, output_dir_epoch)
        accelerator.save_state(output_dir_epoch)


accelerator.end_training()

if output_dir is not None:
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    unwrapped_model.save_pretrained(
        output_dir, is_main_process=accelerator.is_main_process, save_function=accelerator.save
    )
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)
        # if args.push_to_hub:
        #     api.upload_folder(
        #         commit_message="End of training",
        #         folder_path=args.output_dir,
        #         repo_id=repo_id,
        #         repo_type="model",
        #         token=args.hub_token,
        #     )

if output_dir is not None:
    all_results = {f"eval_{k}": v for k, v in eval_metric.items()}
    with open(os.path.join(output_dir, "all_results.json"), "w") as f:
        json.dump(all_results, f)