<a href="https://colab.research.google.com/github/domschl/torch-poet/blob/master/torch_poet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import os
import shutil
import sys
import json
import random
import torch
import torch.nn as nn
from torch import Tensor

try:
    from urllib.request import urlopen  # Py3
except:
    print("This notebook requires Python 3.")
try:
    import pathlib
except:
    print("At least python 3.5 is needed.")
    
try: # Colab instance?
    from google.colab import drive
except: # Not? ignore.
    pass

from IPython.core.display import display, HTML

# 0. System configuration

This notebook can either run on a local jupyter server, or on google cloud.
If a GPU is available, it will be used for training (if `force_cpu` is not set to `True`).

By default snapshots of the trained net are stored locally for jupyter instances, and on user's google drive for Google Colab instances. The snapshots allow the restart of training or inference at any time, e.g. after the Colab session was terminated.

Similarily, the text corpora that are used for training, can be cached on drive or locally.

In [None]:
# force_cpu=True: use CPU for training, even if a GPU is available.
#    Note: inference uses CPU always, because that is faster.
force_cpu=False

# Define where snapshots of training data are stored:
colab_google_drive_snapshots=True

# Define if training data (the texts downloaded from internet) are cached:
colab_google_drive_data_cache=True  # In colab mode cache to google drive
local_jupyter_data_cache=True       # In local jupyter mode cache to local path

In [None]:
is_colab_notebook = 'google.colab' in sys.modules
torch_version = torch.__version__

if torch.cuda.is_available() and force_cpu is not True:
    device='cuda'
    use_cuda = True
    print(f"PyTorch {torch_version}, running on GPU")
    if is_colab_notebook:
        card = !nvidia-smi
        if len(card)>=8:
            try:
                gpu_type=card[7][6:25]
                gpu_memory=card[8][33:54]
                print(f"Colab GPU: {gpu_type}, GPU Memory: {gpu_memory}")
            except Exception as e:
                pass
else:
    device='cpu'
    use_cuda = False
    print(f"{torch_version}, running on CPU")
    if colab_notebook:
        print("Note: on Google Colab, make sure to select:")
        print("      Runtime / Change Runtime Type / Hardware accelerator: GPU")

In [None]:
if is_colab_notebook:
    if colab_google_drive_snapshots:
        mountpoint='/content/drive'
        root_path='/content/drive/My Drive'
        if not os.path.exists(root_path):
            drive.mount(mountpoint)
        if not os.path.exists(root_path):
            print("Something went wrong with Google Drive access. Cannot save snapshots to GD.")
            colab_google_drive_snapshots=False
    else:
        print("Since google drive snapshots are not active, training data will be lost as soon as the Colab session terminates!")
        print("Set `colab_google_drive_snapshots` to `True` to make training data persistent.")
else:
    root_path='.'

# 1. Text data collection

## 1.1 Text library

`TextLibrary` class: text library for training, encoding, batch generation,
and formatted source display. It read some books from Project Gutenberg
and supports creation of training batches. The output functions support
highlighting to allow to compare generated texts with the actual sources
to help to identify identical (memorized) parts of a given length.

In [None]:
use_dark_mode=False  # Set to false for white background

