In [0]:
from google.colab import files
import torch
from torch.utils.data.dataset import Dataset
import numpy as np
from google.colab import drive
import torch.nn as nn
drive.mount('/content/drive')
import random

import time
device = torch.device("cuda")

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [0]:
class Data(object):
    def __init__(self):
        self.char2idx = {}
        self.idx2char = []
        
        self.EOS = '\n'
        self.padding = '\0'
        self.load_data()
        self.N_TOKENS = len(self.idx2char)

    def load_data(self):
        lines = 0
        self.idx2char.append(self.padding)
        self.char2idx[self.padding] = len(self.idx2char)-1
        tokens = 0
        with open('/content/drive/My Drive/shortjokes_noquote.txt', 'r') as data:
          for line in data:
              line = line + self.EOS
              lines = lines + 1
              tokens = tokens + len(line)
              for char in line:
                  if char not in self.char2idx:
                      self.idx2char.append(char)
                      self.char2idx[char] = len(self.idx2char)-1

        self.inputs = []
        self.ids = torch.LongTensor(tokens)
        token = 0
        with open('/content/drive/My Drive/shortjokes_noquote.txt', 'r') as data:
          for line in data:
              line = line + self.EOS
              ind = 0
              input = torch.zeros([len(line)], dtype=torch.int64)
              for char in line:
                  input[ind] = self.char2idx[char]
                  self.ids[token] = self.char2idx[char]
                  token = token+1
                  ind = ind+1
              self.inputs.append(input)
        data.close()
        print("ids ", len(self.ids), " tokens ", len(self.idx2char))
        print(len(input), " tensors.")

class JokeData(Dataset):
    def __init__(self):
        self.data = Data()
        self.EOS_token = self.data.char2idx[self.data.EOS]

    def __len__(self):
        return len(self.data.inputs)

    def __getitem__(self, idx):
        sample = self.data.inputs[idx]
        return sample

In [0]:
class WordRNN(nn.Module):
    def __init__(self, ntoken, ninp=200, nhid = 128, nlayers = 1, dropout=0.1):
        super(WordRNN, self).__init__()
        self.embedding = nn.Embedding(ntoken, ninp)
        self.lstm = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)

        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        emb = self.embedding(input)
        output, hidden = self.lstm(emb, hidden)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))

In [0]:
def batchify(data, bsz):
    nbatch = data.size(0) // bsz
    data = data.narrow(0, 0, nbatch * bsz)
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

def get_batch(source, i):
    seq_len_tmp = min(seq_len, len(source) - 1 - i)
    data = source[i:i+seq_len_tmp]
    target = source[i+1:i+1+seq_len_tmp].view(-1)
    return data, target

trainset = JokeData()
batch_size = 64
train_data = batchify(trainset.data.ids, batch_size)
seq_len = 120

ids  22111965  tokens  99
104  tensors.


In [53]:
model = WordRNN(ntoken=len(trainset.data.idx2char)).to(device)

  "num_layers={}".format(dropout, num_layers))


In [0]:
model.load_state_dict(torch.load("/content/drive/My Drive/lolchar.weight"))
model.lstm.flatten_parameters()

Training loop


In [54]:
def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)
criterion = nn.CrossEntropyLoss()
epochs = 30
lr = 20
learning_rate = 1e-4
while epochs > 0:
    optimizer = torch.optim.Adagrad(model.parameters(), lr=learning_rate)
    print("aloitetaan epoch")
    epochs = epochs-1
    model.train()
    total_loss = 0.
    start_time = time.time()
    ntokens = len(trainset.data.idx2char)
    hidden = model.init_hidden(batch_size)
    seq_len = 160+random.randint(-20, 20)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, seq_len)):
        data, targets = get_batch(train_data, i)
        #print(data, " data", data.shape)
        #print(targets, "targets ", targets.shape)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        model.zero_grad()
        optimizer.zero_grad()
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()

        # `clip_grad_norm` helps prevent the exploding gradient problem in RNNs / LSTMs.
        #torch.nn.utils.clip_grad_norm_(model.parameters(), 0.25)
        #for p in model.parameters():
        #    p.data.add_(-lr, p.grad.data)

        total_loss += loss.item()
        optimizer.step()
        if batch % 100 == 0 and batch > 0:
            cur_loss = total_loss / 100
            elapsed = time.time() - start_time
            print("cur loss ", cur_loss)
            total_loss = 0
            start_time = time.time()

aloitetaan epoch
cur loss  4.627683458328247
cur loss  4.564017457962036
cur loss  4.549938397407532
cur loss  4.535290231704712
cur loss  4.519381713867188
cur loss  4.499804530143738
cur loss  4.474087252616882
cur loss  4.434834461212159
cur loss  4.34435293674469
cur loss  3.8905296182632445
cur loss  3.6699304294586184
cur loss  3.5860158014297485
cur loss  3.5305615401268007
cur loss  3.4950987339019775
cur loss  3.4615744280815126
cur loss  3.438730595111847
cur loss  3.421661925315857
cur loss  3.4047255730628967
cur loss  3.390179512500763
cur loss  3.3789824271202087
cur loss  3.3701874589920044
cur loss  3.359983150959015
cur loss  3.354071419239044
aloitetaan epoch
cur loss  3.3481018948554992
cur loss  3.290704746246338
cur loss  3.286099274158478
cur loss  3.27209401845932
cur loss  3.2740192866325377
cur loss  3.2655307245254517
cur loss  3.261631441116333
cur loss  3.2582183933258055
cur loss  3.2574877023696898
cur loss  3.2581403398513795
cur loss  3.2522103548049928


