<a href="https://colab.research.google.com/github/leburik-1/machine_learning/blob/main/text_preproc_jax.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import numpy as np
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import math

import collections
import re
import random
import requests
import zipfile
import hashlib

try:
  import torch
except ModuleNotFoundError:
  %pip install -qq torch # Installs torch if not found
  import torch
from torch.utils import data # Pytorch data utilities

if not os.path.exists("figures"):
  os.makedirs("figures") # for saving plots


DATA_HUB = dict()
DATA_URL = "http://d2l-data.s3-accelerate.amazonaws.com/"
DATA_HUB["time_machine"] = (DATA_URL + "timemachine.txt","090b5e7e70c295757f55df93cb0a180b9691891a")

# Required functions for downloading data

def download(name, cache_dir=os.path.join("..", "data")):
  """Download a file inserted into DATA_HUB, return the local filename."""
  assert name in DATA_HUB, f"{name} does not exist in {DATA_HUB}."
  url, sha1_hash = DATA_HUB[name]
  os.makedirs(cache_dir, exist_ok=True)
  fname = os.path.join(cache_dir, url.split("/")[-1])
  if os.path.exists(fname):
    sha1 = hashlib.sha1()
    with open(fname, "rb") as f:
      while True:
        data = f.read(1048576)
        if not data:
          break
        sha1.update(data)
    if sha1.hexdigest() == sha1_hash:
      return fname   # Hit cache
    print(f"Downloading {fname} from {url}...")
  r = requests.get(url, stream=True, verify=True)
  with open(fname, "wb") as f:
    f.write(r.content)
  return fname

fname = download("time_machine")
print("Downloaded File:",fname)


def download_extract(name, folder=None):
  """Download and extract a zip/tar file."""
  fname = download(name)
  base_dir = os.path.dirname(fname)
  data_dir, ext = os.path.splitext(fname)

  if ext == ".zip":
    fp = zipfile.ZipFile(fname, "r")
  elif ext in (".tar", ".gz"):
    fp = tarfile.open(fname, "r")
  else:
    assert False, "Only zip/tar files can be extracted."
  fp.extractall(base_dir)
  return os.path.join(base_dir, folder) if folder else data_dir

def read_time_machine():
  """Load the time machine dataset into a list of text lines."""
  with open(download("time_machine"), "r") as f:
    lines = f.readlines()
  return [re.sub("[^A-Za-z]+", " ", line).strip().lower() for line in lines]

lines = read_time_machine()
print(f"number of lines: {len(lines)}")

for i in range(11):
  print(i, lines[i])

nchars = 0
nwords = 0

for i in range(len(lines)):
  nchars += len(lines[i])
  words = lines[i].split()
  nwords += len(words)

print("total num charactes ", nchars)
print("total num words ", nwords)

def tokenize(lins, token="word"):
  """Split text lines into word or character tokens."""
  if token == "word":
    return [line.split() for line in lines]
  elif token == "char":
    return [list(line) for line in lines]
  else:
    print("Error: unknown token type: " + token)

tokens = tokenize(lines)
for i in range(11):
  print(tokens[i])

def count_corpus(tokens):
  """Count token frequencies."""
  # Here `Tokens` is a 1D list or 2D list
  if len(tokens) == 0 or isinstance(tokens[0], list):
    # Flatten a list of token lists into a list of tokens
    tokens = [token for line in tokens for token in line]
  return collections.Counter(tokens)


class Vocab:
  """Vocabulary for text."""

  def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
    if tokens is None:
      tokens = []
    if reserved_tokens is None:
      reserved_tokens = []

    # Sort according to frequencies
    counter = count_corpus(tokens)
    self.token_freqs = sorted(counter.items(), key=lambda x: x[1], reverse=True)
    # The index for the unknown token is 0
    self.unk, uniq_tokens = 0, ["<unk>"] + reserved_tokens

    uniq_tokens += [ token for token, freq in self.token_freqs if freq >= min_freq and token not in uniq_tokens ]
    self.idx_to_token, self.token_to_idx = [], dict()
    for token in uniq_tokens:
      self.idx_to_token.append(token)
      self.token_to_idx[token] = len(self.idx_to_token) - 1

  def __len__(self):
    return len(self.idx_to_token)

  def __getitem__(self, tokens):
    if not isinstance(tokens, (list,tuple)):
      return self.token_to_idx.get(tokens, self.unk)
    return [self.__getitem__(token) for token in tokens]

  def to_tokens(self, indices):
    if not isinstance(indices, (list, tuple)):
      return self.idx_to_token[indices]
    return [ self.idx_to_token for index in indices ]

vocab = Vocab(tokens)
print(list(vocab.token_to_idx.items())[:10])

for i in [0,8,10]:
  print("words: ", tokens[i])
  print("indices:", vocab[tokens[i]])

