In [6]:
from transformers import GPT2LMHeadModel
import torch

In [7]:
if torch.cuda.is_available():
    print(f"PyTorch is using GPU: {torch.cuda.get_device_name(0)}")
else:
    print("PyTorch is not using GPU.")

PyTorch is using GPU: NVIDIA GeForce RTX 3070


In [8]:
model_hf = GPT2LMHeadModel.from_pretrained("gpt2")  # 124M model
sd_hf = model_hf.state_dict()  # raw tensors

for k, v in sd_hf.items():
    print(k, v.shape)

transformer.wte.weight torch.Size([50257, 768])
transformer.wpe.weight torch.Size([1024, 768])
transformer.h.0.ln_1.weight torch.Size([768])
transformer.h.0.ln_1.bias torch.Size([768])
transformer.h.0.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.0.attn.c_attn.bias torch.Size([2304])
transformer.h.0.attn.c_proj.weight torch.Size([768, 768])
transformer.h.0.attn.c_proj.bias torch.Size([768])
transformer.h.0.ln_2.weight torch.Size([768])
transformer.h.0.ln_2.bias torch.Size([768])
transformer.h.0.mlp.c_fc.weight torch.Size([768, 3072])
transformer.h.0.mlp.c_fc.bias torch.Size([3072])
transformer.h.0.mlp.c_proj.weight torch.Size([3072, 768])
transformer.h.0.mlp.c_proj.bias torch.Size([768])
transformer.h.1.ln_1.weight torch.Size([768])
transformer.h.1.ln_1.bias torch.Size([768])
transformer.h.1.attn.c_attn.weight torch.Size([768, 2304])
transformer.h.1.attn.c_attn.bias torch.Size([2304])
transformer.h.1.attn.c_proj.weight torch.Size([768, 768])
transformer.h.1.attn.c_proj.bias 

Layer 1
transformer.wte.weight torch.Size([50257, 768])
--> 50257 token vocabulary in a 768 dimensional embedding

Layer 2
transformer.wpe.weight torch.Size([1024, 768])
--> 1024 context length, 768 dimensional embedding

In [10]:
from transformers import pipeline, set_seed

generator = pipeline('text-generation', model='gpt2')
generator("Hello ...")

Device set to use cuda:0
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


[{'generated_text': "Hello ... It is not necessary to do this... but we have to ensure that the system is properly functioning. If the system is broken, the user can only log in by the following methods. This procedure is performed so that the user can log in in the future.\n\nMethod 1...\n\nMethod 2...\n\n...\n\nAnd finally...\n\nmethod 3...\n\n...\n\n...\n\nThe only way to change the user's password is to use the following methods.\n\nMethod 4...\n\n...\n\n...\n\n...\n\nMethod 5...\n\n...\n\n...\n\nMethod 6...\n\n...\n\n...\n\n...\n\nMethod 7...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\nNote that you can change the password by using the following methods.\n\nMethod 1...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n...\n\n..."}]

In [12]:
with open("tinyshakespeare.txt", "r") as f:
    text = f.read()

data = text[:1000]

In [16]:
import tiktoken
enc = tiktoken.get_encoding("gpt2")
tokens = enc.encode(data)
print(tokens[:24])   

[5962, 22307, 25, 198, 8421, 356, 5120, 597, 2252, 11, 3285, 502, 2740, 13, 198, 198, 3237, 25, 198, 5248, 461, 11, 2740, 13]


In [None]:
import torch

buf = torch.tensor(tokens[:24 + 1]) #we add one more token to have the label for the last token
x = buf.view(4, 6)  # input to the transformer
print(x)

tensor([[ 5962, 22307,    25,   198,  8421,   356],
        [ 5120,   597,  2252,    11,  3285,   502],
        [ 2740,    13,   198,   198,  3237,    25],
        [  198,  5248,   461,    11,  2740,    13]])


In [None]:

buf = torch.tensor(tokens[:24 + 1]) #we add one more token to have the label for the last token
x = buf[:-1].view(4,6)  # input to the transformer
y = buf[1:].view(4,6)  # label for the transformer
print(x)
print(y)

tensor([[ 5962, 22307,    25,   198,  8421,   356],
        [ 5120,   597,  2252,    11,  3285,   502],
        [ 2740,    13,   198,   198,  3237,    25],
        [  198,  5248,   461,    11,  2740,    13]])
tensor([[22307,    25,   198,  8421,   356,  5120],
        [  597,  2252,    11,  3285,   502,  2740],
        [   13,   198,   198,  3237,    25,   198],
        [ 5248,   461,    11,  2740,    13,   198]])
