This notebook is adapted from several sources. I'd like to recognize the following authors for their contributions:


It is also worth identifying that the DialoGPT model is no longer maintained in favor of the GODEL model. I think.

# Setup Steps

Start with the imports, first we'll install transformers in case you haven't already

In [27]:
! pip -q install transformers

Then we'll import the libraries we'll need

In [28]:
# hide
# Imports


"""
Fine-tuning the library models for language modeling on a text file (GPT, GPT-2, BERT, RoBERTa).
GPT and GPT-2 are fine-tuned using a causal language modeling (CLM) loss while BERT and RoBERTa are fine-tuned
using a masked language modeling (MLM) loss.
"""

import glob
import logging
import os
import pickle
import random
import re
import shutil
from typing import Dict, List, Tuple

import pandas as pd
import numpy as np
import torch
import data_context


from datasets import load_dataset
from sklearn.model_selection import train_test_split

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
from torch.utils.data.distributed import DistributedSampler
from tqdm.notebook import tqdm, trange

from pathlib import Path

from transformers import (
    MODEL_WITH_LM_HEAD_MAPPING,
    WEIGHTS_NAME,
    AdamW,
    AutoConfig,
    AutoModelWithLMHead,
    AutoTokenizer,
    PreTrainedModel,
    PreTrainedTokenizer,
    get_linear_schedule_with_warmup,
)


try:
    from torch.utils.tensorboard import SummaryWriter
except ImportError:
    from tensorboardX import SummaryWriter

# Configs
logger = logging.getLogger(__name__)

MODEL_CONFIG_CLASSES = list(MODEL_WITH_LM_HEAD_MAPPING.keys())
MODEL_TYPES = tuple(conf.model_type for conf in MODEL_CONFIG_CLASSES)

Next, we'll define a series of arguments so that we can run this notebook more easily

In [29]:
# collapse
# Args to allow for easy convertion of python script to notebook
class Args():
    def __init__(self):
        self.output_dir = 'output'
        self.model_type = 'gpt2'
        self.model_name_or_path = 'microsoft/DialoGPT-medium'
        self.config_name = 'microsoft/DialoGPT-medium'
        self.tokenizer_name = 'microsoft/DialoGPT-medium'
        self.cache_dir = 'cached'
        self.block_size = 512
        self.do_train = True
        self.do_eval = True
        self.evaluate_during_training = False
        self.per_gpu_train_batch_size = 4
        self.per_gpu_eval_batch_size = 4
        self.gradient_accumulation_steps = 1
        self.learning_rate = 5e-5
        self.weight_decay = 0.0
        self.adam_epsilon = 1e-8
        self.max_grad_norm = 1.0
        self.num_train_epochs = 3
        self.max_steps = -1
        self.warmup_steps = 0
        self.logging_steps = 1000
        self.save_steps = 3500
        self.save_total_limit = None
        self.eval_all_checkpoints = False
        self.no_cuda = False
        self.overwrite_output_dir = True
        self.overwrite_cache = True
        self.should_continue = False
        self.seed = 42
        self.local_rank = -1
        self.fp16 = False
        self.fp16_opt_level = 'O1'

args = Args()

# Data Preparation

In this step, we'll take in a Hugging Face dataset and prepare it for training. We'll also define a function to tokenize the data. To work properly, your data should be in some sort of format that looks like this:

| speaker | text |
|---------|------|
| A       | Hi   |
| B       | Hello|
| A       | How are you?|
| B       | I'm good, how are you?|

And so on so forth. We'll identify the speaker you want to use for fine tuning, and build the appropriate context.

Now we'll import the data and do some preprocessing to put it into the format we want. Follow the pop-up prompts. Depending on your editor you may need to keep scrolling this cell to see the prompts.

In [32]:
trn_df, val_df = data_context.data_setup()


Welcome to this Hugging Face GPT-Medium model trainer. 
It will give you step-by-step prompts to prepare your data for fine-tuning. 
A couple things to note: spelling counts, and you need to use Hugging Face to accomplish this.


Using custom data configuration andrewkroening--Star-wars-scripts-dialogue-IV-VI-ade7d2bca9d9e7f4
Found cached dataset json (/home/codespace/.cache/huggingface/datasets/andrewkroening___json/andrewkroening--Star-wars-scripts-dialogue-IV-VI-ade7d2bca9d9e7f4/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

Here is a sample of the dataset: 
              Character                                               Line
427                LEIA  The more you tighten your grip, Tarkin, the mo...
1532               YODA                                       Concentrate!
1134  SENIOR CONTROLLER  No. Wait -- there's something very weak coming...
1167               LUKE                                          You, too.
2366              VADER                                    His lightsaber.
1345                HAN  I'm going to shut down everything but the emer...
2131               LUKE                                 Hi, Han... Chewie.
406             TROOPER                All right, men.  Load your weapons!
1151              VADER                               You found something?
183                LUKE  Yes, sir.  I think those new droids are going ...

The available columns are: Index(['Character', 'Line'], dtype='object')

The available characters are: 
LUKE        494
HAN         459
THRE

  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.append(
  quotes_context_df = quotes_context_df.


Here is a sample of the context: 
    character                                              quote  \
0       VADER     Where are those transmissions you intercepted?   
1       VADER  If this is a consular ship... where is the Amb...   
2       VADER  Commander, tear this ship apart until you've f...   
3       VADER  Don't play games with me, Your Highness.  You ...   
4       VADER  You're a part of the Rebel Alliance... and a t...   
..        ...                                                ...   
135     VADER                     You cannot hide forever, Luke.   
136     VADER  Give yourself to the dark side. It is the only...   
137     VADER  Sister! So...you have a twin sister. Your feel...   
138     VADER                  Luke, help me take this mask off.   
139     VADER  Nothing can stop that now. Just for once... le...   

                                             context/0  \
0    The Death Star plans are not in the main compu...   
1    We intercepted no transmiss

In [None]:
trn_df.sample(10)