def load_corpus_time_machine(max_tokens=-1):
  """Return token indices and the vocabulary of the time machine dataset."""
  lines = read_time_machine()
  tokens = tokenize(lines, "char")
  vocab = Vocab(tokens)
  # Since each text line in the time machine dataset is not necessarily a
  # sentence or a paragraph, flatten all the text lines into a single list
  corpus = [ vocab[token] for line in tokens for token in line ]
  if max_tokens > 0:
    corpus = corpus[:max_tokens]
  return corpus, vocab

corpus, vocab = load_corpus_time_machine()
len(corpus), len(vocab)

print(f'corpus : {corpus[:20]}')
print(f'list_vocab_token_to_idx_items : {list(vocab.token_to_idx.items())[:10]}')
print(f'vocab_idx_to_token : {[vocab.idx_to_token[i] for i in corpus[:20]]}')

x = jnp.array(corpus[:3])
print(x)
X = jax.nn.one_hot(x, len(vocab))
print(X.shape)
print(X)

def seq_data_iter_random(corpus, batch_size, num_steps):
  """Generat a minibatch of subsequences using random sampling."""
  # Start with a random offset (inclusive of `num_steps - 1`) to partition a sequence
  corpus = corpus[random.randint(0, num_steps - 1) :]
  # Substract 1 since we need to account for labels
  num_subseqs = (len(corpus) - 1) // num_steps
  # The starting indices for subsequences of length `num_steps`
  initial_indices = list(range(0, num_subseqs * num_steps, num_steps))

  # In random sampling, the subsequences from two adjacent random
  # minibatches during iteration are not necessarily adjacent on the original sequence
  random.shuffle(initial_indices)

  def data(pos):
    # return a sequence of length `num_steps` starting from `pos`
    return corpus[pos: pos + num_steps]

  num_batches = num_subseqs // batch_size
  for i in range(0, batch_size * num_batches, batch_size):
    initial_indices_per_batch = initial_indices[i : i + batch_size]
    X = [data(j) for j in initial_indices_per_batch]
    Y = [data(j + 1) for j in initial_indices_per_batch]

    yield jnp.array(X), jnp.array(Y)

my_seq = list(range(35))
b = 0

for X, Y in seq_data_iter_random(my_seq, batch_size=2, num_steps=5):
  print("batch: ", b)
  print("X: ",X, "\nY:", Y)
  b += 1

