# GAN-BERT (in Pytorch and compatible with HuggingFace)

This is a Pytorch (+ **Huggingface** transformers) implementation of the GAN-BERT model from https://github.com/crux82/ganbert. While the original GAN-BERT was an extension of BERT, this implementation can be adapted to several architectures, ranging from Roberta to Albert!

**NOTE**: given that this implementation is different from the original one in Tensorflow, some results can be slighty different.


Let's GO!

Required Imports.

In [1]:
!pip install transformers==4.3.2
import torch
import io
import torch.nn.functional as F
import random
import numpy as np
import time
import math
import datetime
import torch.nn as nn
from transformers import *
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
#!pip install torch==1.7.1+cu101 torchvision==0.8.2+cu101 -f https://download.pytorch.org/whl/torch_stable.html
#!pip install sentencepiece

##Set random values
seed_val = 42
random.seed(seed_val)
np.random.seed(seed_val)
torch.manual_seed(seed_val)
if torch.cuda.is_available():
  torch.cuda.manual_seed_all(seed_val)

Collecting transformers==4.3.2
  Downloading transformers-4.3.2-py3-none-any.whl.metadata (36 kB)
Collecting sacremoses (from transformers==4.3.2)
  Downloading sacremoses-0.1.1-py3-none-any.whl.metadata (8.3 kB)