KeyboardInterrupt: ignored

In [48]:
torch.save(model.state_dict(), "/content/drive/My Drive/lolchar4.weight")


with open("mallichar.lol", 'wb') as f:
                torch.save(model, f)

  "type " + obj.__name__ + ". It won't be checked "


Normal sampling and top-k sampling

In [55]:
words = 1000
temp = 1.0
hidden = model.init_hidden(1)
model.eval()
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
with torch.no_grad():  # no tracking history
    for i in range(words):
        output, hidden = model(input, hidden)
        word_weights = output.squeeze().div(temp).exp().cpu()
        word_idx = torch.multinomial(word_weights, 1)[0]
        input.fill_(word_idx)
        word = trainset.data.idx2char[word_idx]

        print(word, end = '')

/~rrh7H oRrRe qh: jj'es gr>at iso . thnrad an we rmehog hcpe oso btr
,e wsiteiK
gI insts to.V aanroa binnu tepse  teinsr iatn ougwer meg tellytr n
Hut ca!un a orssid ta'oh cdroMa'in vee isin aodt lea tet? jSWs wah tohsl haodr de aLeo..s ss twege en Maepsug l epe tua son?ve
oirdes
u
Dy ta des bd gesn asat EoL

Icahs, rsa aaa.n toaut  Le  onf boob  tb".MR.rio'T fce d"rherd teic nho t"re ikdliiamy  ?ovcu iy doigxma aanne
 yre.W."e dbfo too etii.o pok 
hks Fd'
 ae woe tign bmootit Aisr yn Fih toIhio ?

WhpteT ere heiuhm.a Ocbsus ye 'in hem hsy argip,e' k iba conat weha a mmeepdt" OsGN !.Nts mienJt
.haes atka the iWr hhkeir rb ne utk un dtt tiamilFxp1MWfy
"cety d ey , narstA an t
eunnnr geo sor/ af teo wlad stre wed tcar foulnt at yosmsnog
 yti Xe. fatf.Th. Ihhe Pnr  omar llysR d ralliucg ma etce aierhevg sthwu leirs  aaanou!s/olg ocosn IrSss whr wcrnM,?hy fha Aa lee,o'ev" Wede wol
natv ofd tsores.
?':pEe t aw cre htnjo tiweia'g Y
ins. du ifna tsot iwmTn a-ied, I o!u ho boaod "oi ssr ko<yt 

In [56]:
words = 1000
temp = 1.0
model.eval()
hidden = model.init_hidden(1)
topk = 40
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)
with torch.no_grad():  # no tracking history
    for i in range(words):
        output, hidden = model(input, hidden)
        weights, ind = torch.topk(output.squeeze(), topk)
        word_weights = weights.div(temp).exp().cpu()
        word_idx = ind[torch.multinomial(word_weights, 1)[0]]
        input.fill_(word_idx)
        word = trainset.data.idx2char[word_idx]

        print(word, end = '')

Gazvof

 al" sl.
SIhy. df lcwea"

 ohfI a icbrebcum pl aoute dred aare sityd tk thte sdrle nre hne.edt me  aanest wet wl nord a,lesf tnte.
"fvete'ntnn yon tocbocersd i'ite s hoe tjt ucathre elne iltlviobpeunmd bhaaigi ns ahhg 
ocis booy dri lhede ie de as s!ydwfd ."Mudtdnr thebe cdao pent tetme? neft go bvpect nandyvI 
onox whe awrded wechas achutc fruite taeni? lre hecs anocte sH
t"l yothees peiro tholbIrd ?fTanguy dwancif tp a
ki idd acky piam aalyr!
Wse" cafin mofo s faerea dooho ea  er aa eam?: 
oi. whgie chanisg hol s agh ton tlhouure
tr sW
 wih dl doye "ho mahiteH:i sida we auah,? rhnf'amertd' Ih?r dogt
icBha ato oooregot.Meiu ao Iwo tbwe cosold rf
s lo arfis'nsd si s meadpe rhtpimise riglt pacnag orthehtibs cleg nnt'. Io nsH mri.f ren, a
cweaat su ts wheeIr 
ycs aeshiws Iaet degae y.g ,is e lfils kt lemhe w ncerrn bonnist hhu nmaf"  wshia
Tt onice" udi thbllooAA ohs aetl.l?r"hk nao ?
lcoutnImem iats.h yd e tannoilh
o'ts eis di ie ovtc,s miru oo ?eh hru s.uw hns afos alr Is go ns