<a href="https://colab.research.google.com/github/m37335/kanagawa-exam/blob/master/finetuningedModel_Intextbook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **中学校採択教科書を用いたファインチューニング**

## **ライブラリのインストールとインポート**

In [None]:
!pip install transformers
!pip install pytorch-lightning

## **ファインチューニング用コーパスの読み込み**

In [2]:
import urllib.request
txt_url = "https://raw.githubusercontent.com/m37335/kanagawa-exam/master/data/textbook.txt"
urllib.request.urlretrieve(txt_url, 'train.txt')

('train.txt', <http.client.HTTPMessage at 0x7feee858a610>)

## **モデルとTokenizerの読み込み**

In [3]:
import pytorch_lightning as pl
from argparse import Namespace
from transformers import (
    BertConfig,
    BertForMaskedLM,
    BertTokenizer,
    AdamW,
    DataCollatorForLanguageModeling
)
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset

args = Namespace()
args.train = "train.txt"
args.max_len = 128
args.model_name = "bert-base-uncased"
args.epochs = 1
args.batch_size = 4

### **データセットの作成**

In [4]:
tokenizer = BertTokenizer.from_pretrained(args.model_name)

class MaskedLMDataset(Dataset):
    def __init__(self, file, tokenizer):
        self.tokenizer = tokenizer
        self.lines = self.load_lines(file)
        self.ids = self.encode_lines(self.lines)
        
    def load_lines(self, file):
        with open(file) as f:
            lines = [
                line
                for line in f.read().splitlines()
                if (len(line) > 0 and not line.isspace())
            ]
        return lines
    
    def encode_lines(self, lines):
        batch_encoding = self.tokenizer(
            lines, add_special_tokens=True, truncation=True, max_length=args.max_len
        )
        return batch_encoding["input_ids"]

    def __len__(self):
        return len(self.lines)

    def __getitem__(self, idx):
        return torch.tensor(self.ids[idx], dtype=torch.long)
        
train_dataset = MaskedLMDataset(args.train, tokenizer)

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=231508.0, style=ProgressStyle(descripti…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=28.0, style=ProgressStyle(description_w…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=466062.0, style=ProgressStyle(descripti…




## **Trainerの設定**

In [11]:
data_collator = DataCollatorForLanguageModeling(
    tokenizer=tokenizer, mlm=True, mlm_probability=0.15
)

train_loader = DataLoader(
    train_dataset,
    batch_size=2,
    collate_fn=data_collator
)

### **最適化アルゴリズム**

In [12]:
class Bert(pl.LightningModule):

    def __init__(self):
        super().__init__()
        self.bert = BertForMaskedLM.from_pretrained(args.model_name)

    def forward(self, input_ids, labels):
        return self.bert(input_ids=input_ids,labels=labels)

    def training_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        labels = batch["labels"]
        outputs = self(input_ids=input_ids, labels=labels)
        loss = outputs[0]
        return {"loss": loss}

    def configure_optimizers(self):
        return AdamW(self.parameters(), lr=1e-5)

model = Bert()

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=570.0, style=ProgressStyle(description_…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=440473133.0, style=ProgressStyle(descri…




Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### **モデルの訓練とファインチューニング**

In [20]:
trainer = pl.Trainer(max_epochs=1, gpus=1)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores


In [None]:
trainer = pl.Trainer(gpus=1)
trainer.fit(model, train_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type            | Params
-----------------------------------------
0 | bert | BertForMaskedLM | 109 M 
-----------------------------------------
109 M     Trainable params
0         Non-trainable params
109 M     Total params
438.057   Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

### **モデルの保存**

In [None]:
from google.colab import drive
drive.mount("/content/drive/")

In [None]:
import torch
model_path = '/content/drive/My Drive/0004_【BERT】テキスト分類/bert_nlp/fintuninged_model.bin'
torch.save(model.state_dict(), model_path)

In [1]:
import torch
torch.save(model.state_dict(), 'saved.bin')

NameError: ignored

## **モデルの評価**

In [None]:
class BertPred(nn.Module):
    def __init__(self):
        super().__init__()
        self.bert = BertForMaskedLM.from_pretrained('bert-base-uncased')

    def forward(self, input_ids, labels=None):
        return self.bert(input_ids=input_ids,labels=labels)

new_model = BERTPred()
new_model.load_state_dict(torch.load('saved.bin'))
new_model.eval()

In [None]:
from transformers import BertTokenizer, BertForMaskedLM
import torch

tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

inputs = tokenizer("I'm learning about [MASK] Japanese events like Setsubun and Hinamatsuri for my speech next week.", return_tensors="pt")

labels = tokenizer("I'm learning about traditional Japanese events like Setsubun and Hinamatsuri for my speech next week.", return_tensors="pt")["input_ids"]

outputs = model(input_ids=input_ids, labels=None)
loss = outputs.loss
logits = outputs.logits