In [1]:
# modify the storage of json 
# https://huggingface.co/docs/datasets/loading#local-and-remote-files


# RUN THIS CELL ONLY FOR THE FIRST TIME
import json

with open('dataset.json') as f:
    d = json.load(f)

total_samples = list(d.keys())

In [2]:
len(total_samples)

20148

In [5]:
import random 

def get_train_valid_keys(total_keys):
    k = int(len(total_keys)*0.8)
    
    train_keys = random.choices(total_keys, k = k)
    valid_keys = []

    for key in total_keys:
        if key not in train_keys:
            valid_keys.append(key)

    return train_keys, valid_keys

train_keys, valid_keys = get_train_valid_keys(total_samples)

In [6]:
print(len(train_keys), len(valid_keys))

16118 8950


In [7]:
with open("train_dataset.json", "w") as outfile:
    for key in train_keys:
        json.dump(d[key], outfile)

with open("valid_dataset.json", "w") as outfile:
    for key in valid_keys:
        json.dump(d[key], outfile)

In [1]:
from datasets import load_dataset

data_files = {"train" : "train_dataset.json", "validation" : "valid_dataset.json"}

dataset = load_dataset("json", data_files=data_files)

In [2]:
dataset

DatasetDict({
    train: Dataset({
        features: ['post_id', 'annotators', 'rationales', 'post_tokens'],
        num_rows: 16118
    })
    validation: Dataset({
        features: ['post_id', 'annotators', 'rationales', 'post_tokens'],
        num_rows: 8950
    })
})

In [3]:
dataset['train'].num_rows

16118

# Sample Dataset

In [4]:
dataset_sample = dataset['train'].shuffle(seed = 42).select(range(1000))

In [5]:
dataset_sample[0]

{'post_id': '1178522691382456321_twitter',
 'annotators': [{'label': 'offensive',
   'annotator_id': 1,
   'target': ['Homosexual']},
  {'label': 'normal', 'annotator_id': 84, 'target': ['None']},
  {'label': 'offensive', 'annotator_id': 4, 'target': ['Homosexual']}],
 'rationales': [[0, 1, 1, 1, 0, 0, 0, 0, 0, 0],
  [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],
 'post_tokens': ['this',
  'faggots',
  'came',
  'back',
  'like',
  'go',
  'home',
  'get',
  'a',
  'life']}

In [6]:
dataset_sample[0]['post_tokens']

['this', 'faggots', 'came', 'back', 'like', 'go', 'home', 'get', 'a', 'life']

In [7]:
dataset_sample[0]['rationales']

[[0, 1, 1, 1, 0, 0, 0, 0, 0, 0], [0, 1, 0, 0, 0, 0, 0, 0, 0, 0]]

In [8]:
hate_speech_feature = dataset_sample.features['rationales']

In [9]:
hate_speech_feature

Sequence(feature=Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None), length=-1, id=None)

In [10]:
# add columns 
# final rational 
# final label 

In [11]:
# add new column -- final label
def get_label(sample):
    #get annotator 
    annotators = sample['annotators']
    
    #get labels 
    label_dict = {}
    for annotator in annotators:
        if annotator['label'] not in label_dict.keys():
            label_dict[annotator['label']] = 1
        else:
            label_dict[annotator['label']] += 1
            
            
    #get final label         
    max_label = max(label_dict.values())
    res = list(filter(lambda x: label_dict[x] == max_label, label_dict))
    
    return {"final_label": res[0]}

In [12]:
dataset_sample = dataset_sample.map(get_label)

In [13]:
dataset_sample[1]

{'post_id': '1164583620323028993_twitter',
 'annotators': [{'label': 'normal', 'annotator_id': 235, 'target': ['None']},
  {'label': 'offensive', 'annotator_id': 216, 'target': ['Islam', 'Refugee']},
  {'label': 'offensive', 'annotator_id': 209, 'target': ['Islam']}],
 'rationales': [[0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   1,
   0],
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]],
 'post_tokens': ['<user>',
  'lol',
  'not',
  'surprised',
  'at',
  'all',
  'he',
  'will',
  'have',
  'to',
  'wait',
  'until',
  'after',
  'the',
  'election',
  'if',
  'he',
  'wins',
  'to',
  'bring',
  'in',
  'jihadi',
  'jack'],
 'final_label': 'offensive'}