In [1]:
class TextLibrary:
    def __init__(self, descriptors, text_data_cache_directory=None, max=100000000):
        self.descriptors = descriptors
        self.data = ''
        self.cache_dir=text_data_cache_directory
        self.files = []
        self.c2i = {}
        self.i2c = {}
        index = 1
        for descriptor, author, title in descriptors:
            fd = {}
            cache_name=self.get_cache_name(author, title)
            if os.path.exists(cache_name):
                is_cached=True
            else:
                is_cached=False
            valid=False
            if descriptor[:4] == 'http' and is_cached is False:
                try:
                    print(f"Downloading {cache_name}")
                    dat = urlopen(descriptor).read().decode('utf-8')
                    if dat[0]=='\ufeff':  # Ignore BOM
                        dat=dat[1:]
                    dat=dat.replace(b'\r', '')  # get rid of pesky LFs 
                    self.data += dat
                    fd["title"] = title
                    fd["author"] = author
                    fd["data"] = dat
                    fd["index"] = index
                    index += 1
                    valid=True
                    self.files.append(fd)
                except Exception as e:
                    print(f"Can't download {descriptor}: {e}")
            else:
                fd["title"] = title
                fd["author"] = author
                try:
                    if is_cached is True:
                        print(f"Reading {cache_name} from cache")
                        f = open(cache_name)
                    else:    
                        f = open(descriptor)
                    dat = f.read(max)
                    self.data += dat
                    fd["data"] = dat
                    fd["index"] = index
                    index += 1
                    self.files.append(fd)
                    f.close()
                    valid=True
                except Exception as e:
                    print(f"ERROR: Cannot read: {filename}: {e}")
            if valid is True and is_cached is False and self.cache_dir is not None:
                try:
                    print(f"Caching {cache_name}")
                    f = open(cache_name, 'w')
                    f.write(dat)
                    f.close()
                except Exception as e:
                    print(f"ERROR: failed to save cache {cache_name}: {e}")
                
        ind = 0
        for c in self.data:  # sets are not deterministic
            if c not in self.c2i:
                self.c2i[c] = ind
                self.i2c[ind] = c
                ind += 1
        self.ptr = 0

    def get_cache_name(self, author, title):
        if self.cache_dir is None:
            return None
        cname=f"{author} - {title}.txt"
        cache_filepath=os.path.join(self.cache_dir , cname)
        return cache_filepath
        
    def display_colored_html(self, textlist, dark_mode=False, pre='', post=''):
        bgcolorsWht = ['#d4e6e1', '#d8daef', '#ebdef0', '#eadbd8', '#e2d7d5', '#edebd0',
                    '#ecf3cf', '#d4efdf', '#d0ece7', '#d6eaf8', '#d4e6f1', '#d6dbdf',
                    '#f6ddcc', '#fae5d3', '#fdebd0', '#e5e8e8', '#eaeded', '#A9CCE3']
        bgcolorsDrk = ['#342621','#483a2f', '#3b4e20', '#2a3b48', '#324745', '#3d3b30',
                    '#3c235f', '#443f4f', '#403c37', '#463a28', '#443621', '#364b5f',
                    '#264d4c', '#2a3553', '#3d2b40', '#354838', '#3a3d4d', '#594C23']
        if dark_mode is False:
            bgcolors=bgcolorsWht
        else:
            bgcolors=bgcolorsDrk
        out = ''
        for txt, ind in textlist:
            txt = txt.replace('\n', '<br>')
            if ind == 0:
                out += txt
            else:
                out += "<span style=\"background-color:"+bgcolors[ind % 16]+";\">" + \
                       txt + "</span>"+"<sup>[" + str(ind) + "]</sup>"
        display(HTML(pre+out+post))

    def source_highlight(self, txt, minQuoteSize=10, dark_mode=False):
        tx = txt
        out = []
        qts = []
        txsrc = [("Sources: ", 0)]
        sc = False
        noquote = ''
        while len(tx) > 0:  # search all library files for quote 'txt'
            mxQ = 0
            mxI = 0
            mxN = ''
            found = False
            for f in self.files:  # find longest quote in all texts
                p = minQuoteSize
                if p <= len(tx) and tx[:p] in f["data"]:
                    p = minQuoteSize + 1
                    while p <= len(tx) and tx[:p] in f["data"]:
                        p += 1
                    if p-1 > mxQ:
                        mxQ = p-1
                        mxI = f["index"]
                        mxN = f"{f['author']}: {f['title']}"
                        found = True
            if found:  # save longest quote for colorizing
                if len(noquote) > 0:
                    out.append((noquote, 0))
                    noquote = ''
                out.append((tx[:mxQ], mxI))
                tx = tx[mxQ:]
                if mxI not in qts:  # create a new reference, if first occurence
                    qts.append(mxI)
                    if sc:
                        txsrc.append((", ", 0))
                    sc = True
                    txsrc.append((mxN, mxI))
            else:
                noquote += tx[0]
                tx = tx[1:]
        if len(noquote) > 0:
            out.append((noquote, 0))
            noquote = ''
        self.display_colored_html(out, dark_mode=dark_mode)
        if len(qts) > 0:  # print references, if there is at least one source
            self.display_colored_html(txsrc, dark_mode=dark_mode, pre="<small><p style=\"text-align:right;\">",
                                     post="</p></small>")

    def get_slice(self, length):
        if (self.ptr + length >= len(self.data)):
            self.ptr = 0
        if self.ptr == 0:
            rst = True
        else:
            rst = False
        sl = self.data[self.ptr:self.ptr+length]
        self.ptr += length
        return sl, rst

    def decode(self, ar):
        return ''.join([self.i2c[ic] for ic in ar])

    def get_random_slice(self, length):
        p = random.randrange(0, len(self.data)-length)
        sl = self.data[p:p+length]
        return sl

    def get_slice_array(self, length):
        ar = np.array([c for c in self.get_slice(length)[0]])
        return ar

    def get_encoded_slice(self, length):
        s, rst = self.get_slice(length)
        X = [self.c2i[c] for c in s]
        return X
        
    def get_encoded_slice_array(self, length):
        return np.array(self.get_encoded_slice(length))

    def get_sample(self, length):
        s, rst = self.get_slice(length+1)
        X = [self.c2i[c] for c in s[:-1]]
        y = [self.c2i[c] for c in s[1:]]
        return (X, y, rst)

    def get_random_sample(self, length):
        s = self.get_random_slice(length+1)
        X = [self.c2i[c] for c in s[:-1]]
        y = [self.c2i[c] for c in s[1:]]
        return (X, y)

    def get_sample_batch(self, batch_size, length):
        smpX = []
        smpy = []
        for i in range(batch_size):
            Xi, yi, rst = self.get_sample(length)
            smpX.append(Xi)
            smpy.append(yi)
        return smpX, smpy, rst

    def get_random_sample_batch(self, batch_size, length):
        smpX = []
        smpy = []
        for i in range(batch_size):
            Xi, yi = self.get_random_sample(length)
            smpX.append(Xi)
            smpy.append(yi)
        return np.array(smpX), np.array(smpy)

