In [0]:
import numpy as np
import os
import json
import random
import torch
import torch.nn as nn
from torch import Tensor
from urllib.request import urlopen  # Py3

from IPython.core.display import display, HTML

## Text library

In [0]:
# TextLibrary class: text library for training, encoding, batch generation,
# and formatted source display
class TextLibrary:
    def __init__(self, descriptors, max=100000000):
        self.descriptors=descriptors
        self.data=''
        self.files=[]
        self.c2i = {}
        self.i2c = {}
        index = 1
        for descriptor in descriptors:
            fd={}
            if descriptor[:4] == 'http':
                try:
                    dat = urlopen(descriptor).read().decode('utf-8')
                    self.data += dat
                    fd["name"] = descriptor
                    fd["data"] = dat
                    fd["index"] = index
                    index += 1
                    self.files.append(fd)
                except Exception as e:
                    print(f"Can't download {descriptor}: {e}")
            else:
                fd["name"] = os.path.splitext(os.path.basename(descriptor))[0]
                try:
                    f = open(descriptor)
                    dat = f.read(max)
                    self.data += dat
                    fd["data"] = dat
                    fd["index"] = index
                    index += 1
                    self.files.append(fd)
                    f.close()
                except Exception as e:
                    print(f"ERROR: Cannot read: {filename}: {e}")
        ind = 0
        for c in self.data:  # sets are not deterministic
            if c not in self.c2i:
                self.c2i[c] = ind
                self.i2c[ind] = c
                ind += 1
        self.ptr = 0
            
    def print_colored_IPython(self, textlist, pre='', post=''):
        bgcolors = ['#d4e6f1', '#d8daef', '#ebdef0', '#eadbd8', '#e2d7d5', '#edebd0',
                    '#ecf3cf', '#d4efdf', '#d0ece7', '#d6eaf8', '#d4e6f1', '#d6dbdf',
                    '#f6ddcc', '#fae5d3', '#fdebd0', '#e5e8e8', '#eaeded', '#A9CCE3']
        out = ''
        for txt, ind in textlist:
            txt = txt.replace('\n','<br>')
            if ind==0:
                out += txt
            else:
                out += "<span style=\"background-color:"+bgcolors[ind%16]+";\">" + txt +\
                       "</span>"+"<sup>[" + str(ind) + "]</sup>"
        display(HTML(pre+out+post))
        
    def source_highlight(self, txt, minQuoteSize=10):
        tx = txt
        out = []
        qts = []
        txsrc=[("Sources: ", 0)]
        sc=False
        noquote = ''
        while len(tx)>0:  # search all library files for quote 'txt'
            mxQ = 0
            mxI = 0
            mxN = ''
            found = False
            for f in self.files:  # find longest quote in all texts
                p = minQuoteSize
                if p<=len(tx) and tx[:p] in f["data"]:
                    p = minQuoteSize + 1
                    while p<=len(tx) and tx[:p] in f["data"]:
                        p += 1
                    if p-1>mxQ:
                        mxQ = p-1
                        mxI = f["index"]
                        mxN = f["name"]
                        found = True
            if found:  # save longest quote for colorizing
                if len(noquote)>0:
                    out.append((noquote, 0))
                    noquote = ''
                out.append((tx[:mxQ],mxI))
                tx = tx[mxQ:]
                if mxI not in qts:  # create a new reference, if first occurence
                    qts.append(mxI)
                    if sc:
                        txsrc.append((", ", 0))
                    sc = True
                    txsrc.append((mxN,mxI))
            else:
                noquote += tx[0]
                tx = tx[1:]
        if len(noquote)>0:
            out.append((noquote, 0))
            noquote = ''
        self.print_colored_IPython(out)
        if len(qts)>0:  # print references, if there is at least one source
            self.print_colored_IPython(txsrc, pre="<small><p style=\"text-align:right;\">",
                                     post="</p></small>")
    
    def get_slice(self, length):
        if (self.ptr + length >= len(self.data)):
            self.ptr = 0
        if self.ptr == 0:
            rewind = True
        else:
            rewind = False
        sl = self.data[self.ptr:self.ptr+length]
        self.ptr += length
        return sl, rewind
    
    def decode(self, ar):
         return ''.join([self.i2c[ic] for ic in ar])
            
    def get_random_slice(self, length):
        p = random.randrange(0,len(self.data)-length)
        sl = self.data[p:p+length]
        return sl
    
    def get_slice_array(self, length):
        ar = np.array([c for c in self.get_slice(length)[0]],dtype=int)
        return ar
        
    def get_sample(self, length):
        s, rewind = self.get_slice(length+1)
        X = np.array([self.c2i[c] for c in s[:-1]],dtype=int)
        y = np.array([self.c2i[c] for c in s[1:]],dtype=int)
        return (X, y, rewind)
    
    def get_random_sample(self, length):
        s = self.get_random_slice(length+1)
        X = np.array([self.c2i[c] for c in s[:-1]],dtype=int)
        y = np.array([self.c2i[c] for c in s[1:]],dtype=int)
        return (X, y)
    
    def get_sample_batch(self, batch_size, length):
        smpX = np.zeros((batch_size,length),dtype=int)
        smpy = np.zeros((batch_size,length),dtype=int)
        for i in range(batch_size):
            smpX[i,:], smpy[i,:], _ = self.get_sample(length)
        return smpX, smpy
        
    def get_random_sample_batch(self, batch_size, length):
        smpX = np.zeros((batch_size,length),dtype=int)
        smpy = np.zeros((batch_size,length),dtype=int)
        for i in range(batch_size):
            smpX[i,:], smpy[i,:] = self.get_random_sample(length)
        return smpX, smpy