In [14]:
# add column -- final rational 
def get_combined_rational(sample):
    #get_rationals
    rationales = sample['rationales']
    sentence_tokens = sample['post_tokens']
    try:
        combined_rational = [0]*len(rationales[0])

        for rational in rationales:
            combined_rational = [a | b for a, b in zip(rational, combined_rational)]

        return {'final_rational' : combined_rational}
    except:
        combined_rational = [0] * len(sentence_tokens)
        return {'final_rational' : combined_rational}

In [15]:
dataset_sample = dataset_sample.map(get_combined_rational)

In [16]:
dataset_sample[1]

{'post_id': '1164583620323028993_twitter',
 'annotators': [{'label': 'normal', 'annotator_id': 235, 'target': ['None']},
  {'label': 'offensive', 'annotator_id': 216, 'target': ['Islam', 'Refugee']},
  {'label': 'offensive', 'annotator_id': 209, 'target': ['Islam']}],
 'rationales': [[0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   1,
   0],
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]],
 'post_tokens': ['<user>',
  'lol',
  'not',
  'surprised',
  'at',
  'all',
  'he',
  'will',
  'have',
  'to',
  'wait',
  'until',
  'after',
  'the',
  'election',
  'if',
  'he',
  'wins',
  'to',
  'bring',
  'in',
  'jihadi',
  'jack'],
 'final_label': 'offensive',
 'final_rational': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0]}

In [17]:
hate_speech_rationales = dataset_sample.features['final_rational']
hate_speech_rationales

Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None)

In [18]:
# add name tags
# used in seqeval

def create_name_tags(sample):
    rationales = sample['final_rational']
    ner_tags = []
    for rational in rationales:
        if rational == 0:
            ner_tags.append('O')
        else:
            ner_tags.append('B-HATE')
            
    return {'ner_tags' : ner_tags}

In [19]:
dataset_sample = dataset_sample.map(create_name_tags)

In [20]:
dataset_sample[1]

{'post_id': '1164583620323028993_twitter',
 'annotators': [{'label': 'normal', 'annotator_id': 235, 'target': ['None']},
  {'label': 'offensive', 'annotator_id': 216, 'target': ['Islam', 'Refugee']},
  {'label': 'offensive', 'annotator_id': 209, 'target': ['Islam']}],
 'rationales': [[0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   0,
   1,
   0],
  [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]],
 'post_tokens': ['<user>',
  'lol',
  'not',
  'surprised',
  'at',
  'all',
  'he',
  'will',
  'have',
  'to',
  'wait',
  'until',
  'after',
  'the',
  'election',
  'if',
  'he',
  'wins',
  'to',
  'bring',
  'in',
  'jihadi',
  'jack'],
 'final_label': 'offensive',
 'final_rational': [0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  0,
  1,
  0],
 'ner_tags': ['O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
  'O',
 

In [21]:
#feature attribute of dataset
ner_features = dataset_sample.features["ner_tags"]

In [22]:
dataset_sample

Dataset({
    features: ['post_id', 'annotators', 'rationales', 'post_tokens', 'final_label', 'final_rational', 'ner_tags'],
    num_rows: 1000
})

##  Adding Final Label, Final Rational, NER_tags to train and valid set

In [23]:
dataset = dataset.map(get_label)
dataset = dataset.map(get_combined_rational)
dataset = dataset.map(create_name_tags)

# Tokenize

In [24]:
# import the model 
from transformers import AutoTokenizer

model_checkpoint = "bert-base-cased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

In [25]:
tokenizer.is_fast

True

In [26]:
inputs = tokenizer(dataset_sample[0]["post_tokens"], is_split_into_words=True)
inputs.tokens()

['[CLS]',
 'this',
 'f',
 '##ag',
 '##got',
 '##s',
 'came',
 'back',
 'like',
 'go',
 'home',
 'get',
 'a',
 'life',
 '[SEP]']

In [27]:
dataset_sample[0]["post_tokens"]

['this', 'faggots', 'came', 'back', 'like', 'go', 'home', 'get', 'a', 'life']

In [28]:
inputs.word_ids()

[None, 0, 1, 1, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, None]

In [29]:
# -100 is an index that is ignored in the loss function we will use (cross entropy)

def align_labels_with_tokens(ner_tags, word_ids):
    #create a dict 
    w_ids = list(set(word_ids))
    w_ids.remove(None)

    new_ids = []

    for id__ in word_ids:
        if id__ == None:
            new_ids.append(-100)
        else:
            index = w_ids.index(id__)
            new_ids.append(ner_tags[index])

    return new_ids

In [30]:
ner_tags = dataset_sample[1]["final_rational"]
inputs = tokenizer(dataset_sample[1]["post_tokens"], is_split_into_words=True)
word_ids = inputs.word_ids()
print(ner_tags)
print(word_ids)
print(align_labels_with_tokens(ner_tags, word_ids))

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0]
[None, 0, 0, 0, 1, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 21, 21, 22, None]
[-100, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, -100]


In [31]:
def tokenize_and_align_labels(sample):
    tokenized_inputs = tokenizer(sample["post_tokens"], is_split_into_words=True)
    
    word_ids = tokenized_inputs.word_ids()
    ner_tags = sample["final_rational"]
    
    new_tags = align_labels_with_tokens(ner_tags, word_ids)

    tokenized_inputs["labels"] = new_tags
    
    return tokenized_inputs

In [32]:
dataset_sample_tokenized = dataset_sample.map(tokenize_and_align_labels, 
                                    remove_columns =dataset_sample.column_names,)

In [33]:
dataset_sample_tokenized

Dataset({
    features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
    num_rows: 1000
})

## Tokenize datasets

In [34]:
dataset_tokenized = dataset.map(tokenize_and_align_labels, remove_columns = dataset['train'].column_names)

In [35]:
dataset_tokenized

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 16118
    })
    validation: Dataset({
        features: ['input_ids', 'token_type_ids', 'attention_mask', 'labels'],
        num_rows: 8950
    })
})