## 1.2 Data sources

Data sources can either be files from local filesystem, or for colab notebooks from google drive, or http(s) links.

The `name` given will be use as directory name for both snapshots and model data caches.

Each entry in the `lib` array contains of:

1. a local filename or https(s) link,
2. an Author's name
3. a title


In [None]:
libdesc = {
    "name": "Women-Writers",
    "description": "A collection of works of Woolf, Austen and Brontë",
    "lib": [
        # ('data/tiny-shakespeare.txt', 'William Shakespeare', 'Some parts'),   # local file example
        # ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/1/0/100/100-0.txt', 'Shakespeare', 'Collected Works'),
        ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/3/7/4/3/37431/37431.txt', 'Jane Austen', 'Pride and Prejudice'),
        ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/7/6/768/768.txt', 'Emily Brontë', 'Wuthering Heights'),         
        ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/1/4/144/144.txt', 'Virginia Wolf', 'Voyage out'),
        ('http://www.mirrorservice.org/sites/ftp.ibiblio.org/pub/docs/books/gutenberg/1/5/158/158.txt', 'Jane Austen', 'Emma')
    ]
}

In [None]:
if is_colab_notebook:
    if colab_google_drive_data_cache is True:
        data_cache_path=os.path.join(root_path,f"Colab Notebooks/{libdesc['name']}/Data")
    else:
        data_cache_path=None
else:
    if local_jupyter_data_cache is True:
        data_cache_path=os.path.join(root_path,f"{libdesc['name']}/Data")
    else:
        data_cache_path=None

if data_cache_path is not None:
    pathlib.Path(data_cache_path).mkdir(parents=True, exist_ok=True)
    if not os.path.exists(data_cache_path):
        print("ERROR, the cache directory does not exist. This will fail.")
    else:
        with open(os.path.join(data_cache_path,'libdesc.json'),'w') as f:
            json.dump(libdesc,f,indent=4)

In [None]:
textlib = TextLibrary(libdesc["lib"], text_data_cache_directory=data_cache_path)

# 2. The deep LSTM model

# 2.1 Model configuration parameters

In [None]:
model_params = {
    "model_name": libdesc['name'],
    "vocab_size": len(textlib.i2c),
    "neurons": 256,
    "layers": 2,
    "learning_rate": 1.e-3,
    "steps": 80,
    "batch_size": 128
}

## 2.2 The char-rnn model class

In [None]:
class Poet(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, device):
        super(Poet, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.device=device
        
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=0)
        
        self.demb = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=-1)  # negative dims are a recent thing (as 2018-03), remove for old vers.
    
    def init_hidden(self, batch_size):
        self.h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=self.device)
        self.c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=self.device)

    def forward(self, inputx, steps):
        self.lstm.flatten_parameters()
        hn, (self.h0, self.c0) = self.lstm(inputx.to(self.device), (self.h0, self.c0))
        hnr = hn.contiguous().view(-1,self.hidden_size)
        op = self.demb(hnr)
        opr = op.view(-1, steps ,self.output_size)
        return opr

    def generate(self, n, start=None):
        s=''
        torch.set_grad_enabled(False)
        if start==None or len(start)==0:
            start=' '
        self.init_hidden(1)
        for c in start:
            X=np.array([[textlib.c2i[c]]])
            Xo=one_hot(X,self.output_size)
            Xt = Tensor(torch.from_numpy(np.array(Xo,dtype=np.float32))).to(self.device)
            ypl = self.forward(Xt,1)
            ypl2 = ypl.view(-1,self.output_size)
            yp = self.softmax(ypl2)
        for i in range(n):
            ypc=Tensor.cpu(yp.detach()) # .cpu()
            y_pred=ypc.numpy()
            inds=list(range(self.output_size))
            ind = np.random.choice(inds, p=y_pred.ravel())
            s=s+textlib.i2c[ind]
            X=np.array([[ind]])
            Xo=one_hot(X,self.output_size)
            Xt = Tensor(torch.from_numpy(np.array(Xo,dtype=np.float32))).to(self.device)
            ypl = self.forward(Xt,1)
            ypl2 = ypl.view(-1,self.output_size)
            yp = self.softmax(ypl2)
        torch.set_grad_enabled(True)
        return s    

