<a href="https://colab.research.google.com/github/ekonishi8524/my-colab-notebooks/blob/main/soseki_gpt_transformer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.nn import TransformerDecoderLayer, TransformerDecoder
from torch.utils.data import DataLoader, TensorDataset
from transformers import BertTokenizer, BertJapaneseTokenizer

In [None]:
import re
import requests

def load_aozora_text(url):
    response = requests.get(url)
    response.encoding = 'shift_jis'
    text = response.text

    parts = re.split(r'\-{5,}', text)

    if len(parts) > 2:
        main_content = parts[2]
    else:
        main_content = text

    main_content = re.split(r'底本 : ', main_content)[0]
    cleaned_text = main_content

    cleaned_text = re.sub(r' 《.*?》', '', cleaned_text)
    cleaned_text = re.sub(r'［.*?］', '', cleaned_text)
    cleaned_text = re.sub(r'<[^>]*?>', '', cleaned_text)
    cleaned_text = re.sub(r'\r\n|\r|\n', '', cleaned_text)
    cleaned_text = re.sub(r'　', '', cleaned_text)
    cleaned_text = re.sub(r' ', '', cleaned_text)
    cleaned_text = re.sub(r'[^\u3040-\u30ff\u3400-\u4DBF\u4E00\
    -\u9FFF\uF900-\uFAFF\uFF00-\uFFEF\u3000-\u300F\u30FB-\u30FC]',\
                          '', cleaned_text)
    cleaned_text = re.sub(r'（.*?）', '', cleaned_text)

    print(f"[DEBUG] Length of cleaned text data: {len(cleaned_text)}\
     characters")

    return cleaned_text

In [None]:
import math

class TransformerDecoderOnlyModel(nn.Module):
    def __init__(self, vocab_size, d_model, num_heads, num_layers, \
                 max_seq_len=64, dropout=0.1):
        super(TransformerDecoderOnlyModel, self).__init__()
        self.d_model = d_model
        self.embedding = nn.Embedding(vocab_size, d_model)
        self.pos_encoder = PositionalEncoding(d_model, dropout, max_seq_len)

        decoder_layer = nn.TransformerDecoderLayer(d_model, num_heads,\
                                                   dim_feedforward=d_model*4,\
                                                   dropout=dropout,\
                                                   batch_first=True)
        self.transformer_decoder = nn.TransformerDecoder(decoder_layer,\
                                                         num_layers)
        self.fc_out = nn.Linear(d_model, vocab_size)

    def forward(self, tgt, tgt_mask):
        tgt = self.embedding(tgt) * math.sqrt(self.d_model)
        tgt = self.pos_encoder(tgt)

        output = self.transformer_decoder(tgt, memory=tgt, tgt_mask=tgt_mask,\
                                          memory_mask=tgt_mask)
        output = self.fc_out(output)
        return output

In [None]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=512):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)
        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * \
         (-math.log(10000.0) / d_model))
        pe = torch.zeros(max_len, 1, d_model)
        pe[:, 0, 0::2] = torch.sin(position * div_term)
        pe[:, 0, 1::2] = torch.cos(position * div_term)
        self.pe = pe.permute(1, 0, 2)

    def forward(self, x):
        x = x + self.pe[:, :x.size(1), :].to(x.device)
        return self.dropout(x)

In [None]:
def generate_sample_text(model, tokenizer, prompt_text, device,\
                         max_length_gen=30, temperature=0.7, top_k=50):
    model.eval()
    inputs = tokenizer(prompt_text, add_special_tokens=False, \
                       return_tensors="pt").to(device)
    input_ids = inputs['input_ids']
    generated_ids = input_ids

    for _ in range(max_length_gen):
        with torch.no_grad():
           current_length = generated_ids.size(1)
           tgt_mask = nn.Transformer.generate_square_subsequent_mask\
            (current_length).to(device)

           outputs = model(generated_ids, tgt_mask)
           next_token_logits = outputs[:, -1, :]

           if temperature > 0:
               next_token_logits = next_token_logits / temperature

           if top_k > 0:
               top_k_value = min(top_k, next_token_logits.size(-1))
               indices_to_remove = next_token_logits < \
               torch.topk(next_token_logits, top_k_value)[0][..., -1, None]
               next_token_logits[indices_to_remove] = -float('Inf')

           probabilities = torch.softmax(next_token_logits, dim=-1)
           next_token_id = torch.multinomial(probabilities, num_samples=1)

           generated_ids = torch.cat([generated_ids, next_token_id], dim=-1)

    output_text = tokenizer.decode(generated_ids.squeeze(0),\
                                   skip_special_tokens=True)
    model.train()
    return output_text

