<a href="https://colab.research.google.com/github/domschl/torch-poet/blob/master/torch_poet_using_indie_tools.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import numpy as np
import os
import shutil
from enum import Enum
import re
import time
import logging
import sys
import json
# import random
import torch
import torch.nn as nn
from torch import Tensor

In [None]:
# Run this ONLY for TPU tests (this [currently=01/2022] DOWNGRADES torch for compatibility!):
# See: https://colab.research.google.com/github/pytorch/xla/blob/master/contrib/colab/getting-started.ipynb#scrollTo=yUB12htcqU9W
# !pip install cloud-tpu-client==0.10 torch==1.9.0 https://storage.googleapis.com/tpu-pytorch/wheels/torch_xla-1.9-cp37-cp37m-linux_x86_64.whl

Please review [ml-indie-tools](https://github.com/domschl/ml-indie-tools), a collection machine learning tools that provides support for more environment indepent code. It will access your Google Drive when using with Google Colab.

In [None]:
!pip install -U ml-indie-tools

In [None]:
from ml_indie_tools.env_tools import MLEnv
from ml_indie_tools.Gutenberg_Dataset import Gutenberg_Dataset
from ml_indie_tools.Text_Dataset import Text_Dataset

# 0. System configuration

This notebook can either run on a local hardware, a jupyter server, or on Google Colab.

This version of the notebook uses [ml-indie-tools](https://github.com/domschl/ml-indie-tools) to detect hardware, persistence handling, and access to training data. 

In [None]:
ml_env=MLEnv(platform='pt')  # use PyTorch
ml_env.describe()

In [None]:
project_name='women_writers'
model_name='lstm_v1'
root_path, project_path, model_path, data_path, log_path = ml_env.init_paths(project_name=project_name, model_name=model_name)

# 1. Text data collection

**Important note:** the following `project_name` determines the root directory for training data and model snapshots, so it should be changed whenever datasets of model configurations are changed.

## 1.1 Project Gutenberg data source

Search, filter, clean and download books from Project Gutenberg

In [None]:
logging.basicConfig(level=logging.INFO)

In [None]:
cache_dir = os.path.join(data_path, 'gutenberg_cache')
gd = Gutenberg_Dataset(cache_dir=cache_dir)

In [None]:
# sample searches
search_spec= {"author": ["brontë","Jane Austen", "Virginia Woolf"], "language": ["english"]}

book_list=gd.search(search_spec)
book_cnt = len(book_list)
print(f"{book_cnt} matching books found with search {search_spec}.")
if book_cnt<40:
    # Note: please verify that book_cnt is 'reasonable'. If you plan to use a large number of texts, 
    # consider [mirroring Gutenberg](https://github.com/domschl/ml-indie-tools#working-with-a-local-mirror-of-project-gutenberg)
    book_list = gd.insert_book_texts(book_list, download_count_limit=book_cnt)  
else:
    logging.error("Please verify your book_list, a large number of books is scheduled for download. ABORTED.")

## 1.2 Text library

`TextLibrary` class: text library for training, encoding, batch generation,
and formatted source display. It read some books from Project Gutenberg
and supports creation of training batches. The output functions support
highlighting to allow to compare generated texts with the actual sources
to help to identify identical (memorized) parts of a given length.

In [None]:
td = Text_Dataset(book_list)

In [None]:
class TextLibraryDataset(torch.utils.data.Dataset):
    def __init__(self, text_dataset, sample_length, torch_device, text_stepping=10):
        self.device=torch_device
        self.text_length=0
        full_text=""
        for text in text_dataset.text_list:
            if 'text' in text:
                full_text += text['text']
        text_encode = text_dataset.encode(full_text)
        self.text_length = len(text_encode)
        self.vocab_size = len(text_dataset.i2c)
        self.text_stepping=text_stepping
        self.sample_length=sample_length
        self.records=int((self.text_length-sample_length-1)/text_stepping)
        self.data=torch.LongTensor(text_encode).to(self.device)
        del text_encode
        del full_text
        
    def __len__(self):
        return self.records

    def __getitem__(self, idx):
        if torch.is_tensor(idx):
            idx = idx.tolist()
        if idx>=self.records:
            return None
        X=self.data[idx*self.text_stepping:idx*self.text_stepping+self.sample_length].to(self.device)
        y=self.data[idx*self.text_stepping+1:idx*self.text_stepping+self.sample_length+1].to(self.device)
        return X,y

# 2. The deep LSTM model

## 2.2 The char-rnn model class

In [None]:
class Poet(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, device):
        super(Poet, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.output_size = output_size
        self.device=device
       
        self.oh = torch.eye(input_size, device=self.device)
        self.lstm = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=0)
        
        self.demb = nn.Linear(hidden_size, output_size)
        self.softmax = nn.Softmax(dim=-1)  # negative dims are a recent thing (as 2018-03), remove for old vers.
    
    def init_hidden(self, batch_size):
        self.h0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=self.device)
        self.c0 = torch.zeros(self.num_layers, batch_size, self.hidden_size, device=self.device)

    def forward(self, inputx, steps):
        lstm_input=self.oh[inputx]
        self.lstm.flatten_parameters()
        hn, (self.h0, self.c0) = self.lstm(lstm_input, (self.h0, self.c0))
        hnr = hn.contiguous().view(-1,self.hidden_size)
        op = self.demb(hnr)
        opr = op.view(-1, steps ,self.output_size)
        return opr

    def generate(self, n, start=None, temperature=1.0):
        s=''
        torch.set_grad_enabled(False)
        if start==None or len(start)==0:
            start=' '
        self.init_hidden(1)
        for c in start:
            Xt=torch.LongTensor([[td.c2i[c]]])
            ypl = self.forward(Xt,1)
            ypl2 = ypl.view(-1,self.output_size)
            if temperature>0.0:
                ypl2 = ypl2 / temperature
            yp = self.softmax(ypl2)
        for i in range(n):
            ypc=Tensor.cpu(yp.detach()) # .cpu()
            y_pred=ypc.numpy()
            inds=list(range(self.output_size))
            ind = np.random.choice(inds, p=y_pred.ravel())
            s=s+td.i2c[ind]
            X=np.array([[ind]])
            Xt=torch.LongTensor(X)
            ypl = self.forward(Xt,1)
            ypl2 = ypl.view(-1,self.output_size)
            if temperature>0.0:
                ypl2 = ypl2 / temperature
            yp = self.softmax(ypl2)
        torch.set_grad_enabled(True)
        return s    

## 2.3 Model instance

In [None]:
td.init_tokenizer('char')

In [None]:
model_params = {
    "model_name": model_name,
    "vocab_size": len(td.i2c),
    "neurons": 256,
    "layers": 2,
    "learning_rate": 1e-4,
    "steps": 64,
    "batch_size": 256
}

In [None]:
if ml_env.is_tpu:
    # https://pytorch.org/xla/release/1.9/index.html
    import torch_xla
    import torch_xla.core.xla_model as xm
    device=xm.xla_device()  # untested!
    logging.warning('Multi-core not yet implemented!')
elif ml_env.is_gpu:
    device='cuda'
else:
    device='cpu'

In [None]:
poet = Poet(model_params['vocab_size'], model_params['neurons'], model_params['layers'], model_params['vocab_size'], device).to(device)

## 2.4 Optimizer

In [None]:
criterion = nn.CrossEntropyLoss()
opti = torch.optim.Adam(poet.parameters(),lr=model_params['learning_rate']);

## 2.5 Helper Functions

These allow to save or restore the training data. Saving and restoring can either be performed:

* Jupyter: store/restore in a local directory,
* Colab: store/restore on google drive. The training-code (using load_checkpoint()) will display an authentication url and code input-box in order to be able to access your google drive from this notebook. This allows to continue training sessions (or inference) after the Colab session was terminated.

In [None]:
# snapshot_path=os.path.join(model_path, 'Snapshots')
# os.makedirs(snapshot_path, exist_ok=True)

In [None]:
with open(os.path.join(model_path,'model_params.json'),'w') as f:
    json.dump(model_params,f,indent=4)

In [None]:
def save_checkpoint(epoch, loss, pr, best_pr, filename='checkpoint.pth.tar'):
    state={
            'epoch': epoch,
            'model_config': model_params,
            'state_dict': poet.state_dict(),
            'optimizer' : opti.state_dict(),
            'precision': pr,
            'loss': loss,
        }
    save_file=os.path.join(model_path, filename)
    best_file=os.path.join(model_path,'model_best.pth.tar')
    torch.save(state, save_file)
    if pr >= best_pr:
        shutil.copyfile(save_file, best_file )
        print(f"Saved best precision model, prec={pr:.3f}")
    else:
        print(f"Saved last model data, prec={pr:.3f}")

def save_history(history, filename="history.json"):
    save_file=os.path.join(model_path,filename)
    try:
        with open(save_file, 'w') as f:
            json.dump(history, f)
    except Exception as e:
        print(f"Failed to write training history file {save_file}, {e}")

def load_history(filename="history.json"):
    load_file=os.path.join(model_path,filename)
    try:
        with open(load_file, 'r') as f:
            history=json.load(f)
    except Exception as e:
        print(f"Starting new history file {load_file}")
        return [], time.time()
    if len(history)>0:
        start=history[-1]["timestamp"]
    return history, start

def load_checkpoint(filename='checkpoint.pth.tar'):
    load_file=os.path.join(model_path,filename)
    if not os.path.exists(load_file):
        print(load_file)
        print("No saved state, starting from scratch.")
        return 0,0
    state=torch.load(load_file)
    mod_conf = state['model_config']
    for param in ['model_name', 'learning_rate', 'batch_size']:
        if mod_conf[param]!=model_params[param]:
            print(f"Warning: project {param} has changed from {mod_conf[param]} to {model_params[param]}")
            mod_conf[param]=model_params[param]
    if model_params!=mod_conf:
        print(f"The saved model has an incompatible configuration than the current model: {mod_conf} vs. {model_params}")
        print("Cannot restore state, starting from scratch.")
        return 0,0
    poet.load_state_dict(state['state_dict'])
    opti.load_state_dict(state['optimizer'])
    epoch = state['epoch']
    loss = state['loss']
    best_pr = state['precision']
    print(f"Continuing from saved state epoch={epoch+1}, loss={loss:.3f}")  # Save is not necessarily on epoch boundary, so that's approx.
    return epoch,loss

# 3. Training

If there is already saved training data, this step is optional, and alternatively, ch. 4 can be continued.

## 3.1 Training helpers

In [None]:
def torch_data_loader(batch_size, sample_length, device):
    textlib_dataset=TextLibraryDataset(td, sample_length, device)
    data_loader=torch.utils.data.DataLoader(textlib_dataset,batch_size=batch_size, shuffle=True, num_workers=0)
    return data_loader

# Get one sample:
# X, y = next(iter(data_loader))

def precision(y, yp):
    return (torch.sum(yp==y)/float((y.size()[0]*y.size()[1]))).item()

def train(Xt, yt):
    poet.zero_grad()

    poet.init_hidden(Xt.size(0))
    output = poet(Xt, model_params['steps'])
    
    olin=output.view(-1,model_params['vocab_size'])
    _, ytp=torch.max(output,2)
    ytlin=yt.view(-1)

    pr=precision(yt,ytp)
            
    loss = criterion(olin, ytlin)
    ls = loss.item()
    loss.backward()
    opti.step()

    return ls, pr

## 3.2 The actual training loop

In [None]:
use_dark_mode=False
ls=0
nr=0
prs=0
# torch.cuda.empty_cache()
epoch_start, _ = load_checkpoint()
history, start_time = load_history()
pr=0.0
best_pr=0.0

data_loader=torch_data_loader(model_params['batch_size'], model_params['steps'], device)
    
# Make a snapshot of the trained parameters every snapshot_interval_sec
snapshot_interval_sec=180
# Generate text samples every sample_intervall_sec
sample_interval_sec=600

last_snapshot=time.time()
last_sample=time.time()

bench_all=0
bench_data=0
bench_train=0
bench_sample=0
bench_snapshot=0
bench_output_times=3  # Give 3 benchmark outputs, then stop (it stays more or less same)
sample_train_time=0

for e in range(epoch_start,2500000):
    t1=time.time()
    t0=time.time()
    for Xi,yi in data_loader:
        t2=time.time()
        # this cannot be done in data_loader, if multiprocessing is used :-/
        Xt=Xi #.to(device)
        yt=yi #.to(device)
        
        Xt.requires_grad_(False)
        yt.requires_grad_(False)

        bench_data += time.time()-t1
        t1=time.time()
        l, pr = train(Xt,yt)
        bench_train += time.time()-t1

        ls=ls+l
        prs=prs+pr
        nr=nr+1
        cur_loss=ls/nr
        cur_pr=prs/nr
        if time.time()-last_snapshot > snapshot_interval_sec:
            t1=time.time()
            nr=0
            ls=0
            prs=0
            if cur_pr>best_pr:
                best_pr=cur_pr
            last_snapshot=time.time()
            print(f"Epoch {e+1} Loss: {cur_loss:.3f} Precision: {cur_pr:.3f} Time/Sample: {sample_train_time:.6f} sec/sample")
            save_checkpoint(e,cur_loss,cur_pr, best_pr)
            # if use_cuda:
            #     print(f"Cuda memory allocated: {torch.cuda.memory_allocated()} max_alloc: {torch.cuda.max_memory_allocated()} cached: {torch.cuda.memory_cached()} max_cached: {torch.cuda.max_memory_cached()}")
            hist={"epoch": e+1, "loss": cur_loss, "precision": cur_pr, "timestamp": time.time()-start_time}
            history.append(hist)
            save_history(history)
            bench_snapshot+=time.time()-t1

            if bench_all > 0 and bench_output_times>0:
                bd=bench_data/bench_all*100.0
                bt=bench_train/bench_all*100.0
                bs=bench_sample/bench_all*100.0
                bss=bench_snapshot/bench_all*100.0
                bo=(bench_all-bench_data-bench_train-bench_sample-bench_snapshot)/bench_all*100.0
                print(f"Profiling: data-loading: {bd:.2f}%, training: {bt:.2f}%, sample gen: {bs:.2f}%, snapshots: {bss:.2f}%", end="")
                bench_output_times = bench_output_times - 1
                if bench_output_times == 0:
                    print(f" | Profiling finished.")
                else:
                    print()

                
        sample_train_time=(time.time()-t2)/len(Xt)

        if time.time()-last_sample > sample_interval_sec and cur_loss<1.5:
            t1=time.time()
            last_sample=time.time()
            for temperature in [0.6, 0.7, 0.8]:
                print(f"Temperature {temperature}:")
                tgen=poet.generate(700,". ", temperature=temperature)
                td.source_highlight(tgen,min_quote_size=10,dark_mode=use_dark_mode,display_ref_anchor=False)
            bench_sample+=time.time()-t1
        t1=time.time()
        bench_all+=time.time()-t0
        t0=time.time()


# 4. Text generation

## 4.1 Sample generation

In [None]:
load_checkpoint(filename="model_best.pth.tar")

In [None]:
def detectPlagiarism(generatedtext, textlibrary, min_quote_size=10, display_ref_anchor=True):
    textlibrary.source_highlight(generatedtext, min_quote_size=min_quote_size, dark_mode=use_dark_mode, display_ref_anchor=display_ref_anchor)

In [None]:
print("Sample text:")
print("")
for temperature in [0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]:
    tgen=poet.generate(1000,"\n\n", temperature=temperature)
    print(f"================Temperature: {temperature}==============")
    detectPlagiarism(tgen, td, display_ref_anchor=False)

## 4.2 Dialog with the model

In [None]:
# Do a dialog with the recursive neural net trained above:
def doDialog():
    temperature = 0.6  # 0.1 (free-style chaos) - >1.0 (rigid, frozen)
    endPrompt = '.'  # the endPrompt character is the end-mark in answers.
    numSentences = 3 # Try to generate numSentences terminated by endPrompt
    # maxEndPrompts = 4  # look for number of maxEndPrompts until answer is finished.
    # maxAnswerSize = 2048  # Maximum length of the answer
    # minAnswerSize = 64  # Minimum length of the answer

    
    print("Please enter some dialog.")
    print("The net will answer according to your input.")
    print("'bye' for end,")
    print("'reset' to reset the conversation context,")
    print("'temperature=<float>' [0.1(free, chaotic) - >1.0(strict, frozen)]")
    print("    to change character of the dialog.")
    # print("    Current temperature={}.".format(temperature))
    print()
    xso = None
    bye = False
    last_ans=""
        
    while not bye:
        print("> ", end="")
        prompt = input()
        if prompt == 'bye':
            bye = True
            print("Good bye!")
            continue
        if prompt.find("temperature")>=0 and prompt.find("=") > prompt.find("temperature"):
            temperature=float(prompt[prompt.find('=')+1:])
            print(f"Temperature set to {temperature}")
            continue

        prompt+=' '
        for attempts in range(0,3):
            tgen=poet.generate(2000,prompt,temperature=temperature)
            # tgen=tgen.replace("Mr.", "Mr")
            # tgen=tgen.replace("Mrs.", "Mrs")
            # tgen=tgen.replace("\n"," ")
            # tgen=tgen.replace("  "," ")
            tgi=tgen.split(". ")
            print(f"{len(tgi)} sentences")
            if len(tgi)<numSentences:
                continue
            ans=""
            for i in range(0,numSentences):
                ans += tgi[i]+". "
            break
            # i=tgen.find(endPrompt)
            # i2=tgen[i+1:].find(endPrompt)+i
            # i3=tgen[i2+1:].find(endPrompt)+i2
            # i4=tgen[i3+1:].find(endPrompt)+i3
            # tgen=tgen[i+1:i4+2]
            # if len(tgen)>10:
            #     break
        last_ans=ans
        td.source_highlight(last_ans, min_quote_size=10,dark_mode=use_dark_mode,display_ref_anchor=False)
    return

In [None]:
doDialog()