## Model parameters and data sources


In [0]:
libdesc = {
    "name": "TinyShakespeare",
    "description": "Shakespeare's collected works from project Gutenberg",
    "lib": [
        'http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/1/0/100/100-0.txt',
    ]
}

textlib = TextLibrary(libdesc["lib"])

use_shakespeare = False

if use_shakespeare or not os.path.exists('data/lib.json'):
    # Default model parameters for shakespeare:
    model_params_shakespeare = {
        "model_name": "lib",
        "vocab_size": len(textlib.i2c),
        "neurons": 256,
        "layers": 2,
        "learning_rate": 1.e-3,
        "steps": 80,
        "batch_size": 128
    }
    model_params = model_params_shakespeare
else:        
    # Look for optional json description of a library:
    with open('data/lib.json') as data_file:    
        libdesc = json.load(data_file)
        textlib = TextLibrary(libdesc["lib"])
        model_params_lib = {
            "model_name": "lib",
            "vocab_size": len(textlib.i2c),
            "neurons": 512,
            "layers": 4,
            "learning_rate": 2.e-4,
            "steps": 80,
            "batch_size": 128
        }
        model_params = model_params_lib


In [0]:
def one_hot(p, dim):
    o=np.zeros(p.shape+(dim,), dtype=int)
    for y in range(p.shape[0]):
        for x in range(p.shape[1]):
            o[y,x,p[y,x]]=1
    return o

In [5]:
batch_size = model_params['batch_size']
vocab_size = model_params['vocab_size']
steps = model_params['steps']

force_cpu=False

if torch.cuda.is_available() and force_cpu is not True:
    device='cuda'
    use_cuda = True
    print("Running on GPU")
else:
    device='cpu'
    use_cuda = False
    print("Running on CPU")
    print("Note: on Google Colab, make sure to select:")
    print("      Runtime / Change Runtime Type / Hardware accelerator: GPU")

def get_data():
    X, y=textlib.get_random_sample_batch(batch_size, steps)
    Xo = one_hot(X, vocab_size)
    
    # Xt = Tensor(torch.from_numpy(np.array(Xo,dtype=np.float32)), requires_grad=False, dtype=torch.float32, device=device)
    # yt = Tensor(torch.from_numpy(y), requires_grad=False, dtype=torch.int32, device=device)
    Xt = Tensor(torch.from_numpy(np.array(Xo,dtype=np.float32))).to(device)
    Xt.requires_grad_(False)
    yt = torch.LongTensor(torch.from_numpy(np.array(y,dtype=np.int64))).to(device)
    yt.requires_grad_(False)
    return Xt, yt

Running on GPU


In [0]:
def show_gpu_mem(context="all"):
    if use_cuda:
        print("[{}] Memory allocated: {} max_alloc: {} cached: {} max_cached: {}".format(context,torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated(), torch.cuda.memory_cached(), torch.cuda.max_memory_cached()))