In [None]:
from tqdm import tqdm

if __name__ == '__main__':
  print("Loading data...")

  try:
      import fugashi
  except ImportError:
      !pip install fugashi
      import fugashi

  try:
      import unidic_lite
  except ImportError:
      !pip install unidic-lite
      import unidic_lite

  aozora_urls = [
      "https://www.aozora.gr.jp/cards/000148/files/773_14560.html", # こころ
      "https://www.aozora.gr.jp/cards/000148/files/789_14547.html", # 吾輩は猫である
      "https://www.aozora.gr.jp/cards/000148/files/794_14946.html", # 三四郎
      "https://www.aozora.gr.jp/cards/000148/files/56143_50921.html", # それから
      "https://www.aozora.gr.jp/cards/000148/files/785_14971.html" # 門
  ]

  tokenizer = BertJapaneseTokenizer.from_pretrained\
   ('cl-tohoku/bert-base-japanese-whole-word-masking')
  MAX_LENGTH = 64
  STRIDE = 64

  all_input_chunks = []

  print("Loading and processing data from multiple URLs...")
  for url in aozora_urls:
      try:
          text = load_aozora_text(url)
          print(f"Processing {len(text)} characters from {url}")

          full_encoded = tokenizer(text, return_tensors='pt',\
                                   truncation=False, padding=False)
          input_ids_file = full_encoded['input_ids'].squeeze(0)

          for i in range(0, len(input_ids_file) - MAX_LENGTH + 1, STRIDE):
              chunk = input_ids_file[i:i + MAX_LENGTH]
              all_input_chunks.append(chunk)

      except Exception as e:
          print(f"Error processing {url}: {e}")

  input_ids_tensor = torch.stack(all_input_chunks)

  input_ids = input_ids_tensor[:, :-1]
  labels = input_ids_tensor[:, 1:]

  train_dataset = TensorDataset(input_ids, labels)
  train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

  print(f"[DEBUG] Total number of training samples (dataset size):\
   {len(train_dataset)}")
  print(f"[DEBUG] Number of batches in DataLoader: {len(train_loader)}")

  VOCAB_SIZE = tokenizer.vocab_size
  D_MODEL = 256
  NUM_HEADS = 8
  NUM_LAYERS = 4
  MAX_LENGTH = 64

  model = TransformerDecoderOnlyModel(VOCAB_SIZE, D_MODEL, NUM_HEADS,\
                                      NUM_LAYERS, max_seq_len=MAX_LENGTH)

  criterion = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
  optimizer = optim.Adam(model.parameters(), lr=0.0001)

  print("Starting training...")
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
  model.to(device)

  NUM_EPOCHS = 200
  for epoch in range(NUM_EPOCHS):
      model.train()
      total_loss = 0
      for batch_idx, (inputs, targets) in tqdm(enumerate(train_loader),\
                                               desc=f"Epoch\
                                                {epoch+1}/{NUM_EPOCHS}"):
          inputs, targets = inputs.to(device), targets.to(device)

          optimizer.zero_grad()

          tgt_mask = nn.Transformer.generate_square_subsequent_mask\
           (inputs.size(1)).to(device)
          outputs = model(inputs, tgt_mask)
          outputs = outputs.permute(0, 2, 1)

          loss = criterion(outputs, targets)
          loss.backward()
          optimizer.step()
          total_loss += loss.item()

      avg_loss = total_loss / len(train_loader)
      print(f"Epoch [{epoch+1}/{NUM_EPOCHS}] Loss: {avg_loss:.4f}")
      sample_prompt = "先生は私に"
      generated_sample = generate_sample_text(model, tokenizer,\
                                              sample_prompt, device)
      print(f"Output: {generated_sample}\n")
  print("Training finished.")