In [1]:
!git clone https://github.com/karpathy/minGPT.git

Cloning into 'minGPT'...
remote: Enumerating objects: 175, done.[K
remote: Total 175 (delta 0), reused 0 (delta 0), pack-reused 175[K
Receiving objects: 100% (175/175), 1.37 MiB | 7.70 MiB/s, done.
Resolving deltas: 100% (101/101), done.


In [1]:
%cd minGPT/



/content/minGPT


In [2]:
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

In [3]:
from mingpt.utils import set_seed
set_seed(42)

import numpy as np
import torch
import torch.nn as nn
from torch.nn import functional as F

In [4]:
from torch.utils.data import Dataset

class CharDataset(Dataset):

    def __init__(self, data, block_size):
        chars = sorted(list(set(data)))
        data_size, vocab_size = len(data), len(chars)
        print('data has %d characters, %d unique.' % (data_size, vocab_size))
        
        self.stoi = { ch:i for i,ch in enumerate(chars) }
        self.itos = { i:ch for i,ch in enumerate(chars) }
        self.block_size = block_size
        self.vocab_size = vocab_size
        self.data = data
    
    def __len__(self):
        return len(self.data) - self.block_size

    def __getitem__(self, idx):
        # grab a chunk of (block_size + 1) characters from the data
        chunk = self.data[idx:idx + self.block_size + 1]
        # encode every character to an integer
        dix = [self.stoi[s] for s in chunk]
        x = torch.tensor(dix[:-1], dtype=torch.long)
        y = torch.tensor(dix[1:], dtype=torch.long)
        return x, y

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [5]:
block_size = 6
with open('/content/poke.txt', 'rb') as f:
  text=f.read() 
train_dataset = CharDataset(text, block_size)

data has 7815 characters, 32 unique.


In [6]:
from mingpt.model import GPT, GPTConfig
mconf = GPTConfig(train_dataset.vocab_size, train_dataset.block_size,
                  n_layer=6, n_head=6, n_embd=174)
model = GPT(mconf)

04/03/2021 10:07:50 - INFO - mingpt.model -   number of parameters: 2.205972e+06


In [7]:

from mingpt.trainer import Trainer, TrainerConfig

# initialize a trainer instance and kick off training
tconf = TrainerConfig(max_epochs=1, batch_size=6, learning_rate=3e-4,
                      lr_decay=True, warmup_tokens=20, final_tokens=len(train_dataset),
                      num_workers=4)
trainer = Trainer(model, train_dataset, None, tconf)
trainer.train()

  cpuset_checked))
epoch 1 iter 1301: train loss 2.39866. lr 2.998780e-04: 100%|██████████| 1302/1302 [00:55<00:00, 23.42it/s]


In [37]:
from mingpt.utils import sample

context = " "
x = torch.tensor([train_dataset.stoi[10]], dtype=torch.long)[None,...].to(trainer.device)
y = sample(model, x, 4000, temperature=1.0, sample=True, top_k=10)[0]
completion = ''.join([chr(train_dataset.itos[int(i)]) for i in y])
pokemons=[i for i in completion.replace('\r','').split('\n') if len(i)>=5]


mulad
magale
pimate
garenige
popet
tearam
ciche
misitis
stostr
belematun
gelia
sshetenigos
reluroshea
panat
belale
guleraro
scolicar
crisar
ctera
rictar
cimape
galeolunite
deoshan
bratanc
conione
tionaske
byruchas
caschir
daschessom
perate
selenl
moreetic
piomeoot
pargartinick
tlosa
ctapitin
cosater
terulan
ristincon
bucorete
lianc
dameronortitoan
betralitet
letunglinaret
gumitar
dinchee
ceospul
cerga
blrutasat
mioncat
ririme
gumala
raritint
lomon
buorl
alrealine
shicto
teshol
reondust
derrlos
shiler
garcas
seranc
marilee
titreton
sorgorad
dargicea
pigusculime
denal
lomim
byameal
runarin
silic
guglos
brileior
dataniga
tioncite
grotianch
lacyas
ardict
blere
tarereoool
cereron
sinimtar
reamaral
gorinitit
shesa
mashice
lerlecares
colit
rasilcanein
diliotor
petrit
selettinea
torer
raril
patolol
cichy
domin
porliaros
mutlet
lesheonos
bascem
raregu
stuole
mergaaron
shoshint
pancat
pencos
moopy
bligon
coneros
gerargl
lelan
rorigel
peonia
rlaron
puarmtlo
camergame
meralina
mochye
bemem
lecolot

In [42]:
for i in pokemons:print(i)

mulad

magale

pimate

garenige

popet

tearam

ciche

misitis

stostr

belematun

gelia

sshetenigos

reluroshea

panat

belale

guleraro

scolicar

crisar

ctera

rictar

cimape

galeolunite

deoshan

bratanc

conione

tionaske

byruchas

caschir

daschessom

perate

selenl

moreetic

piomeoot

pargartinick

tlosa

ctapitin

cosater

terulan

ristincon

bucorete

lianc

dameronortitoan

betralitet

letunglinaret

gumitar

dinchee

ceospul

cerga

blrutasat

mioncat

ririme

gumala

raritint

lomon

buorl

alrealine

shicto

teshol

reondust

derrlos

shiler

garcas

seranc

marilee

titreton

sorgorad

dargicea

pigusculime

denal

lomim

byameal

runarin

silic

guglos

brileior

dataniga

tioncite

grotianch

lacyas

ardict

blere

tarereoool

cereron

sinimtar

reamaral

gorinitit

shesa

mashice

lerlecares

colit

rasilcanein

diliotor

petrit

selettinea

torer

raril

patolol

cichy

domin

porliaros

mutlet

lesheonos

bascem

raregu

stuole

mergaaron

shoshint

pancat

penc

In [41]:
for i in range(len(pokemons)):
  pokemons[i]=pokemons[i]+'\n'

In [45]:
with open('pokemon.txt','w') as f:
  for i in pokemons:f.write(i)