def seq_data_iter_sequential(corpus, batch_size, num_steps):
    """Generate a minibatch of subsequences using sequential partitioning."""
    # Start with a random offset to partition a sequence
    offset = random.randint(0, num_steps)
    num_tokens = ((len(corpus) - offset - 1) // batch_size) * batch_size
    Xs = jnp.array(corpus[offset : offset + num_tokens])
    Ys = jnp.array(corpus[offset + 1 : offset + 1 + num_tokens])
    Xs, Ys = Xs.reshape(batch_size, -1), Ys.reshape(batch_size, -1)
    num_batches = Xs.shape[1] // num_steps
    for i in range(0, num_steps * num_batches, num_steps):
        X = Xs[:, i : i + num_steps]
        Y = Ys[:, i : i + num_steps]
        yield X, Y

for X, Y in seq_data_iter_sequential(my_seq, batch_size=2, num_steps=5):
  print("X: ", X, "\nY:", Y)

class SeqDataLoader:
  """An iterator to load sequence data."""

  def __init__(self, batch_size, num_steps, use_random_iter, max_tokens):
    if use_random_iter:
      self.data_iter_fn = seq_data_iter_random
    else:
      self.data_iter_fn = seq_data_iter_sequential
    self.corpus, self.vocab = load_corpus_time_machine(max_tokens)
    self.batch_size, self.num_steps = batch_size, num_steps

  def __iter__(self):
    return self.data_iter_fn(self.corpus, self.batch_size, self.num_steps)

def load_data_time_machine(batch_size, num_steps, use_random_iter=False, max_tokens=10000):
  """Return the iterator and the vocabulary of the time machine dataset."""
  data_iter = SeqDataLoader(batch_size, num_steps, use_random_iter, max_tokens)
  return data_iter, data_iter.vocab

data_iter, vocab = load_data_time_machine(2,5)
print(list(vocab.token_to_idx.items())[:10])

b = 0
for X, Y in data_iter:
  print("batch: ", b)
  print("X: ", X, "\nY: ", Y)
  b += 1
  if b > 2:
    break

DATA_HUB["fra-eng"] = (DATA_URL + "fra-eng.zip", "94646ad1522d915e7b0f9296181140edcf86a4f5")

def read_data_nmt():
    """Load the English-French dataset."""
    data_dir = download_extract("fra-eng")
    with open(os.path.join(data_dir, "fra.txt"), "r") as f:
        return f.read()

raw_text = read_data_nmt()
print(raw_text[:100])

raw_text = read_data_nmt()
print(raw_text[:100])

def preprocess_nmt(text):
    """Preprocess the English-French dataset."""

    def no_space(char, prev_char):
        return char in set(",.!?") and prev_char != " "

    # Replace non-breaking space with space, and convert uppercase letters to
    # lowercase ones
    text = text.replace("\u202f", " ").replace("\xa0", " ").lower()
    # Insert space between words and punctuation marks
    out = [" " + char if i > 0 and no_space(char, text[i - 1]) else char for i, char in enumerate(text)]
    return "".join(out)

text = preprocess_nmt(raw_text)
print(text[:110])

def tokenize_nmt(text, num_examples=None):
  """Tokenize the English-French dataset."""
  source, target = [], []
  for i, line in enumerate(text.split("\n")):
    if num_examples and i > num_examples:
      break
    parts = line.split("\t")
    if len(parts) == 2:
      source.append(parts[0].split(" "))
      target.append(parts[1].split(" "))

  return source, target

source, target = tokenize_nmt(text)
source[:10], target[:10]

src_vocab = Vocab(source, min_freq=2, reserved_tokens=["<pad>", "<bos>", "<eos>"])
print(f" SRV-VOCAB {len(src_vocab)}")

# French has more high frequency words than English
target_vocab = Vocab(target, min_freq=2, reserved_tokens=["<pad>", "<bos>", "<eos>"])
print(f" TARGET_VOCB {len(target_vocab)}")

def truncate_pad(line, num_steps, padding_token):
  """Truncate or pad sequences."""
  if len(line) > num_steps:
    return line[:num_steps]  # Truncate
  return line + [padding_token] * (num_steps - len(line)) # Pad

print(f'Truncate Pad - {truncate_pad(source[0], 10, "pad")}')
print(f'Truncate Pad(SRC) - {truncate_pad(src_vocab[source[0]], 10, src_vocab["<pad>"])}')

def build_array_nmt(lines, vocab, num_steps):
  """Transform text sequences of machine translation into minibatches."""
  lines = [vocab[l] for l in lines]
  lines = [ l + [vocab["<eos>"]] for l in lines]
  array = torch.tensor([ truncate_pad(l, num_steps, vocab["<pad>"]) for l in lines ])
  valid_len = ( array != vocab["<pad>"] ).type(torch.int32).sum(1)
  return array, valid_len

num_steps = 10
src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
print(f'SRC-array shape {jnp.array(src_array).shape}')
print(f'SRC-valid len {jnp.array(src_valid_len).shape}')

print(f'SRC ARRAY {jnp.array(src_array[0, :])}')
print(f'SRC ARRAY {jnp.array(src_valid_len[0])}')

def load_array(data_arrays, batch_size, is_train=True):
  """Construct a PyTorch data iterator."""
  dataset = data.TensorDataset(*data_arrays)
  return data.DataLoader(dataset, batch_size, shuffle=is_train)

def load_data_nmt(batch_size, num_steps, num_examples=600):
  """Return the iterator and the vocabularies of the translation dataset."""
  text = preprocess_nmt(read_data_nmt());
  source, target = tokenize_nmt(text,num_examples)
  src_vocab = Vocab(source, min_freq=2, reserved_tokens=["<pad>", "<bos>","<eos>"])
  tgt_vocab = Vocab(target, min_freq=2, reserved_tokens=["<pad>", "<bos>", "<eos>"])
  src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
  tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
  data_array = (src_array, src_valid_len, tgt_array, tgt_valid_len)
  data_iter = load_array(data_array, batch_size)
  return data_iter, src_vocab, tgt_vocab


train_iter, src_vocab, tgt_vocab = load_data_nmt(batch_size=2, num_steps=8)
for X, X_valid_len, Y, Y_valid_len in train_iter:
  print("X:", jnp.array(X).astype(jnp.int32))
  print("valid lengths for X:", jnp.array(X_valid_len))
  print("Y:", jnp.array(Y).astype(jnp.int32))
  print("valid lengths for Y:", jnp.array(Y_valid_len))
  break



Downloaded File: ../data/timemachine.txt
number of lines: 3221
0 the time machine by h g wells
1 
2 
3 
4 
5 i
6 
7 
8 the time traveller for so it will be convenient to speak of him
9 was expounding a recondite matter to us his grey eyes shone and
10 twinkled and his usually pale face was flushed and animated the
total num charactes  170580
total num words  32775
['the', 'time', 'machine', 'by', 'h', 'g', 'wells']
[]
[]
[]
[]
['i']
[]
[]
['the', 'time', 'traveller', 'for', 'so', 'it', 'will', 'be', 'convenient', 'to', 'speak', 'of', 'him']
['was', 'expounding', 'a', 'recondite', 'matter', 'to', 'us', 'his', 'grey', 'eyes', 'shone', 'and']
['twinkled', 'and', 'his', 'usually', 'pale', 'face', 'was', 'flushed', 'and', 'animated', 'the']
[('<unk>', 0), ('the', 1), ('i', 2), ('and', 3), ('of', 4), ('a', 5), ('to', 6), ('was', 7), ('in', 8), ('that', 9)]
words:  ['the', 'time', 'machine', 'by', 'h', 'g', 'wells']
indices: [1, 19, 50, 40, 2183, 2184, 400]
words:  ['the', 'time', 'traveller'