Collecting tokenizers<0.11,>=0.10.1 (from transformers==4.3.2)
  Downloading tokenizers-0.10.3.tar.gz (212 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m212.7/212.7 kB[0m [31m9.7 MB/s[0m eta [36m0:00:00[0m
[?25h  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Downloading transformers-4.3.2-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m48.3 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading sacremoses-0.1.1-py3-none-any.whl (897 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m897.5/897.5 kB[0m [31m56.6 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels

GroupViT models are not usable since `tensorflow_probability` can't be loaded. It seems you have `tensorflow_probability` installed with the wrong tensorflow version.Please try to reinstall it following the instructions here: https://github.com/tensorflow/probability.
TAPAS models are not usable since `tensorflow_probability` can't be loaded. It seems you have `tensorflow_probability` installed with the wrong tensorflow version. Please try to reinstall it following the instructions here: https://github.com/tensorflow/probability.


In [2]:
# If there's a GPU available...
if torch.cuda.is_available():
    # Tell PyTorch to use the GPU.
    device = torch.device("cuda")
    print('There are %d GPU(s) available.' % torch.cuda.device_count())
    print('We will use the GPU:', torch.cuda.get_device_name(0))
# If not...
else:
    print('No GPU available, using the CPU instead.')
    device = torch.device("cpu")

There are 1 GPU(s) available.
We will use the GPU: Tesla T4


### Input Parameters


In [3]:
#--------------------------------
#  Transformer parameters
#--------------------------------
max_seq_length = 64
batch_size = 64

#--------------------------------
#  GAN-BERT specific parameters
#--------------------------------
# number of hidden layers in the generator,
# each of the size of the output space
num_hidden_layers_g = 1;
# number of hidden layers in the discriminator,
# each of the size of the input space
num_hidden_layers_d = 1;
# size of the generator's input noisy vectors
noise_size = 100
# dropout to be applied to discriminator's input vectors
out_dropout_rate = 0.2

# Replicate labeled data to balance poorly represented datasets,
# e.g., less than 1% of labeled material
apply_balance = True

#--------------------------------
#  Optimization parameters
#--------------------------------
learning_rate_discriminator = 5e-5
learning_rate_generator = 5e-5
epsilon = 1e-8
num_train_epochs = 10
multi_gpu = True
# Scheduler
apply_scheduler = False
warmup_proportion = 0.1
# Print
print_each_n_step = 10

#--------------------------------
#  Adopted Tranformer model
#--------------------------------
# Since this version is compatible with Huggingface transformers, you can uncomment
# (or add) transformer models compatible with GAN

model_name = "bert-base-cased"
#model_name = "bert-base-uncased"
#model_name = "roberta-base"
#model_name = "albert-base-v2"
#model_name = "xlm-roberta-base"
#model_name = "amazon/bort"

#--------------------------------
#  Retrieve the TREC QC Dataset
#--------------------------------
! git clone https://github.com/mauriciokonrath/ganbert.git

#  NOTE: in this setting 50 classes are involved
labeled_file = "./ganbert/data/standardized_labeled_monsanto_withoutSub.tsv"
unlabeled_file = "./ganbert/data/unlabeled_enron_5000.tsv"
test_filename = "./ganbert/data/standardized_test_monsanto_withoutSub.tsv"

#categorias de rótulos que o modelo deve aprender a classificar.
label_list = ["UNK_UNK","GHOST_ghost", "TOXIC_toxic",
              "CHEMI_chemi", "REGUL_regul"]

Cloning into 'ganbert'...
remote: Enumerating objects: 173, done.[K
remote: Counting objects: 100% (173/173), done.[K
remote: Compressing objects: 100% (172/172), done.[K
remote: Total 173 (delta 104), reused 0 (delta 0), pack-reused 0 (from 0)[K
Receiving objects: 100% (173/173), 9.22 MiB | 5.10 MiB/s, done.
Resolving deltas: 100% (104/104), done.


Load the Tranformer Model

In [4]:
transformer = AutoModel.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--bert-base-cased/snapshots/cd5ef92a9fb2f889e972770a36d4ed042daf221e/config.json
Model config BertConfig {
  "_name_or_path": "bert-base-cased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.47.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 28996
}



model.safetensors:   0%|          | 0.00/436M [00:00<?, ?B/s]

loading weights file model.safetensors from cache at /root/.cache/huggingface/hub/models--bert-base-cased/snapshots/cd5ef92a9fb2f889e972770a36d4ed042daf221e/model.safetensors
A pretrained model of type `BertModel` contains parameters that have been renamed internally (a few are listed below but more are present in the model):
* `bert.embeddings.LayerNorm.gamma` -> `bert.embeddings.LayerNorm.weight`
* `bert.encoder.layer.0.attention.output.LayerNorm.gamma` -> `{'bert.embeddings.LayerNorm.gamma': 'bert.embeddings.LayerNorm.weight', 'bert.encoder.layer.0.attention.output.LayerNorm.gamma': {...}, 'bert.encoder.layer.0.output.LayerNorm.gamma': {...}, 'bert.encoder.layer.1.attention.output.LayerNorm.gamma': {...}, 'bert.encoder.layer.1.output.LayerNorm.gamma': {...}, 'bert.encoder.layer.10.attention.output.LayerNorm.gamma': {...}, 'bert.encoder.layer.10.output.LayerNorm.gamma': {...}, 'bert.encoder.layer.11.attention.output.LayerNorm.gamma': {...}, 'bert.encoder.layer.11.output.LayerNorm.gam

tokenizer_config.json:   0%|          | 0.00/49.0 [00:00<?, ?B/s]

loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--bert-base-cased/snapshots/cd5ef92a9fb2f889e972770a36d4ed042daf221e/config.json
Model config BertConfig {
  "_name_or_path": "bert-base-cased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.47.1",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 28996
}



vocab.txt:   0%|          | 0.00/213k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/436k [00:00<?, ?B/s]

loading file vocab.txt from cache at /root/.cache/huggingface/hub/models--bert-base-cased/snapshots/cd5ef92a9fb2f889e972770a36d4ed042daf221e/vocab.txt
loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models--bert-base-cased/snapshots/cd5ef92a9fb2f889e972770a36d4ed042daf221e/tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at None
loading file tokenizer_config.json from cache at /root/.cache/huggingface/hub/models--bert-base-cased/snapshots/cd5ef92a9fb2f889e972770a36d4ed042daf221e/tokenizer_config.json
loading file chat_template.jinja from cache at None
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models--bert-base-cased/snapshots/cd5ef92a9fb2f889e972770a36d4ed042daf221e/config.json
Model config BertConfig {
  "_name_or_path": "bert-base-cased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gr

Function required to load the dataset

In [5]:
def get_qc_examples(input_file):
  """Creates examples for the training and dev sets."""
  examples = []

  with open(input_file, 'r') as f:
      contents = f.read()
      file_as_list = contents.splitlines()
      for line in file_as_list[1:]:
          split = line.split(" ")
          question = ' '.join(split[1:])

          text_a = question
          inn_split = split[0].split(":")
          label = inn_split[0] + "_" + inn_split[1]
          examples.append((text_a, label))
      f.close()

  return examples

**Load** the input QC dataset (fine-grained)

In [6]:
#Load the examples
labeled_examples = get_qc_examples(labeled_file)
#unlabeled_examples = get_qc_examples(unlabeled_file)
#test_examples = get_qc_examples(test_filename)

Functions required to convert examples into Dataloader

In [7]:
def generate_data_loader(input_examples, label_masks, label_map, do_shuffle = False, balance_label_examples = False):
  '''
  Generate a Dataloader given the input examples, eventually masked if they are
  to be considered NOT labeled.
  '''
  examples = []

  # Count the percentage of labeled examples
  num_labeled_examples = 0
  for label_mask in label_masks:
    if label_mask:
      num_labeled_examples += 1
  label_mask_rate = num_labeled_examples/len(input_examples)

  # if required it applies the balance
  for index, ex in enumerate(input_examples):
    if label_mask_rate == 1 or not balance_label_examples:
      examples.append((ex, label_masks[index]))
    else:
      # IT SIMULATE A LABELED EXAMPLE
      if label_masks[index]:
        balance = int(1/label_mask_rate)
        balance = int(math.log(balance,2))
        if balance < 1:
          balance = 1
        for b in range(0, int(balance)):
          examples.append((ex, label_masks[index]))
      else:
        examples.append((ex, label_masks[index]))

  #-----------------------------------------------
  # Generate input examples to the Transformer
  #-----------------------------------------------
  input_ids = []
  input_mask_array = []
  label_mask_array = []
  label_id_array = []

  # Tokenization
  for (text, label_mask) in examples:
    encoded_sent = tokenizer.encode(text[0], add_special_tokens=True, max_length=max_seq_length, padding="max_length", truncation=True)
    input_ids.append(encoded_sent)
    label_id_array.append(label_map[text[1]])
    label_mask_array.append(label_mask)

  # Attention to token (to ignore padded input wordpieces)
  for sent in input_ids:
    att_mask = [int(token_id > 0) for token_id in sent]
    input_mask_array.append(att_mask)
  # Convertion to Tensor
  input_ids = torch.tensor(input_ids)
  input_mask_array = torch.tensor(input_mask_array)
  label_id_array = torch.tensor(label_id_array, dtype=torch.long)
  label_mask_array = torch.tensor(label_mask_array)

  # Building the TensorDataset
  dataset = TensorDataset(input_ids, input_mask_array, label_id_array, label_mask_array)

  if do_shuffle:
    sampler = RandomSampler
  else:
    sampler = SequentialSampler

  # Building the DataLoader
  return DataLoader(
              dataset,  # The training samples.
              sampler = sampler(dataset),
              batch_size = batch_size) # Trains with this batch size.

def format_time(elapsed):
    '''
    Takes a time in seconds and returns a string hh:mm:ss
    '''
    # Round to the nearest second.
    elapsed_rounded = int(round((elapsed)))
    # Format as hh:mm:ss
    return str(datetime.timedelta(seconds=elapsed_rounded))

Convert the input examples into DataLoader

In [11]:
label_map = {}
for (i, label) in enumerate(label_list):
  label_map[label] = i
#------------------------------
#   Load the train dataset
#------------------------------
train_examples = labeled_examples
# Normalizar os rótulos em train_examples
train_examples = [(text[0], ' '.join(text[1].split()).replace("\t", "_")) for text in train_examples]

# Atualizar label_map dinamicamente
dataset_labels = set(example[1] for example in train_examples)
for label in dataset_labels:
    if label not in label_map:
        label_map[label] = len(label_map)

print("Updated label_map:", label_map)

# Gerar o dataloader novamente
train_dataloader = generate_data_loader(train_examples, train_label_masks, label_map, do_shuffle=True, balance_label_examples=apply_balance)

#------------------------------
#   Load the test dataset
#------------------------------
#The labeled (test) dataset is assigned with a mask set to True
'''test_label_masks = np.ones(len(test_examples), dtype=bool)

test_dataloader = generate_data_loader(test_examples, test_label_masks, label_map, do_shuffle = False, balance_label_examples = False)'''

Updated label_map: {'UNK_UNK': 0, 'GHOST_ghost': 1, 'TOXIC_toxic': 2, 'CHEMI_chemi': 3, 'REGUL_regul': 4, 'TOXIC_toxic This': 5, 'GHOST_ghost In': 6, 'TOXIC_toxic In': 7, 'CHEMI_chemi These': 8, 'CHEMI_chemi This': 9, 'REGUL_regul “In': 10, 'REGUL_regul This': 11, 'REGUL_regul In': 12, 'GHOST_ghost This': 13}


  label_mask_array = torch.tensor(label_mask_array)


'test_label_masks = np.ones(len(test_examples), dtype=bool)\n\ntest_dataloader = generate_data_loader(test_examples, test_label_masks, label_map, do_shuffle = False, balance_label_examples = False)'

We define the Generator and Discriminator as discussed in https://www.aclweb.org/anthology/2020.acl-main.191/

In [None]:
'''#------------------------------
#   The Generator as in
#   https://www.aclweb.org/anthology/2020.acl-main.191/
#   https://github.com/crux82/ganbert
#------------------------------
class Generator(nn.Module):
    def __init__(self, noise_size=100, output_size=512, hidden_sizes=[512], dropout_rate=0.1):
        super(Generator, self).__init__()
        layers = []
        hidden_sizes = [noise_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])

        layers.append(nn.Linear(hidden_sizes[-1],output_size))
        self.layers = nn.Sequential(*layers)

    def forward(self, noise):
        output_rep = self.layers(noise)
        return output_rep

#------------------------------
#   The Discriminator
#   https://www.aclweb.org/anthology/2020.acl-main.191/
#   https://github.com/crux82/ganbert
#------------------------------
class Discriminator(nn.Module):
    def __init__(self, input_size=512, hidden_sizes=[512], num_labels=2, dropout_rate=0.1):
        super(Discriminator, self).__init__()
        self.input_dropout = nn.Dropout(p=dropout_rate)
        layers = []
        hidden_sizes = [input_size] + hidden_sizes
        for i in range(len(hidden_sizes)-1):
            layers.extend([nn.Linear(hidden_sizes[i], hidden_sizes[i+1]), nn.LeakyReLU(0.2, inplace=True), nn.Dropout(dropout_rate)])

        self.layers = nn.Sequential(*layers) #per il flatten
        self.logit = nn.Linear(hidden_sizes[-1],num_labels+1) # +1 for the probability of this sample being fake/real.
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, input_rep):
        input_rep = self.input_dropout(input_rep)
        last_rep = self.layers(input_rep)
        logits = self.logit(last_rep)
        probs = self.softmax(logits)
        return last_rep, logits, probs'''

We instantiate the Discriminator and Generator

In [None]:
'''# The config file is required to get the dimension of the vector produced by
# the underlying transformer
config = AutoConfig.from_pretrained(model_name)
hidden_size = int(config.hidden_size)
# Define the number and width of hidden layers
hidden_levels_g = [hidden_size for i in range(0, num_hidden_layers_g)]
hidden_levels_d = [hidden_size for i in range(0, num_hidden_layers_d)]

#-------------------------------------------------
#   Instantiate the Generator and Discriminator
#-------------------------------------------------
generator = Generator(noise_size=noise_size, output_size=hidden_size, hidden_sizes=hidden_levels_g, dropout_rate=out_dropout_rate)
discriminator = Discriminator(input_size=hidden_size, hidden_sizes=hidden_levels_d, num_labels=len(label_list), dropout_rate=out_dropout_rate)

# Put everything in the GPU if available
if torch.cuda.is_available():
  generator.cuda()
  discriminator.cuda()
  transformer.cuda()
  if multi_gpu:
    transformer = torch.nn.DataParallel(transformer)

# print(config)'''

Let's go with the training procedure

In [12]:
import time
import torch
from transformers import *

# Configuração do dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Inicialização do modelo
transformer = transformer.to(device)
transformer.train()

# Otimizador
optimizer = torch.optim.AdamW(transformer.parameters(), lr=5e-5)

# Loop de treinamento
training_stats = []
total_t0 = time.time()

for epoch_i in range(0, num_train_epochs):
    print(f"======== Epoch {epoch_i + 1} / {num_train_epochs} ========")
    print("Training...")

    t0 = time.time()
    tr_loss = 0
    correct_predictions = 0
    total_predictions = 0

    # Inicializar listas para previsões e rótulos (se necessário)
    all_preds = []
    all_labels_ids = []

    for step, batch in enumerate(train_dataloader):
        b_input_ids = batch[0].to(device)
        b_input_mask = batch[1].to(device)
        b_labels = batch[2].to(device)

        optimizer.zero_grad()

        outputs = transformer(b_input_ids, attention_mask=b_input_mask)
        logits = outputs.last_hidden_state[:, 0, :]
        loss_fn = torch.nn.CrossEntropyLoss()
        loss = loss_fn(logits, b_labels)

        loss.backward()
        optimizer.step()

        tr_loss += loss.item()

        # Cálculo da acurácia
        preds = torch.argmax(logits, dim=1)
        correct_predictions += (preds == b_labels).sum().item()
        total_predictions += b_labels.size(0)

        # Adicionar previsões e rótulos para validação posterior (se necessário)
        all_preds.extend(preds.detach().cpu())
        all_labels_ids.extend(b_labels.detach().cpu())

    # Cálculo do loss médio e acurácia
    avg_train_loss = tr_loss / len(train_dataloader)
    train_accuracy = correct_predictions / total_predictions

    all_preds = torch.stack(all_preds).numpy()
    all_labels_ids = torch.stack(all_labels_ids).numpy()
    test_accuracy = np.sum(all_preds == all_labels_ids) / len(all_preds)
    print(f"  Validation Accuracy: {test_accuracy:.3f}")
    print(f"  Average Training Loss: {avg_train_loss:.3f}")

    # Salvar estatísticas
    training_stats.append({
        'epoch': epoch_i + 1,
        'Training Loss': avg_train_loss,
        'Training Accuracy': train_accuracy
    })

# Salvando o modelo treinado
model_save_path = "bert_model.pth"
torch.save(transformer.state_dict(), model_save_path)
print(f"Model saved to {model_save_path}")


Using device: cuda
Training...
  Validation Accuracy: 0.000
  Average Training Loss: 7.409
Training...
  Validation Accuracy: 0.000
  Average Training Loss: 6.486
Training...
  Validation Accuracy: 0.125
  Average Training Loss: 6.103
Training...
  Validation Accuracy: 0.350
  Average Training Loss: 5.740
Training...
  Validation Accuracy: 0.325
  Average Training Loss: 5.324
Training...
  Validation Accuracy: 0.375
  Average Training Loss: 4.697
Training...
  Validation Accuracy: 0.475
  Average Training Loss: 3.904
Training...
  Validation Accuracy: 0.375
  Average Training Loss: 3.266
Training...
  Validation Accuracy: 0.375
  Average Training Loss: 2.895
Training...
  Validation Accuracy: 0.450
  Average Training Loss: 2.449
Model saved to bert_model.pth


In [13]:
for stat in training_stats:
  print(stat)

print("\nTraining complete!")

print("Total training took {:} (h:mm:ss)".format(format_time(time.time()-total_t0)))

{'epoch': 1, 'Training Loss': 7.409079551696777, 'Training Accuracy': 0.0}
{'epoch': 2, 'Training Loss': 6.485732078552246, 'Training Accuracy': 0.0}
{'epoch': 3, 'Training Loss': 6.102871417999268, 'Training Accuracy': 0.125}
{'epoch': 4, 'Training Loss': 5.7401814460754395, 'Training Accuracy': 0.35}
{'epoch': 5, 'Training Loss': 5.3244123458862305, 'Training Accuracy': 0.325}
{'epoch': 6, 'Training Loss': 4.69679069519043, 'Training Accuracy': 0.375}
{'epoch': 7, 'Training Loss': 3.9039218425750732, 'Training Accuracy': 0.475}
{'epoch': 8, 'Training Loss': 3.26641583442688, 'Training Accuracy': 0.375}
{'epoch': 9, 'Training Loss': 2.8946800231933594, 'Training Accuracy': 0.375}
{'epoch': 10, 'Training Loss': 2.4487547874450684, 'Training Accuracy': 0.45}

Training complete!
Total training took 0:00:06 (h:mm:ss)


In [14]:
# Salvar o modelo transformer
torch.save(transformer.state_dict(), 'transformer.pth')
print("Transformer model saved.")

# Carregar o modelo transformer
transformer.load_state_dict(torch.load('transformer.pth'))
transformer.eval()
print("Transformer model loaded and set to evaluation mode.")


Transformer model saved.


  transformer.load_state_dict(torch.load('transformer.pth'))


Transformer model loaded and set to evaluation mode.


In [15]:
def preprocess(text, max_seq_length=64):
    tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
    encoded_sent = tokenizer.encode(
        text,
        add_special_tokens=True,
        max_length=max_seq_length,
        padding="max_length",
        truncation=True,
        return_tensors="pt"
    )
    attention_mask = (encoded_sent > 0).long()
    return encoded_sent, attention_mask


In [16]:
def classify_question(text, label_list):
    transformer.eval()

    # Pré-processar o texto
    input_ids, attention_mask = preprocess(text)
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)

    with torch.no_grad():
        # Obter os logits do transformer
        outputs = transformer(input_ids, attention_mask=attention_mask)
        logits = outputs.last_hidden_state[:, 0, :]  # Usar representação do token [CLS]

        # Predição
        pred = torch.argmax(logits, dim=1).item()

    return label_list[pred]
