In [1]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
import torch
import os
from dotenv import load_dotenv
import numpy as np
from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
load_dotenv()


  from .autonotebook import tqdm as notebook_tqdm


cpu


True

In [2]:
train_dataset = load_dataset('go_emotions', split='train')
test_dataset = load_dataset('go_emotions', split='test')

train_dataset[0]

{'text': "My favourite food is anything I didn't have to cook myself.",
 'labels': [27],
 'id': 'eebbqej'}

In [3]:
tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

def convert_to_multihot(label_list, num_labels):
    multihot = torch.zeros(num_labels)
    
    for label in label_list:
        multihot[label] = 1
    
    return multihot

def tokenize_and_encode_dataset(dataset):
    encoded_dataset = []
    
    for i in tqdm(range(len(dataset))):
        tokenized_input = tokenizer(
            dataset[i]['text'],
            padding='max_length',
            truncation=True,
            max_length=512,
            return_tensors='pt'
        )
        
        labels = convert_to_multihot(
            dataset[i]['labels'],
            num_labels=28
        )
        
        encoded_example = {
            'input_ids': tokenized_input['input_ids'].squeeze(0),
            'attention_mask': tokenized_input['attention_mask'].squeeze(0),
            'labels': labels
        }
        
        encoded_dataset.append(encoded_example)
    return encoded_dataset
        
training = tokenize_and_encode_dataset(train_dataset)
testing = tokenize_and_encode_dataset(test_dataset)

100%|██████████| 43410/43410 [00:11<00:00, 3711.55it/s]
100%|██████████| 5427/5427 [00:01<00:00, 3369.46it/s]


In [4]:
class EmotionDataset(Dataset):
    def __init__(self, encoded_dataset):
        self.data = encoded_dataset
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        return self.data[idx]
    
train_dataset = EmotionDataset(training)
test_dataset = EmotionDataset(testing)

In [8]:
model = AutoModelForSequenceClassification.from_pretrained(
    'bert-base-uncased', 
    num_labels=28, 
    problem_type='multi_label_classification',
    token=os.getenv('huggingface_api_key')
    )

model.gradient_checkpointing_enable()

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [9]:
training_args = TrainingArguments(
    output_dir='./results',
    eval_strategy='epoch',
    save_strategy='epoch',
    learning_rate=2e-5,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=3,
    weight_decay=.01,
    logging_dir='./logs',
    logging_steps=10,
    load_best_model_at_end=True,
    gradient_accumulation_steps=4
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset
)

In [10]:
trainer.train()

                                        
  0%|          | 0/8139 [03:51<?, ?it/s]            

{'loss': 0.6453, 'grad_norm': 5.252878189086914, 'learning_rate': 1.997542695662858e-05, 'epoch': 0.0}


                                        
  0%|          | 0/8139 [05:27<?, ?it/s]            

{'loss': 0.5408, 'grad_norm': 4.748049736022949, 'learning_rate': 1.9950853913257158e-05, 'epoch': 0.01}


                                        
  0%|          | 0/8139 [06:27<?, ?it/s]            

{'loss': 0.4458, 'grad_norm': 4.0631608963012695, 'learning_rate': 1.9926280869885738e-05, 'epoch': 0.01}


                                        
  0%|          | 0/8139 [07:15<?, ?it/s]            

{'loss': 0.3719, 'grad_norm': 3.0443949699401855, 'learning_rate': 1.9901707826514315e-05, 'epoch': 0.01}


                                        
  0%|          | 0/8139 [08:18<?, ?it/s]            

{'loss': 0.3217, 'grad_norm': 2.6003167629241943, 'learning_rate': 1.987713478314289e-05, 'epoch': 0.02}


                                        
  0%|          | 0/8139 [09:43<?, ?it/s]            

{'loss': 0.2732, 'grad_norm': 2.2467637062072754, 'learning_rate': 1.985256173977147e-05, 'epoch': 0.02}


                                        
  0%|          | 0/8139 [11:22<?, ?it/s]            

{'loss': 0.2448, 'grad_norm': 1.9526629447937012, 'learning_rate': 1.982798869640005e-05, 'epoch': 0.03}


                                        
  0%|          | 0/8139 [12:23<?, ?it/s]            

{'loss': 0.219, 'grad_norm': 1.5164965391159058, 'learning_rate': 1.9803415653028628e-05, 'epoch': 0.03}


                                        
  0%|          | 0/8139 [13:19<?, ?it/s]            

{'loss': 0.2031, 'grad_norm': 1.5184931755065918, 'learning_rate': 1.9778842609657208e-05, 'epoch': 0.03}


                                        
  0%|          | 0/8139 [14:16<?, ?it/s]             

{'loss': 0.187, 'grad_norm': 1.5094234943389893, 'learning_rate': 1.9754269566285788e-05, 'epoch': 0.04}


                                        
  0%|          | 0/8139 [15:23<?, ?it/s]             

{'loss': 0.177, 'grad_norm': 1.1334054470062256, 'learning_rate': 1.9729696522914364e-05, 'epoch': 0.04}


                                        
  0%|          | 0/8139 [16:23<?, ?it/s]             

{'loss': 0.1719, 'grad_norm': 0.8974711298942566, 'learning_rate': 1.970512347954294e-05, 'epoch': 0.04}


                                        
  0%|          | 0/8139 [17:18<?, ?it/s]             

{'loss': 0.1639, 'grad_norm': 0.9343737363815308, 'learning_rate': 1.968055043617152e-05, 'epoch': 0.05}


                                        
  0%|          | 0/8139 [18:48<?, ?it/s]             

{'loss': 0.1579, 'grad_norm': 0.9667676687240601, 'learning_rate': 1.96559773928001e-05, 'epoch': 0.05}


                                        
  0%|          | 0/8139 [19:43<?, ?it/s]             

{'loss': 0.1574, 'grad_norm': 0.9658186435699463, 'learning_rate': 1.9631404349428678e-05, 'epoch': 0.06}


                                        
  0%|          | 0/8139 [20:39<?, ?it/s]             

{'loss': 0.1577, 'grad_norm': 0.8895514607429504, 'learning_rate': 1.9606831306057258e-05, 'epoch': 0.06}


                                        
  0%|          | 0/8139 [21:37<?, ?it/s]             

{'loss': 0.1528, 'grad_norm': 0.9111794829368591, 'learning_rate': 1.9582258262685838e-05, 'epoch': 0.06}


                                        
  0%|          | 0/8139 [22:54<?, ?it/s]             

{'loss': 0.1584, 'grad_norm': 1.0175142288208008, 'learning_rate': 1.9557685219314414e-05, 'epoch': 0.07}


                                        
  0%|          | 0/8139 [23:51<?, ?it/s]             

{'loss': 0.1473, 'grad_norm': 1.0838487148284912, 'learning_rate': 1.953311217594299e-05, 'epoch': 0.07}


                                        
  0%|          | 0/8139 [24:50<?, ?it/s]             

{'loss': 0.1528, 'grad_norm': 0.879416286945343, 'learning_rate': 1.950853913257157e-05, 'epoch': 0.07}


                                        
  0%|          | 0/8139 [25:45<?, ?it/s]             

{'loss': 0.1469, 'grad_norm': 0.9143323302268982, 'learning_rate': 1.9483966089200147e-05, 'epoch': 0.08}




KeyboardInterrupt: 