# Data Collator

In [36]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer=tokenizer)

2024-07-13 09:52:42.949295: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:479] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-07-13 09:52:42.973775: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10575] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-07-13 09:52:42.973820: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1442] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-07-13 09:52:42.988585: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [37]:
batch = data_collator([dataset_sample_tokenized[i] for i in range(2)])
batch["labels"]

tensor([[-100,    0,    1,    1,    1,    1,    1,    1,    0,    0,    0,    0,
            0,    0, -100, -100, -100, -100, -100, -100, -100, -100, -100, -100,
         -100, -100, -100, -100, -100, -100],
        [-100,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,    0,
            0,    1,    1,    1,    0, -100]])

In [38]:
import evaluate

metric = evaluate.load("seqeval")

In [39]:
labels = dataset_sample[0]["ner_tags"]
labels = [i for i in labels]
labels

['O', 'B-HATE', 'B-HATE', 'B-HATE', 'O', 'O', 'O', 'O', 'O', 'O']

In [40]:
predictions = labels.copy()
predictions[2] = "O"
metric.compute(predictions=[predictions], references=[labels])

{'HATE': {'precision': 1.0,
  'recall': 0.6666666666666666,
  'f1': 0.8,
  'number': 3},
 'overall_precision': 1.0,
 'overall_recall': 0.6666666666666666,
 'overall_f1': 0.8,
 'overall_accuracy': 0.9}

In [41]:
import numpy as np

def compute_metrics(eval_preds):
    logits, labels = eval_preds
    predictions = np.argmax(logits, axis=-1)

    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    
    all_metrics = metric.compute(predictions=true_predictions, references=true_labels)
    
    return {
        "precision": all_metrics["overall_precision"],
        "recall": all_metrics["overall_recall"],
        "f1": all_metrics["overall_f1"],
        "accuracy": all_metrics["overall_accuracy"],
    }


In [42]:
label_names = ['O', 'B-HATE']

In [43]:
id2label = {i: label for i, label in enumerate(label_names)}
label2id = {v: k for k, v in id2label.items()}

In [44]:
from transformers import AutoModelForTokenClassification

model = AutoModelForTokenClassification.from_pretrained(
    model_checkpoint,
    id2label=id2label,
    label2id=label2id,
)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [45]:
model.config.num_labels

2

In [46]:
from transformers import TrainingArguments

args = TrainingArguments(
    "bert-finetuned-ner",
    eval_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    weight_decay=0.01,
)