## The char-rnn model (deep LSTMs)

In [0]:
class Poet(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, device):
        super(Poet, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.device=device
        
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=0)
        
        self.demb = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=-1)  # negative dims are a recent thing (as 2018-03), remove for old vers.
    
    def init_hidden(self, batch_size):
        self.h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=self.device)
        self.c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=self.device)

    def forward(self, inputx, steps):
        self.lstm.flatten_parameters()
        hn, (self.h0, self.c0) = self.lstm(inputx.to(self.device), (self.h0, self.c0))
        hnr = hn.contiguous().view(-1,self.hidden_size)
        op = self.demb(hnr)
        opr = op.view(-1, steps ,self.output_size)
        return opr

    def generate(self, n, start=None):
        s=''
        torch.set_grad_enabled(False)
        if start==None or len(start)==0:
            start=' '
        self.init_hidden(1)
        for c in start:
            X=np.array([[textlib.c2i[c]]])
            Xo=one_hot(X,self.output_size)
            Xt = Tensor(torch.from_numpy(np.array(Xo,dtype=np.float32))).to(self.device)
            ypl = self.forward(Xt,1)
            ypl2 = ypl.view(-1,self.output_size)
            yp = self.softmax(ypl2)
        for i in range(n):
            ypc=Tensor.cpu(yp.detach()) # .cpu()
            y_pred=ypc.numpy()
            inds=list(range(self.output_size))
            ind = np.random.choice(inds, p=y_pred.ravel())
            s=s+textlib.i2c[ind]
            X=np.array([[ind]])
            Xo=one_hot(X,self.output_size)
            Xt = Tensor(torch.from_numpy(np.array(Xo,dtype=np.float32))).to(self.device)
            ypl = self.forward(Xt,1)
            ypl2 = ypl.view(-1,self.output_size)
            yp = self.softmax(ypl2)
        torch.set_grad_enabled(True)
        return s    

## Create a poet

In [0]:
poet = Poet(vocab_size, model_params['neurons'], model_params['layers'], vocab_size, device).to(device)

## Training helpers

In [0]:
criterion = nn.CrossEntropyLoss()
learning_rate = model_params['learning_rate']

opti = torch.optim.Adam(poet.parameters(),lr=learning_rate);

bok=0

def train(Xt, yt, bPr=False):
    poet.zero_grad()

    poet.init_hidden(Xt.size(0))
    output = poet(Xt, steps)
    
    olin=output.view(-1,vocab_size)
    _, ytp=torch.max(olin,1)
    ytlin=yt.view(-1)

    pr=0.0
    if bPr: # Calculate precision
        ok=0
        nok=0
        for i in range(ytlin.size()[0]):
            i1=ytlin[i].item()
            i2=ytp[i].item()
            if i1==i2:
                ok = ok + 1
            else:
                nok = nok+1
            pr=ok/(ok+nok)
            
    loss = criterion(olin, ytlin)
    ls = loss.item()
    loss.backward()
    opti.step()

    return ls, pr

## The actual training

In [0]:
ls=0
nrls=0
if use_cuda:
    intv=250
else:
    intv=10
for e in range(2500000):
    Xt, yt = get_data()
    if (e+1)%intv==0:
        l,pr=train(Xt,yt,True)
    else:
        l,pr=train(Xt,yt,False)        
    ls=ls+l
    nrls=nrls+1
    if (e+1)%intv==0:
        print("Loss: {} Precision: {}".format(ls/nrls, pr))
        # if use_cuda:
        #    print("Memory allocated: {} max_alloc: {} cached: {} max_cached: {}".format(torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated(), torch.cuda.memory_cached(), torch.cuda.max_memory_cached()))
        nrls=0
        ls=0
        tgen=poet.generate(500,"\n\n")
        textlib.source_highlight(tgen,10)

Loss: 3.255704989433289 Precision: 0.2521484375


Loss: 2.4498188133239744 Precision: 0.350390625


Loss: 2.169398317337036 Precision: 0.3958984375


Loss: 2.0434416241645814 Precision: 0.435546875


Loss: 1.9594793286323546 Precision: 0.4525390625