## 2.3 Model instance

In [None]:
poet = Poet(model_params['vocab_size'], model_params['neurons'], model_params['layers'], model_params['vocab_size'], device).to(device)

## 2.4 Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
learning_rate = model_params['learning_rate']

opti = torch.optim.Adam(poet.parameters(),lr=learning_rate);

## 2.5 Helper Functions

These allow to save or restore the training data. Saving and restoring can either be performed:

* Jupyter: store/restore in a local directory,
* Colab: store/restore on google drive. The training-code (using load_checkpoint()) will display an authentication url and code input-box in order to be able to access your google drive from this notebook. This allows to continue training sessions (or inference) after the Colab session was terminated.

In [None]:
if is_colab_notebook:
    if colab_google_drive_snapshots is True:
        snapshot_path=os.path.join(root_path,f"Colab Notebooks/{model_params['model_name']}/Snapshots")
    else:
        snapshot_path=None
else:
    snapshot_path=os.path.join(root_path,f"{model_params['model_name']}/Snapshots")

In [None]:
def get_project_path():
    if snapshot_path is None:
        return None
    project_path_ext=f"model-{model_params['vocab_size']}x{model_params['steps']}x{model_params['layers']}x{model_params['neurons']}"
    return os.path.join(snapshot_path, project_path_ext)

def create_project_path():
    if snapshot_path is None:
        return None
    ppath=get_project_path()
    pathlib.Path(ppath).mkdir(parents=True, exist_ok=True)

In [None]:
if snapshot_path is not None:
    pathlib.Path(snapshot_path).mkdir(parents=True, exist_ok=True)
    create_project_path()
    with open(os.path.join(get_project_path(),'model_params.json'),'w') as f:
        json.dump(model_params,f,indent=4)

In [None]:

best_pr=0.0

def save_checkpoint(epoch, loss, pr, filename='checkpoint.pth.tar'):
    if snapshot_path is None:
        return
    global best_pr
    state={
            'epoch': epoch,
            'model_config': model_params,
            'state_dict': poet.state_dict(),
            'optimizer' : opti.state_dict(),
            'precision': pr,
            'loss': loss,
        }
    project_path=get_project_path()
    save_file=os.path.join(project_path,filename)
    best_file=os.path.join(project_path,'model_best.pth.tar')
    torch.save(state, save_file)
    if pr>best_pr:
        best_pr=pr
        shutil.copyfile(save_file, best_file )
        print(f"Saved best precision model, prec={pr}")
    else:
        print(f"saved last model data, prec={pr}")

def load_checkpoint(filename='checkpoint.pth.tar'):
    if snapshot_path is None:
        return 0,0
    project_path=get_project_path()
    load_file=os.path.join(project_path,filename)
    if not os.path.exists(load_file):
        print(load_file)
        print("No saved state, starting from scratch.")
        return 0,0
    state=torch.load(load_file)
    mod_conf = state['model_config']
    if (mod_conf['model_name']!=model_params['model_name']):
        print(f"Warning: project has been renamed from {mod_conf['model_name']} to {model_param['model_name']}")
        mod_conf['model_name']=model_params['model_name']
    if model_params!=mod_conf:
        print(f"The saved model has a different configuration than the current model: {mod_conf} vs. {model_params}")
        print("Cannot restore state, starting from scratch.")
        return 0,0
    poet.load_state_dict(state['state_dict'])
    opti.load_state_dict(state['optimizer'])
    epoch = state['epoch']
    loss = state['loss']
    best_pr = state['precision']
    print(f"Continuing from saved state epoch={epoch}, loss={loss}")  # Save is not necessarily on epoch boundary, so that's approx.
    return epoch,loss

def one_hot(p, dim):
    o=np.zeros(p.shape+(dim,), dtype=int)
    for y in range(p.shape[0]):
        for x in range(p.shape[1]):
            o[y,x,p[y,x]]=1
    return o

