In [2]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Not connected to a GPU')
else:
  print(gpu_info)

Thu Jun 30 01:45:20 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0    26W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [None]:
!pip install einops
!pip install torchtyping
!pip install transformers
!pip install datasets
!pip install GPUtil
!pip install tqdm
!pip install jsonlines

In [4]:
import os
os.environ['CUDA_LAUNCH_BLOCKING'] = "1"

In [5]:
from gpt2 import GPT2
import torch
from torch import nn
from torch import optim
import transformers
from datasets import load_dataset
from torch.utils.data import Dataset
from torch.utils.data.dataloader import DataLoader
from tqdm import tqdm
import time
import random
from torch.nn import functional as F
import math

In [6]:
seed = 42
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

In [7]:
# init TokenDataset class for Dataloader
class TokenDataset(Dataset):
    def __init__(self, data, block_size):
        self.block_size = block_size
        self.data = data

    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        dix = self.data[idx : idx + self.block_size + 1]
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y

In [8]:
block_size = 512

In [9]:
# load opentext dataset
dataset = load_dataset("stas/openwebtext-10k", split="train")
print("dataset before tokenization:", dataset)

tokenizer = transformers.GPT2Tokenizer.from_pretrained("gpt2")

def encode(example):
    return tokenizer(example["text"])

dataset = dataset.map(encode, batched=True)
print("dataset after tokenization:", dataset)
print(dataset['input_ids'][:1])

Downloading builder script:   0%|          | 0.00/3.08k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.33k [00:00<?, ?B/s]

Downloading and preparing dataset openwebtext-10k/plain_text (download: 14.04 MiB, generated: 47.37 MiB, post-processed: Unknown size, total: 61.41 MiB) to /root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b...


Downloading data:   0%|          | 0.00/14.7M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/10000 [00:00<?, ? examples/s]

Dataset openwebtext-10k downloaded and prepared to /root/.cache/huggingface/datasets/stas___openwebtext-10k/plain_text/1.0.0/3a8df094c671b4cb63ed0b41f40fb3bd855e9ce2e3765e5df50abcdfb5ec144b. Subsequent calls will reuse this data.
dataset before tokenization: Dataset({
    features: ['text'],
    num_rows: 10000
})


Downloading:   0%|          | 0.00/0.99M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]



  0%|          | 0/10 [00:00<?, ?ba/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1094 > 1024). Running this sequence through the model will result in indexing errors