Loss: 1.8975659966468812 Precision: 0.46064453125


Loss: 1.8447654180526734 Precision: 0.447265625


Loss: 1.8022566781044007 Precision: 0.48359375


Loss: 1.7632980027198792 Precision: 0.4869140625


Loss: 1.7289632000923156 Precision: 0.48662109375


Loss: 1.6965673122406006 Precision: 0.51201171875


Loss: 1.6696258521080016 Precision: 0.510546875


Loss: 1.6417124342918397 Precision: 0.51005859375


Loss: 1.616397126197815 Precision: 0.5298828125


Loss: 1.5954541869163512 Precision: 0.521484375


Loss: 1.5743368620872498 Precision: 0.53310546875


Loss: 1.5554784007072449 Precision: 0.54033203125


Loss: 1.5383869190216064 Precision: 0.5416015625


Loss: 1.5220750260353089 Precision: 0.5625


Loss: 1.505407793045044 Precision: 0.5607421875


Loss: 1.4919890918731689 Precision: 0.559375


Loss: 1.477024534702301 Precision: 0.557421875


Loss: 1.467521065235138 Precision: 0.55849609375


Loss: 1.453620198249817 Precision: 0.5693359375


Loss: 1.4470531501770019 Precision: 0.5681640625


Loss: 1.4349612398147582 Precision: 0.5705078125


Loss: 1.425355254650116 Precision: 0.56484375


Loss: 1.416916063785553 Precision: 0.58564453125


Loss: 1.406871392726898 Precision: 0.57861328125


Loss: 1.3994286775588989 Precision: 0.5890625


Loss: 1.3925196762084961 Precision: 0.5833984375


Loss: 1.3843580837249756 Precision: 0.5806640625


Loss: 1.3785300645828247 Precision: 0.5990234375


Loss: 1.369513699054718 Precision: 0.5833984375


Loss: 1.3637665090560913 Precision: 0.58779296875


Loss: 1.3557679538726806 Precision: 0.5849609375


Loss: 1.3499893941879273 Precision: 0.58974609375


Loss: 1.3453693375587463 Precision: 0.585546875


Loss: 1.339822769165039 Precision: 0.5923828125


Loss: 1.3339936695098877 Precision: 0.6041015625


Loss: 1.3298990368843078 Precision: 0.5994140625


Loss: 1.32836759185791 Precision: 0.59951171875


Loss: 1.3209148054122926 Precision: 0.60673828125


Loss: 1.3171037216186523 Precision: 0.60302734375


Loss: 1.3125245246887207 Precision: 0.60712890625


Loss: 1.3071202368736268 Precision: 0.59912109375


Loss: 1.3077496972084046 Precision: 0.605859375


Loss: 1.2998870558738709 Precision: 0.59638671875


Loss: 1.2987980780601502 Precision: 0.59794921875


Loss: 1.293959973335266 Precision: 0.61064453125


Loss: 1.2906631383895875 Precision: 0.60693359375


Loss: 1.2875648727416993 Precision: 0.6037109375


Loss: 1.2853244881629944 Precision: 0.61064453125


Loss: 1.279796021938324 Precision: 0.612109375


Loss: 1.2811073250770568 Precision: 0.60849609375


Loss: 1.277353618144989 Precision: 0.61162109375


Loss: 1.2742029604911804 Precision: 0.6134765625


Loss: 1.268563380241394 Precision: 0.60634765625


Loss: 1.2704424510002137 Precision: 0.60908203125


Loss: 1.2676051445007324 Precision: 0.60205078125


Loss: 1.264794282913208 Precision: 0.61513671875


Loss: 1.2612366361618041 Precision: 0.617578125


Loss: 1.2609120292663574 Precision: 0.62705078125


Loss: 1.257060541152954 Precision: 0.61005859375


Loss: 1.2551992144584656 Precision: 0.6251953125


Loss: 1.2536875505447387 Precision: 0.62431640625


Loss: 1.250656361103058 Precision: 0.61201171875


Loss: 1.2514708976745605 Precision: 0.6126953125


Loss: 1.2468826050758361 Precision: 0.61455078125


Loss: 1.2434534072875976 Precision: 0.6275390625


Loss: 1.245025978088379 Precision: 0.6115234375