# 3. Training

If there is already saved training data, this step is optional, and alternatively, ch. 4 can be continued.

## 3.1 Training helpers

In [None]:
def get_data():
    X, y=textlib.get_random_sample_batch(model_params['batch_size'], model_params['steps'])
    Xo = one_hot(X, model_params['vocab_size'])
    
    # Xt = Tensor(torch.from_numpy(np.array(Xo,dtype=np.float32)), requires_grad=False, dtype=torch.float32, device=device)
    # yt = Tensor(torch.from_numpy(y), requires_grad=False, dtype=torch.int32, device=device)
    Xt = Tensor(torch.from_numpy(np.array(Xo,dtype=np.float32))).to(device)
    Xt.requires_grad_(False)
    yt = torch.LongTensor(torch.from_numpy(np.array(y,dtype=np.int64))).to(device)
    yt.requires_grad_(False)
    return Xt, yt

def train(Xt, yt, bPr=False):
    poet.zero_grad()

    poet.init_hidden(Xt.size(0))
    output = poet(Xt, model_params['steps'])
    
    olin=output.view(-1,model_params['vocab_size'])
    _, ytp=torch.max(olin,1)
    ytlin=yt.view(-1)

    pr=0.0
    if bPr: # Calculate precision
        ok=0
        nok=0
        for i in range(ytlin.size()[0]):
            i1=ytlin[i].item()
            i2=ytp[i].item()
            if i1==i2:
                ok = ok + 1
            else:
                nok = nok+1
            pr=ok/(ok+nok)
            
    loss = criterion(olin, ytlin)
    ls = loss.item()
    loss.backward()
    opti.step()

    return ls, pr

## 3.2 The actual training loop

In [None]:
ls=0
nrls=0
if use_cuda:
    intv=250
else:
    intv=10

create_project_path()
epoch_start, _ = load_checkpoint()

for e in range(epoch_start,2500000):
    Xt, yt = get_data()
    if (e+1)%intv==0:
        l,pr=train(Xt,yt,True)
    else:
        l,pr=train(Xt,yt,False)        
    ls=ls+l
    nrls=nrls+1
    if (e+1)%intv==0:
        print("Epoch {} Loss: {} Precision: {}".format(e+1,ls/nrls, pr))
        save_checkpoint(e,ls/nrls,pr)
        if use_cuda:
            print("Memory allocated: {} max_alloc: {} cached: {} max_cached: {}".format(torch.cuda.memory_allocated(), torch.cuda.max_memory_allocated(), torch.cuda.memory_cached(), torch.cuda.max_memory_cached()))
        nrls=0
        ls=0
        tgen=poet.generate(500,"\n\n")
        textlib.source_highlight(tgen,minQuoteSize=10,dark_mode=use_dark_mode)

# 4. Text generation

## 4.1 Helpers

In [None]:
def detectPlagiarism(generatedtext, textlibrary, minQuoteLength=10):
    textlibrary.source_highlight(generatedtext, minQuoteSize=minQuoteLength,dark_mode=use_dark_mode)

## 4.2 Dialog with the model

In [None]:
# Do a dialog with the recursive neural net trained above:
def doDialog():
    # temperature = 0.6  # 0.1 (frozen character) - 1.3 (creative/chaotic character)
    endPrompt = '.'  # the endPrompt character is the end-mark in answers.
    maxEndPrompts = 4  # look for number of maxEndPrompts until answer is finished.
    maxAnswerSize = 2048  # Maximum length of the answer
    minAnswerSize = 64  # Minimum length of the answer

    
    print("Please enter some dialog.")
    print("The net will answer according to your input.")
    print("'bye' for end,")
    print("'reset' to reset the conversation context,")
    # print("'temperature=<float>' [0.1(frozen)-1.0(creative)]")
    print("    to change character of the dialog.")
    # print("    Current temperature={}.".format(temperature))
    print()
    xso = None
    bye = False
        
    while not bye:
        print("> ", end="")
        prompt = input()
        if prompt == 'bye':
            bye = True
            print("Good bye!")
            continue
        tgen=poet.generate(1000,prompt)
        # print(xso.replace("\\n","\n"))
        textlib.source_highlight(tgen, minQuoteSize=10,dark_mode=use_dark_mode)
    return

In [None]:
load_checkpoint(filename="model_best.pth.tar")

In [None]:
print("Sample text:")
print("")
tgen=poet.generate(1000,"\n\n")
detectPlagiarism(tgen, textlib)

In [None]:
doDialog()