In [4]:
import torch
from torch import nn
from transformers import AutoModel, AutoTokenizer

In [5]:
class PhoBertForTextGeneration(nn.Module):
    def __init__(self, phobert_model):
        super(PhoBertForTextGeneration, self).__init__()
        self.phobert = phobert_model
        self.config = phobert_model.config  # Lưu cấu hình từ mô hình gốc
        self.linear = nn.Linear(self.config.hidden_size, self.config.vocab_size)
    
    def forward(self, input_ids, attention_mask=None):
        outputs = self.phobert(input_ids=input_ids, attention_mask=attention_mask)
        sequence_output = outputs[0]
        prediction_scores = self.linear(sequence_output)
        return prediction_scores

In [7]:
model_path = 'phobert_text_generation_model.pth'  # Cập nhật đường dẫn của mô hình
phobert_model = AutoModel.from_pretrained('vinai/phobert-base')
tokenizer = AutoTokenizer.from_pretrained('vinai/phobert-base')
model = PhoBertForTextGeneration(phobert_model)
model.load_state_dict(torch.load(model_path))
model.eval()

Some weights of the model checkpoint at vinai/phobert-base were not used when initializing RobertaModel: ['lm_head.dense.weight', 'lm_head.bias', 'lm_head.layer_norm.weight', 'lm_head.decoder.weight', 'lm_head.decoder.bias', 'lm_head.dense.bias', 'lm_head.layer_norm.bias']
- This IS expected if you are initializing RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


PhoBertForTextGeneration(
  (phobert): RobertaModel(
    (embeddings): RobertaEmbeddings(
      (word_embeddings): Embedding(64001, 768, padding_idx=1)
      (position_embeddings): Embedding(258, 768, padding_idx=1)
      (token_type_embeddings): Embedding(1, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): RobertaEncoder(
      (layer): ModuleList(
        (0-11): 12 x RobertaLayer(
          (attention): RobertaAttention(
            (self): RobertaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): RobertaSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerN

In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

In [9]:
def generate_text(model, start_text, max_length=50):
    model.eval()
    with torch.no_grad():
        # Mã hóa chuỗi đầu vào
        input_ids = tokenizer.encode(start_text, return_tensors='pt').to(device)

        # Sinh văn bản
        for _ in range(max_length):
            outputs = model(input_ids=input_ids)
            predictions = outputs[0]

            # Chọn từ tiếp theo (ví dụ: từ có xác suất cao nhất)
            predicted_id = torch.argmax(predictions[0, -1, :]).unsqueeze(0)

            # Thêm từ vừa dự đoán vào chuỗi
            input_ids = torch.cat([input_ids, predicted_id.unsqueeze(0)], dim=1)

            # Kiểm tra xem đã kết thúc chuỗi chưa
            if predicted_id.item() == tokenizer.eos_token_id:
                break

        # Chuyển đổi ID token thành văn bản
        generated_text = tokenizer.decode(input_ids[0], skip_special_tokens=True)
        return generated_text


In [10]:
start_text = "Đây là một ví dụ về"
generated_text = generate_text(model, start_text)
print(generated_text)

IndexError: too many indices for tensor of dimension 2