In [47]:
from transformers import Trainer

trainer = Trainer(
    model = model,
    args = args,
    train_dataset = dataset_tokenized["train"],
    eval_dataset = dataset_tokenized["validation"],
    data_collator = data_collator,
    compute_metrics = compute_metrics,
    tokenizer = tokenizer,
)
trainer.train()

Epoch,Training Loss,Validation Loss,Precision,Recall,F1,Accuracy
1,0.3792,0.440273,0.57392,0.57971,0.5768,0.793655
2,0.2557,0.522676,0.620534,0.502138,0.555093,0.80475
3,0.1653,0.612112,0.594352,0.528291,0.559378,0.798118


TrainOutput(global_step=6045, training_loss=0.28299589819706994, metrics={'train_runtime': 11145.2076, 'train_samples_per_second': 4.339, 'train_steps_per_second': 0.542, 'total_flos': 1467397675449312.0, 'train_loss': 0.28299589819706994, 'epoch': 3.0})

In [72]:
trainer.save_model()

In [92]:
sentence = "Hi ! I am Mridul and I will let you die"
sentence_token = sentence.split()
sentence_token_inp = tokenizer(sentence_token, is_split_into_words=True)

In [93]:
sentence_token_inp

{'input_ids': [101, 8790, 106, 146, 1821, 1828, 2386, 4654, 1105, 146, 1209, 1519, 1128, 2939, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}

In [101]:
predictions = trainer.predict([sentence_token_inp])

In [105]:
predictions[0][0]

array([[ 0.89419085, -0.40716982],
       [ 1.3067052 , -0.6443935 ],
       [ 1.2271513 , -0.79165065],
       [ 0.70659477, -0.29191267],
       [ 0.8627339 , -0.47308615],
       [-0.10650423,  0.77537507],
       [-0.0018034 , -0.05933079],
       [-0.0118507 , -0.04285377],
       [ 0.40884638, -0.04362368],
       [-0.19079329,  0.5887715 ],
       [-0.7469316 ,  0.92949456],
       [-0.97256947,  0.94243336],
       [-1.1810716 ,  1.2315508 ],
       [-0.7058423 ,  0.9502713 ],
       [ 0.46939856,  0.4232003 ]], dtype=float32)

In [106]:
predictions = np.argmax(predictions[0][0], axis=1)

In [114]:
list(predictions)

[0, 0, 0, 0, 0, 1, 0, 0, 0, 1, 1, 1, 1, 1, 0]

In [115]:
sentence_token_inp.word_ids()

[None, 0, 1, 2, 3, 4, 4, 4, 5, 6, 7, 8, 9, 10, None]

In [120]:
def indetify_hate_speech(word_ids, pred_tags):

    id_hate_dict = {}
    id_occurence_dict = {}

    for id__ in word_ids:
        if id__ != None:
            index = word_ids.index(id__)
            ner_tag = pred_tags[index]
            
            if id__ in id_occurence_dict.keys():
                id_occurence_dict[id__] += 1
                if ner_tag == 1:
                    id_hate_dict[id__] += 1
            else:
                id_occurence_dict[id__] = 1
                if ner_tag == 1:
                    id_hate_dict[id__] = 1
                else:
                    id_hate_dict[id__] = 0

    hate_words = [0] * (len(set(word_ids)) - 1)
    for id__ in id_hate_dict.keys():
        # more than half hates
        if id_hate_dict[id__] > 0.5 * id_occurence_dict[id__]:
            hate_words[id__] = 1

    return hate_words
            

In [121]:
indetify_hate_speech(sentence_token_inp.word_ids(), predictions)

[0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1]

In [126]:
sentence = "Hi nice to meet you stupid fuck"
sentence_token = sentence.split()
sentence_token_inp = tokenizer(sentence_token, is_split_into_words=True)
predictions = trainer.predict([sentence_token_inp])
predictions = np.argmax(predictions[0][0], axis=1)
indetify_hate_speech(sentence_token_inp.word_ids(), predictions)

[0, 0, 0, 0, 0, 1, 1]

In [127]:
trainer.push_to_hub("bert-hate-speech")

events.out.tfevents.1720816258.ip-10-192-10-107.6549.0:   0%|          | 0.00/4.91k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/431M [00:00<?, ?B/s]

events.out.tfevents.1720859568.ip-10-192-10-32.1554.0:   0%|          | 0.00/6.46k [00:00<?, ?B/s]

events.out.tfevents.1720816784.ip-10-192-10-107.6549.1:   0%|          | 0.00/4.93k [00:00<?, ?B/s]

events.out.tfevents.1720864367.ip-10-192-10-32.18353.0:   0%|          | 0.00/9.23k [00:00<?, ?B/s]

Upload 6 LFS files:   0%|          | 0/6 [00:00<?, ?it/s]

training_args.bin:   0%|          | 0.00/5.11k [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/noobiebuilder/bert-finetuned-ner/commit/d019e253683d6addce4b7ca06e6b5882505ca0ff', commit_message='bert-hate-speech', commit_description='', oid='d019e253683d6addce4b7ca06e6b5882505ca0ff', pr_url=None, pr_revision=None, pr_num=None)

# Custom Training with Accelerate

In [56]:
from torch.utils.data import DataLoader

train_dataloader = DataLoader(
    dataset_tokenized["train"],
    shuffle=True,
    collate_fn = data_collator,
    batch_size=8,
)

eval_dataloader = DataLoader(
    dataset_tokenized["train"], 
    collate_fn=data_collator, 
    batch_size=8
)

In [57]:
# fresh model

model = AutoModelForTokenClassification.from_pretrained(
    model_checkpoint,
    id2label=id2label,
    label2id=label2id,
)

Some weights of BertForTokenClassification were not initialized from the model checkpoint at bert-base-cased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [58]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=2e-5)

In [59]:
from accelerate import Accelerator

accelerator = Accelerator()

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader)