Loss: 1.2434160089492798 Precision: 0.60830078125


Loss: 1.2413553609848023 Precision: 0.6119140625


Loss: 1.236811535358429 Precision: 0.62236328125


Loss: 1.237205258846283 Precision: 0.61689453125


Loss: 1.2365548753738402 Precision: 0.6328125


Loss: 1.2324037365913392 Precision: 0.63427734375


Loss: 1.2327389097213746 Precision: 0.61875


Loss: 1.2309629244804383 Precision: 0.60986328125


Loss: 1.2271036033630371 Precision: 0.626953125


Loss: 1.2306803307533265 Precision: 0.612890625


Loss: 1.224728513240814 Precision: 0.6173828125


Loss: 1.2218732109069825 Precision: 0.6185546875


Loss: 1.2214697794914247 Precision: 0.62294921875


Loss: 1.2230288333892823 Precision: 0.62900390625


Loss: 1.219960021018982 Precision: 0.63017578125


Loss: 1.2190384187698364 Precision: 0.61396484375


Loss: 1.2183401141166688 Precision: 0.61591796875


Loss: 1.216660062789917 Precision: 0.62958984375


Loss: 1.2168455276489258 Precision: 0.625390625


Loss: 1.2131704607009888 Precision: 0.634765625


Loss: 1.211466220855713 Precision: 0.615625


Loss: 1.210780891418457 Precision: 0.6240234375


Loss: 1.2114899725914001 Precision: 0.63037109375


Loss: 1.208860357284546 Precision: 0.62275390625


Loss: 1.2094458994865418 Precision: 0.62529296875


Loss: 1.2079046516418457 Precision: 0.626953125


Loss: 1.2039598898887633 Precision: 0.63271484375


Loss: 1.2075602850914002 Precision: 0.640234375


Loss: 1.2042218565940856 Precision: 0.62060546875


Loss: 1.2036804928779603 Precision: 0.63232421875


Loss: 1.2002779693603516 Precision: 0.63857421875


Loss: 1.1999722185134887 Precision: 0.6234375


Loss: 1.2006292839050292 Precision: 0.634765625


Loss: 1.2001281580924987 Precision: 0.6212890625


Loss: 1.1979000740051269 Precision: 0.6341796875


Loss: 1.1956218242645265 Precision: 0.62470703125


Loss: 1.1981484413146972 Precision: 0.63603515625


Loss: 1.1940227813720703 Precision: 0.61884765625


Loss: 1.1938928060531617 Precision: 0.6353515625


Loss: 1.191798403263092 Precision: 0.63544921875


Loss: 1.1906272459030152 Precision: 0.6408203125


Loss: 1.191455096244812 Precision: 0.6228515625


Loss: 1.1883676109313965 Precision: 0.62626953125


Loss: 1.1902928609848022 Precision: 0.6244140625


Loss: 1.1916629514694215 Precision: 0.6294921875


Loss: 1.1875650119781493 Precision: 0.636328125


Loss: 1.1859152102470398 Precision: 0.62890625


Loss: 1.187756332397461 Precision: 0.63642578125


Loss: 1.1847564449310304 Precision: 0.64130859375


Loss: 1.1835960869789124 Precision: 0.63369140625


Loss: 1.1831025733947753 Precision: 0.63388671875


Loss: 1.1823793902397155 Precision: 0.6421875


Loss: 1.182653230190277 Precision: 0.62841796875


Loss: 1.1813934631347656 Precision: 0.6431640625


Loss: 1.18020378780365 Precision: 0.63466796875


Loss: 1.1768303337097168 Precision: 0.64404296875


Loss: 1.1803958902359009 Precision: 0.62470703125


Loss: 1.1805668187141418 Precision: 0.63037109375


Loss: 1.1762813992500305 Precision: 0.6455078125


Loss: 1.1765143780708314 Precision: 0.62705078125


Loss: 1.1763078570365906 Precision: 0.634375


Loss: 1.1763301029205322 Precision: 0.637109375


Loss: 1.1765989856719972 Precision: 0.6388671875


Loss: 1.1728280606269836 Precision: 0.64716796875


Loss: 1.175449453830719 Precision: 0.64150390625


Loss: 1.172529100894928 Precision: 0.63740234375