dataset after tokenization: Dataset({
    features: ['text', 'input_ids', 'attention_mask'],
    num_rows: 10000
})
[[32, 7093, 10327, 351, 281, 2939, 286, 35333, 11908, 290, 262, 3670, 705, 464, 791, 46155, 4897, 6, 318, 17915, 287, 11307, 13, 1400, 1099, 20075, 564, 250, 5308, 259, 45327, 69, 447, 251, 287, 4486, 11, 475, 262, 1230, 286, 37313, 10312, 11, 6622, 262, 6634, 290, 10942, 340, 11354, 420, 6819, 13, 357, 22405, 5613, 14, 2200, 14974, 8, 198, 198, 464, 1748, 326, 373, 262, 3641, 286, 35333, 11908, 447, 247, 82, 13735, 318, 41664, 351, 40687, 286, 262, 12267, 1613, 11, 422, 262, 10492, 10421, 326, 6028, 262, 29324, 286, 867, 6832, 284, 262, 289, 377, 3364, 6026, 701, 86, 21223, 10043, 326, 783, 2156, 262, 15007, 9475, 13, 198, 198, 2061, 340, 1595, 447, 247, 83, 423, 11, 4249, 468, 340, 1201, 15761, 11, 389, 9088, 286, 11908, 447, 247, 82, 41795, 290, 1964, 29960, 11, 564, 250, 5308, 259, 45327, 69, 11, 447, 251, 287, 663, 1492, 43409, 13, 383, 3452, 2230, 284, 7715, 39146, 

In [10]:
# preprocess into one big tensor of tokens
def flatten(t):
    return [item for sublist in t for item in sublist]

tokens = torch.Tensor(flatten(dataset["input_ids"])).long()
print("tokenized train_data shape:", tokens.shape)
print(tokens[:10])

train_dataset = TokenDataset(tokens, block_size)

tokenized train_data shape: torch.Size([11243054])
tensor([   32,  7093, 10327,   351,   281,  2939,   286, 35333, 11908,   290])


In [11]:
# init trainloader for training loop
batch_size = 8
train_loader = DataLoader(
    train_dataset, shuffle=True, pin_memory=True, batch_size=batch_size
)
print("train loader:", train_loader)

train loader: <torch.utils.data.dataloader.DataLoader object at 0x7f2c0767bd50>


In [12]:
# training loop
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("device:", device)

model = GPT2(
    num_layers=12,
    num_heads=12,
    vocab_size=50257,
    hidden_size=768,
    max_position_embeddings=block_size,
    dropout=0.1,
    layer_norm_epsilon=1e-5,
).to(device).train()

# path to save weights during training
ckpt_path = './drive/MyDrive/gpt2_model.pt'

# load if resuming training runs
checkpoint = torch.load(ckpt_path)
model.load_state_dict(checkpoint)

loss_fn = nn.CrossEntropyLoss()

learning_rate = 2.5e-4
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

device: cuda
number of parameters: 124046592


In [None]:
max_epochs = 1

#counter used for lr decay
num_tokens = 0
#warmup_tokens = 2000 * batch_size * block_size
final_tokens = 100000 * batch_size * block_size

#use this for resuming old runs
start_iter = 518000 

for epoch in range(max_epochs):
    pbar = tqdm(enumerate(train_loader, start=start_iter), total=len(train_loader))
    for it, (x, y) in pbar:
        x = x.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        gpt_output = model(x)
        loss = loss_fn(gpt_output.logits.view(-1, gpt_output.logits.size(-1)), y.view(-1))
        loss.backward()
        optimizer.step()

        pbar.set_description(f"epoch {epoch+1} iter {it}: train loss {loss.item():.5f}")
        if it % 1000 == 0:
            print(f"iter {it}. train loss {loss.item():.5f}")
            torch.save(model.state_dict(), ckpt_path)


In [14]:
# helper functions to eval model
def top_k_logits(logits, k):
    v, ix = torch.topk(logits, k)
    out = logits.clone()
    out[out < v[:, [-1]]] = -float('Inf')
    return out

@torch.no_grad()
def sample(model, x, steps, temperature=1.0, sample=False, top_k=None):
    """
    take a conditioning sequence of indices in x (of shape (b,t)) and predict the next token in
    the sequence, feeding the predictions back into the model each time. Clearly the sampling
    has quadratic complexity unlike an RNN that is only linear, and has a finite context window
    of block_size, unlike an RNN that has an infinite context window.
    """
    block_size = model.get_block_size()
    model.eval()
    for k in range(steps):
        x_cond = x if x.size(1) <= block_size else x[:, -block_size:] # crop context if needed
        logits = model(x_cond).logits
        # pluck the logits at the final step and scale by temperature
        logits = logits[:, -1, :] / temperature
        # optionally crop probabilities to only the top k options
        if top_k is not None:
            logits = top_k_logits(logits, top_k)
        # apply softmax to convert to probabilities
        probs = F.softmax(logits, dim=-1)
        # sample from the distribution or take the most likely
        if sample:
            ix = torch.multinomial(probs, num_samples=1)
        else:
            _, ix = torch.topk(probs, k=1, dim=-1)
        # append to the sequence and continue
        x = torch.cat((x, ix), dim=1)

    return x

In [32]:
context = "Miley Cyrus was caught shoplifting from Abercrombie and Fitch on Hollywood Boulevard today."
x = torch.tensor(tokenizer.encode(context), dtype=torch.long).unsqueeze(0).to(device)
y = sample(model, x, 100, temperature=1.0, sample=True, top_k=10)[0]
completion = tokenizer.decode(y)
print(completion)

Miley Cyrus was caught shoplifting from Abercrombie and Fitch on Hollywood Boulevard today.

According to police, a woman called 911 reported that when she contacted her then asked her if she kept the tickets from the cab driver, she found her number. She found the number two times.

 paramedic: She is not wearing a helmet, is you is.

She is just throwing away her calf in a bare dress.

Shina, Justina said she was trying to collect her driver when she heard the entire squad numbers of the party members.




In [17]:
# lambada acc eval
def preprocess(text):
    text = text.replace("“", '"')
    text = text.replace("”", '"')
    text = text.replace("''", '"')
    text = text.replace("``", '"')
    return "\n" + text.strip()


stopwords = {
    "ourselves",
    "hers",
    "between",
    "yourself",
    "but",
    "again",
    "there",
    "about",
    "once",
    "during",
    "out",
    "very",
    "having",
    "with",
    "they",
    "own",
    "an",
    "be",
    "some",
    "for",
    "do",
    "its",
    "yours",
    "such",
    "into",
    "of",
    "most",
    "itself",
    "other",
    "off",
    "is",
    "s",
    "am",
    "or",
    "who",
    "as",
    "from",
    "him",
    "each",
    "the",
    "themselves",
    "until",
    "below",
    "are",
    "we",
    "these",
    "your",
    "his",
    "through",
    "don",
    "nor",
    "me",
    "were",
    "her",
    "more",
    "himself",
    "this",
    "down",
    "should",
    "our",
    "their",
    "while",
    "above",
    "both",
    "up",
    "to",
    "ours",
    "had",
    "she",
    "all",
    "no",
    "when",
    "at",
    "any",
    "before",
    "them",
    "same",
    "and",
    "been",
    "have",
    "in",
    "will",
    "on",
    "does",
    "yourselves",
    "then",
    "that",
    "because",
    "what",
    "over",
    "why",
    "so",
    "can",
    "did",
    "not",
    "now",
    "under",
    "he",
    "you",
    "herself",
    "has",
    "just",
    "where",
    "too",
    "only",
    "myself",
    "which",
    "those",
    "i",
    "after",
    "few",
    "whom",
    "t",
    "being",
    "if",
    "theirs",
    "my",
    "against",
    "a",
    "by",
    "doing",
    "it",
    "how",
    "further",
    "was",
    "here",
    "than",
}

In [21]:
import jsonlines

correct, total = 0, 0
with jsonlines.open("lambada_test.jsonl") as reader:
    for obj in reader:
        text = preprocess(obj["text"])
        tokens = torch.tensor(tokenizer.encode(text)[:-1], dtype=torch.long).unsqueeze(0).to(device)
        final_token = tokenizer.encode(text)[-1] 
        gpt_output = model(tokens)
        _, line_encoded_candidates = torch.topk(gpt_output.logits[:,-1,:], k=20, dim=-1)
        line_encoded_candidates = list(line_encoded_candidates[0])
        predicted = None
        for candidate in line_encoded_candidates:
            if not (tokenizer.decode(candidate).strip() in stopwords):
                predicted = candidate
                break

        assert predicted is not None
        total += 1
        if predicted.item() == final_token:
            correct += 1
        if total % 1000 == 0:
            print(f"examples {total}, accuracy: {correct / total}")

print(f"final accuracy: {correct / total}")

examples 1000, accuracy: 0.094
examples 2000, accuracy: 0.0885
examples 3000, accuracy: 0.08633333333333333
examples 4000, accuracy: 0.085
examples 5000, accuracy: 0.0838
final accuracy: 0.0836405977100718
