## Understand GPT-2 model

- BERT, 分类模型；
- GPT-2, 生成模型；[Hugging Face Transformers/GPT2 Documents](https://huggingface.co/docs/transformers/en/model_doc/gpt2)

In [None]:
# Load model to local
from transformers import AutoModelForCausalLM,AutoTokenizer

# model_name = "uer/gpt2-chinese-lyric"
model_name = "uer/gpt2-chinese-poem"
cache_dir = "../local_models"
AutoModelForCausalLM.from_pretrained(model_name, cache_dir=cache_dir)
AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)

In [None]:
# Try the GPT2 model
from transformers import GPT2LMHeadModel, BertTokenizer, TextGenerationPipeline

# model_path="../local_models/models--uer--gpt2-chinese-cluecorpussmall/snapshots/c2c0249d8a2731f269414cc3b22dff021f8e07a3"
# model_path="../local_models/models--uer--gpt2-chinese-lyric/snapshots/4a42fd76daab07d9d7ff95c816160cfb7c21684f"
model_path="../local_models/models--uer--gpt2-chinese-poem/snapshots/6335c88ef6a3362dcdf2e988577b7bafeda6052b"
model = GPT2LMHeadModel.from_pretrained(model_path)
tokenizer = BertTokenizer.from_pretrained(model_path)
text_generator = TextGenerationPipeline(model=model, tokenizer=tokenizer,device="cpu")

prompt = "中文GPT2大规模预训练模型"
output = text_generator(prompt, max_length=100, do_sample=True)

print(model)
print(output)

## Train a GPT-2 base model to be a Poem model:

### Step 1: Load the dataset

In [None]:
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import torch

class PoemDataset(Dataset):
    def __init__(self, file_path):
        # Here, we are just reading the file. You can add custom pre-processing here
        with open(file_path, encoding="utf-8") as f:
            text = f.readlines()
        text = [i.strip() for i in text]
        self.text = text

    def __len__(self):
        return len(self.text)

    def __getitem__(self, item):
        return self.text[item]

dataset_train = PoemDataset(file_path="../local_datasets/Poem/chinese_poems.txt")
for data in dataset_train[:5]:
    print(data)

model_path="../local_models/models--uer--gpt2-chinese-cluecorpussmall/snapshots/c2c0249d8a2731f269414cc3b22dff021f8e07a3"
tokenizer = AutoTokenizer.from_pretrained(model_path)
def collate_fn(data):
    data = tokenizer.batch_encode_plus(
        data,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    )
    data["labels"] = data["input_ids"].clone()
    return data

dataloader = torch.utils.data.DataLoader(
    dataset_train,
    batch_size=4,
    shuffle=True,
    drop_last=True,
    collate_fn=collate_fn
)
print(f"Dataset length: {len(dataset_train)}")

### Step 2: Training the model

- BERT, Incremental training model;
- GPT-2, Full training model;

In [None]:
from transformers import AutoModelForCausalLM, AdamW
from transformers.optimization import get_scheduler
import torch

model_path="../local_models/models--uer--gpt2-chinese-cluecorpussmall/snapshots/c2c0249d8a2731f269414cc3b22dff021f8e07a3"
model = AutoModelForCausalLM.from_pretrained(model_path)

def run_train():
    global model
    DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
    EPOCH = 30000
    model.to(DEVICE)

    optimizer = AdamW(model.parameters(), lr=2e-5) # lr 2e-5 - 5e-5
    scheduler = get_scheduler(
        "linear",
        optimizer=optimizer,
        num_warmup_steps=0,
        num_training_steps=len(dataloader)
    )

    model.train()
    for epoch in range(EPOCH):
        for i, data in enumerate(dataloader):
            for k in data.keys():
                data[k] = data[k].to(DEVICE)
            outputs = model(**data)
            loss = outputs.loss
            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            optimizer.step()
            scheduler.step()

            # Reset gradients
            optimizer.zero_grad()
            model.zero_grad()

            if i % 100 == 0:
                labels = data["labels"][:, 1:].contiguous() # Target
                out = outputs["logits"].argmax(dim=2)[:, :-1].contiguous() # Predictions
                select = labels != 0 # Select all tokens that are not <PAD>
                labels = labels[select]
                out = out[select]
                del select
                accuracy = (labels == out).sum().item() / labels.numel()
                lr = optimizer.state_dict()["param_groups"][0]["lr"]

                print(f"Epoch: {epoch}, Iteration: {i}, Loss: {loss.item()}, lr: {lr}, Accuracy: {accuracy}")

        torch.save(model.state_dict(), "params/model.pt")
        print("Model saved!")

# Trigger the training
# run_train()