Loss: 1.173108594894409 Precision: 0.63779296875


Loss: 1.1702784023284911 Precision: 0.63271484375


Loss: 1.1711036620140076 Precision: 0.63837890625


Loss: 1.1701803660392762 Precision: 0.63349609375


Loss: 1.1718324432373046 Precision: 0.64150390625


Loss: 1.168165397167206 Precision: 0.64091796875


Loss: 1.1656170907020569 Precision: 0.63369140625


Loss: 1.1676516904830934 Precision: 0.634375


Loss: 1.1677251882553101 Precision: 0.63828125


Loss: 1.16567578458786 Precision: 0.6439453125


Loss: 1.1672618045806884 Precision: 0.6369140625


Loss: 1.1645196571350098 Precision: 0.63310546875


Loss: 1.1645893607139588 Precision: 0.637890625


Loss: 1.1630619463920593 Precision: 0.6427734375


Loss: 1.1646311755180359 Precision: 0.64150390625


Loss: 1.161884404182434 Precision: 0.6291015625


Loss: 1.1630174775123596 Precision: 0.637890625


Loss: 1.1618769283294679 Precision: 0.64384765625


Loss: 1.1611861901283265 Precision: 0.64287109375


Loss: 1.1632365322113036 Precision: 0.63642578125


Loss: 1.1579099631309508 Precision: 0.6435546875


Loss: 1.1593650283813477 Precision: 0.64287109375


Loss: 1.1596139211654664 Precision: 0.633984375


Loss: 1.1598324265480042 Precision: 0.634765625


Loss: 1.1581641135215759 Precision: 0.6302734375


Loss: 1.155652060031891 Precision: 0.632421875


Loss: 1.1565477452278137 Precision: 0.65068359375


Loss: 1.1569456086158751 Precision: 0.64541015625


Loss: 1.1544312963485717 Precision: 0.6541015625


Loss: 1.154347008705139 Precision: 0.65146484375


Loss: 1.1548327412605286 Precision: 0.63603515625


Loss: 1.155311399459839 Precision: 0.6537109375


Loss: 1.153081754207611 Precision: 0.6337890625


Loss: 1.1530686926841736 Precision: 0.64970703125


Loss: 1.1543822207450867 Precision: 0.63017578125


Loss: 1.151426335811615 Precision: 0.645703125


Loss: 1.1509396104812621 Precision: 0.633984375


Loss: 1.1520266103744508 Precision: 0.62841796875


Loss: 1.149930950641632 Precision: 0.63759765625


Loss: 1.1505857667922974 Precision: 0.64638671875


Loss: 1.148957104206085 Precision: 0.6369140625


Loss: 1.1511017994880677 Precision: 0.6404296875


Loss: 1.146714647769928 Precision: 0.64619140625


Loss: 1.1492949962615966 Precision: 0.65419921875


Loss: 1.149353605747223 Precision: 0.64169921875


Loss: 1.148136679649353 Precision: 0.64072265625


Loss: 1.1480453968048097 Precision: 0.63681640625


Loss: 1.145222580909729 Precision: 0.65234375


Loss: 1.1473090105056762 Precision: 0.64150390625


Loss: 1.1446104683876037 Precision: 0.6412109375


Loss: 1.1455671920776367 Precision: 0.6408203125


Loss: 1.143882631778717 Precision: 0.655078125


Loss: 1.1440187606811523 Precision: 0.6455078125


Loss: 1.1449364705085754 Precision: 0.6376953125


Loss: 1.143679964542389 Precision: 0.65244140625


Loss: 1.1432036862373351 Precision: 0.6412109375


Loss: 1.1414262166023255 Precision: 0.640234375


Loss: 1.1441850423812867 Precision: 0.641796875


Loss: 1.1415747175216675 Precision: 0.647265625


Loss: 1.1391691493988036 Precision: 0.64521484375


Loss: 1.1407215247154237 Precision: 0.64130859375


Loss: 1.1395437927246095 Precision: 0.6513671875


Loss: 1.1434314045906067 Precision: 0.65087890625


Loss: 1.1399984936714171 Precision: 0.64619140625


Loss: 1.139115396976471 Precision: 0.64189453125


Loss: 1.139301691532135 Precision: 0.65087890625


