In [1]:
#comment this if you are not using AIT proxy...
import os
os.environ['http_proxy']  = 'http://192.41.170.23:3128'
os.environ['https_proxy'] = 'http://192.41.170.23:3128'

In [2]:
# coding: utf-8
import argparse
import time
import math
import os
import torch
import torch.nn as nn
import torch.onnx
from tqdm import tqdm
from statistics import mean
import math

In [3]:
import torch.nn as nn
import torch
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2TokenizerFast

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(device)

#make our work comparable if restarted the kernel
SEED = 1234
torch.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

  from .autonotebook import tqdm as notebook_tqdm


cuda:0


In [4]:
tokenizer = GPT2TokenizerFast.from_pretrained('gpt2')
PAD_TOKEN = '<pad>'
tokenizer.add_special_tokens({'pad_token': PAD_TOKEN})
tokenizer

GPT2TokenizerFast(name_or_path='gpt2', vocab_size=50257, model_max_length=1024, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<|endoftext|>', 'eos_token': '<|endoftext|>', 'unk_token': '<|endoftext|>', 'pad_token': '<pad>'}, clean_up_tokenization_spaces=True)

In [5]:
ntokens = tokenizer.vocab_size
ntokens

50257

In [6]:
PAD_TOKEN_ID = tokenizer.encode(PAD_TOKEN)[0]
BOS_TOKEN_ID = tokenizer.encode(tokenizer.bos_token)[0]
PAD_TOKEN_ID, BOS_TOKEN_ID

(50257, 50256)

In [7]:
from datasets import load_dataset
raw_dataset = load_dataset('wikitext', 'wikitext-2-raw-v1')
raw_dataset

Found cached dataset wikitext (/home/todsavadt/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126)
100%|██████████| 3/3 [00:00<00:00, 1637.55it/s]


DatasetDict({
    test: Dataset({
        features: ['text'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['text'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['text'],
        num_rows: 3760
    })
})

In [8]:
def tokenize_function(example):
    outputs =  tokenizer(example['text'], truncation=True, padding='max_length')
    input_batch = []
    for input_ids in outputs["input_ids"]:
        input_batch.append(input_ids)
    return {"input_ids": input_batch}

tokenized_datasets = raw_dataset.map(
    tokenize_function, batched=True, remove_columns=raw_dataset["train"].column_names
)
tokenized_datasets

Loading cached processed dataset at /home/todsavadt/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-7aa8bc10b189e3f8.arrow
Loading cached processed dataset at /home/todsavadt/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-02db874b365b79c1.arrow
Loading cached processed dataset at /home/todsavadt/.cache/huggingface/datasets/wikitext/wikitext-2-raw-v1/1.0.0/a241db52902eaf2c6aa732210bead40c090019a499ceb13bcbfa3f8ab646a126/cache-11f4e6db4ab29da3.arrow


DatasetDict({
    test: Dataset({
        features: ['input_ids'],
        num_rows: 4358
    })
    train: Dataset({
        features: ['input_ids'],
        num_rows: 36718
    })
    validation: Dataset({
        features: ['input_ids'],
        num_rows: 3760
    })
})

In [9]:
len(tokenized_datasets['train']['input_ids'][1]) #longest token

1024

In [10]:
from transformers import AutoTokenizer, GPT2LMHeadModel, AutoConfig
context_length  = 1024
config = AutoConfig.from_pretrained(
    "gpt2",
    vocab_size=len(tokenizer),
    n_ctx=context_length,
    bos_token_id=tokenizer.bos_token_id,
    eos_token_id=tokenizer.eos_token_id,
)
model = GPT2LMHeadModel(config)

In [11]:
model_size = sum(t.numel() for t in model.parameters())
print(f"GPT-2 size: {model_size/1000**2:.1f}M parameters")

GPT-2 size: 124.4M parameters


In [12]:
from torch.utils.data.dataloader import DataLoader
batch_size = 8
tokenized_datasets.set_format("torch")
train_dataloader = DataLoader(tokenized_datasets["train"], batch_size=batch_size, shuffle=True)
eval_dataloader  = DataLoader(tokenized_datasets["validation"], batch_size=batch_size)
test_dataloader  = DataLoader(tokenized_datasets["test"], batch_size=batch_size)

In [13]:
for i in train_dataloader:
    i['input_ids']
    print(i['input_ids'].shape)
    break
for i in eval_dataloader:
    print(i['input_ids'].shape)
    break
for i in test_dataloader:
    print(i['input_ids'].shape)
    break

torch.Size([8, 1024])
torch.Size([8, 1024])
torch.Size([8, 1024])


In [14]:
from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=5e-5)

In [15]:
from transformers import get_scheduler

num_train_epochs = 1
num_update_steps_per_epoch = len(train_dataloader)
num_training_steps = num_train_epochs * num_update_steps_per_epoch

lr_scheduler = get_scheduler(
    name="linear",
    optimizer=optimizer,
    num_warmup_steps=1_000,
    num_training_steps=num_training_steps,
)

In [16]:
from accelerate import Accelerator

accelerator = Accelerator()

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)
model = model.to(device)

In [17]:
def train():
    progress_bar = tqdm(range(num_training_steps))
    gradient_accumulation_steps = 8
    model.to(device)
    eval_steps = 2
    model.train()
    completed_steps = 0
    for epoch in range(num_train_epochs):
        total_loss = 0
        for step, batch in enumerate(train_dataloader):
            inputs = batch['input_ids'].to(device)
            labels = batch['input_ids'].to(device)
    
            optimizer.zero_grad()
            outputs = model(inputs, labels=labels)
            logits = outputs.logits      
            loss = outputs.loss
            loss = loss / gradient_accumulation_steps
            # print(loss)
            accelerator.backward(loss) #instance of optimize.backward()
    
            if step % gradient_accumulation_steps == 0:
                accelerator.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                completed_steps += 1
            
            if (step % (eval_steps * gradient_accumulation_steps)) == 0:
                eval_loss, perplexity = evaluate()
                accelerator.print({"loss/eval": eval_loss, "perplexity": perplexity})
                model.train()
                
            progress_bar.update(1)
            total_loss += loss.item()
    
        avg_loss = total_loss / len(train_loader)
        print(f'Epoch {epoch + 1}/{epochs} - Average Loss: {avg_loss:.4f}')

In [18]:
from tqdm.auto import tqdm
num_update_steps_per_epoch = len(eval_dataloader)

def evaluate():
    model.eval()
    losses = []
    progress_bar = tqdm(range(num_update_steps_per_epoch))
    for step, batch in enumerate(eval_dataloader):
        with torch.no_grad():
            inputs = batch["input_ids"].to(device)
            labels = batch["input_ids"].to(device)
            outputs = model(inputs, labels=labels)
            outputs.loss = outputs.loss.reshape(1)
            progress_bar.update(1)
        losses.append(accelerator.gather(outputs.loss))        
    loss = torch.mean(torch.cat(losses))
    try:
        perplexity = torch.exp(loss)
    except OverflowError:
        perplexity = float("inf")
    return loss.item(), perplexity.item()

In [19]:
evaluate()

100%|██████████| 470/470 [01:42<00:00,  4.59it/s]


(9.97258472442627, 21430.806640625)

In [19]:
train()

  0%|          | 0/4590 [00:00<?, ?it/s]
  0%|          | 0/470 [00:00<?, ?it/s][A
  0%|          | 1/470 [00:00<01:21,  5.79it/s][A
  0%|          | 2/470 [00:00<01:32,  5.04it/s][A
  1%|          | 3/470 [00:00<01:36,  4.86it/s][A
  1%|          | 4/470 [00:00<01:37,  4.77it/s][A
  1%|          | 5/470 [00:01<01:38,  4.72it/s][A
  1%|▏         | 6/470 [00:01<01:38,  4.70it/s][A
  1%|▏         | 7/470 [00:01<01:39,  4.67it/s][A
  2%|▏         | 8/470 [00:01<01:39,  4.66it/s][A
  2%|▏         | 9/470 [00:01<01:39,  4.65it/s][A
  2%|▏         | 10/470 [00:02<01:39,  4.65it/s][A
  2%|▏         | 11/470 [00:02<01:38,  4.65it/s][A
  3%|▎         | 12/470 [00:02<01:38,  4.64it/s][A
  3%|▎         | 13/470 [00:02<01:38,  4.64it/s][A
  3%|▎         | 14/470 [00:02<01:38,  4.64it/s][A
  3%|▎         | 15/470 [00:03<01:38,  4.64it/s][A
  3%|▎         | 16/470 [00:03<01:37,  4.64it/s][A
  4%|▎         | 17/470 [00:03<01:37,  4.64it/s][A
  4%|▍         | 18/470 [00:03<01:37,  4.

{'loss/eval': 9.97258472442627, 'perplexity': 21430.806640625}


  0%|          | 16/4590 [01:54<1:01:55,  1.23it/s]  
  0%|          | 0/470 [00:00<?, ?it/s][A
  0%|          | 1/470 [00:00<01:22,  5.67it/s][A
  0%|          | 2/470 [00:00<01:34,  4.93it/s][A
  1%|          | 3/470 [00:00<01:38,  4.74it/s][A
  1%|          | 4/470 [00:00<01:40,  4.65it/s][A
  1%|          | 5/470 [00:01<01:40,  4.60it/s][A
  1%|▏         | 6/470 [00:01<01:41,  4.57it/s][A
  1%|▏         | 7/470 [00:01<01:41,  4.56it/s][A
  2%|▏         | 8/470 [00:01<01:41,  4.55it/s][A
  2%|▏         | 9/470 [00:01<01:41,  4.54it/s][A
  2%|▏         | 10/470 [00:02<01:41,  4.53it/s][A
  2%|▏         | 11/470 [00:02<01:41,  4.53it/s][A
  3%|▎         | 12/470 [00:02<01:41,  4.53it/s][A
  3%|▎         | 13/470 [00:02<01:40,  4.53it/s][A
  3%|▎         | 14/470 [00:03<01:40,  4.53it/s][A
  3%|▎         | 15/470 [00:03<01:40,  4.53it/s][A
  3%|▎         | 16/470 [00:03<01:40,  4.52it/s][A
  4%|▎         | 17/470 [00:03<01:40,  4.52it/s][A
  4%|▍         | 18/470 [00:

{'loss/eval': 9.842355728149414, 'perplexity': 18813.984375}


  1%|          | 32/4590 [03:48<1:01:42,  1.23it/s] 
  0%|          | 0/470 [00:00<?, ?it/s][A
  0%|          | 1/470 [00:00<01:22,  5.67it/s][A
  0%|          | 2/470 [00:00<01:34,  4.93it/s][A
  1%|          | 3/470 [00:00<01:38,  4.73it/s][A
  1%|          | 4/470 [00:00<01:40,  4.65it/s][A
  1%|          | 5/470 [00:01<01:41,  4.60it/s][A
  1%|▏         | 6/470 [00:01<01:41,  4.58it/s][A
  1%|▏         | 7/470 [00:01<01:41,  4.56it/s][A
  2%|▏         | 8/470 [00:01<01:41,  4.55it/s][A
  2%|▏         | 9/470 [00:01<01:41,  4.55it/s][A
  2%|▏         | 10/470 [00:02<01:41,  4.54it/s][A
  2%|▏         | 11/470 [00:02<01:41,  4.54it/s][A
  3%|▎         | 12/470 [00:02<01:40,  4.54it/s][A
  3%|▎         | 13/470 [00:02<01:40,  4.53it/s][A
  3%|▎         | 14/470 [00:03<01:40,  4.53it/s][A
  3%|▎         | 15/470 [00:03<01:40,  4.53it/s][A
  3%|▎         | 16/470 [00:03<01:40,  4.53it/s][A
  4%|▎         | 17/470 [00:03<01:40,  4.53it/s][A
  4%|▍         | 18/470 [00:0

{'loss/eval': 9.540742874145508, 'perplexity': 13915.28125}


  1%|          | 48/4590 [05:43<1:01:25,  1.23it/s] 
  0%|          | 0/470 [00:00<?, ?it/s][A
  0%|          | 1/470 [00:00<01:22,  5.68it/s][A
  0%|          | 2/470 [00:00<01:34,  4.94it/s][A
  1%|          | 3/470 [00:00<01:38,  4.74it/s][A
  1%|          | 4/470 [00:00<01:40,  4.66it/s][A
  1%|          | 5/470 [00:01<01:40,  4.61it/s][A
  1%|▏         | 6/470 [00:01<01:41,  4.59it/s][A
  1%|▏         | 7/470 [00:01<01:41,  4.57it/s][A
  2%|▏         | 8/470 [00:01<01:41,  4.56it/s][A
  2%|▏         | 9/470 [00:01<01:41,  4.55it/s][A
  2%|▏         | 10/470 [00:02<01:41,  4.55it/s][A
  2%|▏         | 11/470 [00:02<01:41,  4.54it/s][A
  3%|▎         | 12/470 [00:02<01:40,  4.54it/s][A
  3%|▎         | 13/470 [00:02<01:40,  4.54it/s][A
  3%|▎         | 14/470 [00:03<01:40,  4.54it/s][A
  3%|▎         | 15/470 [00:03<01:40,  4.54it/s][A
  3%|▎         | 16/470 [00:03<01:39,  4.54it/s][A
  4%|▎         | 17/470 [00:03<01:39,  4.54it/s][A
  4%|▍         | 18/470 [00:0

{'loss/eval': 9.071001052856445, 'perplexity': 8699.328125}


  1%|▏         | 64/4590 [07:37<1:01:12,  1.23it/s] 
  0%|          | 0/470 [00:00<?, ?it/s][A
  0%|          | 1/470 [00:00<01:22,  5.69it/s][A
  0%|          | 2/470 [00:00<01:34,  4.94it/s][A
  1%|          | 3/470 [00:00<01:38,  4.74it/s][A
  1%|          | 4/470 [00:00<01:40,  4.66it/s][A
  1%|          | 5/470 [00:01<01:40,  4.61it/s][A
  1%|▏         | 6/470 [00:01<01:41,  4.59it/s][A
  1%|▏         | 7/470 [00:01<01:41,  4.57it/s][A
  2%|▏         | 8/470 [00:01<01:41,  4.56it/s][A
  2%|▏         | 9/470 [00:01<01:41,  4.55it/s][A
  2%|▏         | 10/470 [00:02<01:41,  4.54it/s][A
  2%|▏         | 11/470 [00:02<01:41,  4.54it/s][A
  3%|▎         | 12/470 [00:02<01:40,  4.54it/s][A
  3%|▎         | 13/470 [00:02<01:40,  4.53it/s][A
  3%|▎         | 14/470 [00:03<01:40,  4.53it/s][A
  3%|▎         | 15/470 [00:03<01:40,  4.53it/s][A
  3%|▎         | 16/470 [00:03<01:40,  4.53it/s][A
  4%|▎         | 17/470 [00:03<01:39,  4.53it/s][A
  4%|▍         | 18/470 [00:0

{'loss/eval': 8.443418502807617, 'perplexity': 4644.40478515625}


  2%|▏         | 80/4590 [09:31<1:01:00,  1.23it/s] 
  0%|          | 0/470 [00:00<?, ?it/s][A
  0%|          | 1/470 [00:00<01:22,  5.70it/s][A
  0%|          | 2/470 [00:00<01:34,  4.94it/s][A
  1%|          | 3/470 [00:00<01:38,  4.75it/s][A
  1%|          | 4/470 [00:00<01:39,  4.67it/s][A
  1%|          | 5/470 [00:01<01:40,  4.62it/s][A
  1%|▏         | 6/470 [00:01<01:41,  4.59it/s][A
  1%|▏         | 7/470 [00:01<01:41,  4.57it/s][A
  2%|▏         | 8/470 [00:01<01:41,  4.56it/s][A
  2%|▏         | 9/470 [00:01<01:41,  4.55it/s][A
  2%|▏         | 10/470 [00:02<01:41,  4.55it/s][A
  2%|▏         | 11/470 [00:02<01:41,  4.54it/s][A
  3%|▎         | 12/470 [00:02<01:40,  4.54it/s][A
  3%|▎         | 13/470 [00:02<01:40,  4.54it/s][A
  3%|▎         | 14/470 [00:03<01:40,  4.54it/s][A
  3%|▎         | 15/470 [00:03<01:40,  4.54it/s][A
  3%|▎         | 16/470 [00:03<01:40,  4.53it/s][A
  4%|▎         | 17/470 [00:03<01:39,  4.53it/s][A
  4%|▍         | 18/470 [00:0