# accelerator.prepare = use its length to compute the number of training steps.

In [60]:
from transformers import get_scheduler

num_train_epochs = 3
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

In [62]:
from huggingface_hub import notebook_login

notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [63]:
# token = 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
from huggingface_hub import Repository, get_full_repo_name

model_name = "bert-finetuned-hate-speech-accelerate"
repo_name = get_full_repo_name(model_name)
repo_name

'noobiebuilder/bert-finetuned-hate-speech-accelerate'

In [69]:
output_dir = "bert-finetuned-hate-speech-accelerate"

In [66]:
def postprocess(predictions, labels):
    predictions = predictions.detach().cpu().clone().numpy()
    labels = labels.detach().cpu().clone().numpy()

    # Remove ignored index (special tokens) and convert to labels
    true_labels = [[label_names[l] for l in label if l != -100] for label in labels]
    true_predictions = [
        [label_names[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    return true_labels, true_predictions

In [70]:
from tqdm.auto import tqdm
import torch

progress_bar = tqdm(range(num_training_steps))

for epoch in range(num_train_epochs):
    # Training
    model.train()
    for batch in train_dataloader:
        outputs = model(**batch)
        loss = outputs.loss
        accelerator.backward(loss)

        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        progress_bar.update(1)

    # Evaluation
    model.eval()
    for batch in eval_dataloader:
        with torch.no_grad():
            outputs = model(**batch)

        predictions = outputs.logits.argmax(dim=-1)
        labels = batch["labels"]

        # Necessary to pad predictions and labels for being gathered
        predictions = accelerator.pad_across_processes(predictions, dim=1, pad_index=-100)
        labels = accelerator.pad_across_processes(labels, dim=1, pad_index=-100)

        predictions_gathered = accelerator.gather(predictions)
        labels_gathered = accelerator.gather(labels)

        true_predictions, true_labels = postprocess(predictions_gathered, labels_gathered)
        metric.add_batch(predictions=true_predictions, references=true_labels)

    results = metric.compute()
    print(
        f"epoch {epoch}:",
        {
            key: results[f"overall_{key}"]
            for key in ["precision", "recall", "f1", "accuracy"]
        },
    )

    # Save and upload
    accelerator.wait_for_everyone()
    unwrapped_model = accelerator.unwrap_model(model)
    
    unwrapped_model.save_pretrained(output_dir, save_function=accelerator.save)
    
    if accelerator.is_main_process:
        tokenizer.save_pretrained(output_dir)

  0%|          | 0/6045 [00:00<?, ?it/s]

KeyboardInterrupt: 