Loss: 1.1394615168571471 Precision: 0.6416015625


Loss: 1.1388295049667359 Precision: 0.6482421875


Loss: 1.1377213006019593 Precision: 0.640625


Loss: 1.135403344631195 Precision: 0.65029296875


Loss: 1.1376732864379884 Precision: 0.64912109375


Loss: 1.13677170753479 Precision: 0.65009765625


Loss: 1.1355699400901795 Precision: 0.645703125


Loss: 1.1355364990234376 Precision: 0.65078125


Loss: 1.1361627850532532 Precision: 0.63564453125


Loss: 1.1327106919288634 Precision: 0.65546875


Loss: 1.1356502957344055 Precision: 0.64248046875


Loss: 1.1343560853004455 Precision: 0.64921875


Loss: 1.1330626640319825 Precision: 0.64765625


Loss: 1.1313126850128175 Precision: 0.6509765625


Loss: 1.1315498871803285 Precision: 0.64677734375


Loss: 1.1313760385513305 Precision: 0.640625


Loss: 1.1307169346809387 Precision: 0.63759765625


Loss: 1.1311377673149108 Precision: 0.6423828125


Loss: 1.131107626914978 Precision: 0.65224609375


Loss: 1.1307828822135926 Precision: 0.64072265625


Loss: 1.1302334842681885 Precision: 0.649609375


Loss: 1.1295686268806457 Precision: 0.63388671875


Loss: 1.1298975796699524 Precision: 0.643359375


Loss: 1.130827606678009 Precision: 0.650390625


Loss: 1.129229344367981 Precision: 0.65546875


Loss: 1.1299602818489074 Precision: 0.64140625


Loss: 1.1259336709976195 Precision: 0.65146484375


Loss: 1.129051558494568 Precision: 0.6515625


Loss: 1.1284474830627442 Precision: 0.65


Loss: 1.1281633534431457 Precision: 0.64013671875


Loss: 1.1264382047653199 Precision: 0.659375


Loss: 1.1266458435058593 Precision: 0.6494140625


Loss: 1.1265461778640746 Precision: 0.65048828125


Loss: 1.1270105171203613 Precision: 0.65439453125


Loss: 1.1267406611442565 Precision: 0.64990234375


Loss: 1.1261787209510803 Precision: 0.64365234375


Loss: 1.1272695956230163 Precision: 0.6435546875


Loss: 1.1243583984375 Precision: 0.64560546875


## Generate text

In [0]:
def detectPlagiarism(generatedtext, textlibrary, minQuoteLength=10):
    textlibrary.source_highlight(generatedtext, minQuoteLength)
    
tgen=poet.generate(1000,"\n\n")
detectPlagiarism(tgen, textlib)

## Dialog

In [0]:
# Do a dialog with the recursive neural net trained above:
def doDialog():
    # temperature = 0.6  # 0.1 (frozen character) - 1.3 (creative/chaotic character)
    endPrompt = '.'  # the endPrompt character is the end-mark in answers.
    maxEndPrompts = 4  # look for number of maxEndPrompts until answer is finished.
    maxAnswerSize = 2048  # Maximum length of the answer
    minAnswerSize = 64  # Minimum length of the answer

    
    print("Please enter some dialog.")
    print("The net will answer according to your input.")
    print("'bye' for end,")
    print("'reset' to reset the conversation context,")
    # print("'temperature=<float>' [0.1(frozen)-1.0(creative)]")
    print("    to change character of the dialog.")
    # print("    Current temperature={}.".format(temperature))
    print()
    xso = None
    bye = False
        
    while not bye:
        print("> ", end="")
        prompt = input()
        if prompt == 'bye':
            bye = True
            print("Good bye!")
            continue
        tgen=poet.generate(1000,prompt)
        # print(xso.replace("\\n","\n"))
        textlib.source_highlight(tgen, 10)
    return

In [0]:
doDialog()

In [0]:
def save_checkpoint(state, filename='checkpoint.pth.tar', is_best=False):
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, 'model_best.pth.tar')

best_prec1=64.4

save_checkpoint({
            'epoch': e,
            'arch': "poet8",
            'state_dict': poet.state_dict(),
            'best_prec1': best_prec1,
            'optimizer' : opti.state_dict(),
        })
