<a href="https://colab.research.google.com/github/benluks/LRQW/blob/main/bin_lstm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from datetime import datetime
from operator import lt, gt
from pathlib import Path

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter

from tqdm import tqdm

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

### Utility `binarize` Function

In [None]:
def binarize(W):
  """
  Binarize normalized weight matrix according to 
  https://arxiv.org/abs/1809.11086
  """
  W_b = torch.clone(W)
  W_b.add_(1).div_(2)
  mask = torch.rand((W_b.shape))
  W_b.add_(-mask)
  W_b = W_b.sign()
  return W_b

## Updates Test

In [None]:
model = nn.Linear(1, 1, bias=False)

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
loss_fn = nn.MSELoss()

In [None]:
model.weight.data

tensor([[0.6611]])

In [None]:
model.org = model.weight.data
model.org

tensor([[0.0911]])

In [None]:
model.weight.grad = torch.tensor([11.4])
model.weight.grad

tensor([11.4000])

In [None]:
model.weight.data = torch.tensor([6.])
optimizer.zero_grad()

y = model(torch.tensor([1.]))
loss = loss_fn(y, torch.tensor(0.3))
loss.backward()
model.weight.grad

tensor([22.8000])

In [None]:
model.weight.data = model.org

In [None]:
model.weight.data

tensor([[0.0911]])

In [None]:
optimizer.step()
model.weight.data

tensor([[0.0911]])

In [None]:
model.weight

Parameter containing:
tensor([[0.0911]], requires_grad=True)

## QLSTM

In [None]:
def qlstm_cell(input, hidden, w_ih, w_hh, b_ih, b_hh, bn=nn.Identity()):
  
  hx, cx = hidden
  batch_size, hidden_size = hx.shape

  # gates: [B, 8*H] => [B, 8, H]
  gates = torch.cat((torch.mm(input, w_ih.t()) + b_ih, torch.mm(hx, w_hh.t()) + b_hh), dim=1).view(batch_size, 8, hidden_size)
  # gates: [B, 8, H] => [B, 2, 4, H] => (sum) => [B, 4, H]
  gates = bn(gates).view(batch_size, 2, 4, hidden_size).sum(1)
  # gates: 4 * ([B, H],)
  ingate, forgetgate, cellgate, outgate = gates.unbind(1)
  
  ingate = torch.sigmoid(ingate)
  forgetgate = torch.sigmoid(forgetgate)
  cellgate = torch.tanh(cellgate)
  outgate = torch.sigmoid(outgate)

  cy = (forgetgate * cx) + (ingate * cellgate)
  hy = outgate * torch.tanh(cy)

  return hy, cy

In [None]:
class QLSTM(nn.LSTM):

  def __init__(self, quant, *args, **kwargs):
    super().__init__(*args, **kwargs)

    unimplemented = [self.batch_first, self.bidirectional]
    if True in unimplemented:
      err = unimplemented[unimplemented.index(True)]
      raise NotImplementedError(f"Support for {err} is not yet implemented. Please initialize QLSTM with `{err}=False`")

    self.quant = binarize if quant == 'bin' else quant
    if self.quant:    
      for layer in range(self.num_layers):
        bn = nn.BatchNorm1d(8)
        bn.bias.requires_grad_(False)
        self.add_module(f'bn_l{layer}', bn)
    

  def _save_and_quantize_params(self):
    """
    save full-precision params (weight or bias, not bn)
    and binarize original data
    """
    for name, par in self.named_parameters():
      if name[:2] != 'bn': 
        self.setattr(f'org_{par}', self.par.data)
        self.par.data = self.quant(self.par)

  def forward(self, input, h_0=None):

    if self.quant:
      self._save_and_quantize_params()
    T = input.size(0) if not self.batch_first else input.size(1)
    
    # final hidden states (h and c) for each layer
    h_t = []

    for layer in range(self.num_layers):
      
      layer_params = [p for n, p in self.named_parameters() if n[-1] == str(layer)]
      if self.quant:
        layer_params.append(self.getattr(f'bn_l{layer}'))
      
      # TODO: more graceful way to innitialize return values so if statements
      # not needed in forward loop. Ditto for `h_t = None` above.

      outputs = []

      hidden = h_0 if h_0 else 2*(torch.zeros(input.size(1), self.hidden_size),)
      for t in range(T):

        hidden = qlstm_cell(input[t], hidden, *layer_params)
        outputs.append(hidden[0])
      
      # all time-steps are done, end T loop
      # -----------------------------------

      h_t.append(hidden)
      outputs = torch.stack(outputs, 0)
      # prev hidden states as following layer's input
      input = outputs
    
    # h_t is [(h, c), (h, c), ...], we want to separate into lists
    # [[h_0, h_1, ...], [c_0, c_1, ...]]
    h_t, c_t = list(zip(*h_t))
    h_t, c_t = torch.stack(h_t, 0), torch.stack(c_t, 0)

    return outputs, (h_t, c_t)

## Test if implementation matches stock PyTorch 

In [None]:
BATCH = 1
T = 6
H = 3
F = 2
L = 4

q = QLSTM(None, F, H, num_layers=L)
real = nn.LSTM(F, H, num_layers=L)

for n, p in q.named_parameters():
  real.get_parameter(n).data = p.data

x = torch.zeros(T, BATCH, F)

In [None]:
mine = q(x)
theirs = real(x)

In [None]:
same = lambda t: torch.all(t).item()

# PTBChar Experiment

## Data

In [None]:
!pip install datasets
from datasets import load_dataset
dataset = load_dataset("ptb_text_only")

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.2.2-py3-none-any.whl (346 kB)
[K     |████████████████████████████████| 346 kB 6.6 MB/s 
Collecting huggingface-hub<1.0.0,>=0.1.0
  Downloading huggingface_hub-0.7.0-py3-none-any.whl (86 kB)
[K     |████████████████████████████████| 86 kB 1.6 MB/s 
Collecting fsspec[http]>=2021.05.0
  Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 74.7 MB/s 
[?25hCollecting dill<0.3.5
  Downloading dill-0.3.4-py2.py3-none-any.whl (86 kB)
[K     |████████████████████████████████| 86 kB 7.7 MB/s 
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (1.1 MB)
[K     |████████████████████████████████| 1.1 MB 59.7 MB/s 
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting xxhas

Downloading builder script:   0%|          | 0.00/2.69k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.06k [00:00<?, ?B/s]

Downloading and preparing dataset ptb_text_only/penn_treebank (download: 5.68 MiB, generated: 5.72 MiB, post-processed: Unknown size, total: 11.40 MiB) to /root/.cache/huggingface/datasets/ptb_text_only/penn_treebank/1.1.0/8d1b97746fb9765d140e569ec5ddd35e20af4d37761f5e1bf357ea0b081f2c1f...


Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/1.70M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/135k [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/150k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/42068 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/3761 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/3370 [00:00<?, ? examples/s]

Dataset ptb_text_only downloaded and prepared to /root/.cache/huggingface/datasets/ptb_text_only/penn_treebank/1.1.0/8d1b97746fb9765d140e569ec5ddd35e20af4d37761f5e1bf357ea0b081f2c1f. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [None]:
raw_text = lambda split: '. '.join(dataset[split]['sentence'])

### Tokenizer

In [None]:
class PTBCharTokenizer:

  def __init__(self, raw_text):
    self.raw_text = raw_text
    self.tokens = [''] + sorted(list(set(raw_text)))
    self.char2idx = {token: self.tokens.index(token) for token in self.tokens}
  
  def __len__(self): return len(self.char2idx)

  def encode(self, sent): return torch.tensor([self.char2idx[char] for char in sent])

  def decode(self, inds): return ''.join([self.tokens[ind] for ind in inds])

### Dataset Object

In [None]:
class PTBData(torch.utils.data.Dataset):

  def __init__(self, raw_text, tokenizer, seq_len):
    self.raw_text = raw_text
    self.tokenizer = tokenizer(raw_text)
    self.seq_len = seq_len

  def __getitem__(self, index): 
    
    if index >= len(self):

      raise IndexError

    if index < 0:
      index = len(self) + index

    start = index*self.seq_len
    end = start+self.seq_len
    sent = self.raw_text[start:end]
    return self.tokenizer.encode(sent)

  def __len__(self): return len(self.raw_text) // self.seq_len

## Model

In [None]:
class PTBCharModel(nn.Module):

  def __init__(self, **kwargs):
    super().__init__()
    self.vocab_size = kwargs['vocab_size']
    self.hidden_size = kwargs['hidden_size']

    self.lstm = nn.LSTM(self.vocab_size, self.hidden_size)
    self.linear_proj = nn.Linear(self.hidden_size, self.vocab_size)
    self.softmax = nn.Softmax(dim=-1)


  def _process_input(self, x):
      if x.dim() == 2:
      # accept indices and convert to one_hot
        x = F.one_hot(x, self.lstm.input_size).float()
      
      return x


  def compute_gates(self, input, hidden=None):
    
    input = self._process_input(input)
    if not hidden:
      hidden = 2*(torch.zeros(input.size(0), self.hidden_size).to(device),)

    B, T = input.size()[:2]

    _result = torch.empty(4, B, T, self.hidden_size).to(device)

    with torch.no_grad():
      # input: [B, T, V]
      for t in range(input.size(1)):
        hx, cx = hidden
        x = input[:, t]

        gates = torch.mm(x, self.lstm.weight_ih_l0.t()) + \
                torch.mm(hx, self.lstm.weight_hh_l0.t()) + \
                self.lstm.bias_ih_l0 + self.lstm.bias_hh_l0

        # gates: [B, 4*H]
        ingate, forgetgate, cellgate, outgate = gates.chunk(4, 1)
        # 4*([B, H])
        ingate = torch.sigmoid(ingate)
        forgetgate = torch.sigmoid(forgetgate)
        cellgate = torch.tanh(cellgate)
        outgate = torch.sigmoid(outgate)

        _result[[0,1,2,3], :, t] = torch.stack([ingate, forgetgate, cellgate, outgate])

    return _result


  def forward(self, x, hidden=None):

    x = self._process_input(x)

    x = x.permute(1, 0, 2)
    # x: [B, T, N] => [T, B, N]
    x, hidden = self.lstm(x, hidden)
    x = x.permute(1, 0, 2)
    # x: [T, B, H] => [B, T, H]
    x = self.linear_proj(x)
    return self.softmax(x), hidden

## Training

### Trainer Class

In [None]:
class Trainer:

  def __init__(self, params):
    
    # hparams
    self.params = params
    hparams = self.params['hparams']

    self.num_epochs = hparams['num_epochs']
    self.lr = hparams['lr']
    self.batch_size = hparams['batch_size']
    
    self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
    
    # logging and saving
    self.output_dir = Path(params['output_dir']) if 'output_dir' in params \
                      else Path('.') / datetime.now().strftime("%Y-%m-%d-%H_%M")
    self.writer = SummaryWriter(log_dir=self.output_dir / 'log/')
    self.checkpoint_dir = self.output_dir / 'checkpoint'

    # whether to use BPC as metric
    self.BPC = hparams['BPC'] if 'BPC' in hparams.keys() else False

  def build(self):
    """
    Dedicated function to execute more storage- and resource-intensive 
    initializations
    """

    # model
    self.model = PTBCharModel(**self.params['model']).to(self.device)

    # data
    self.train_set = PTBData(raw_text('train'), PTBCharTokenizer, self.params['seq_len'])
    self.valid_set = PTBData(raw_text('validation'), PTBCharTokenizer, self.params['seq_len'])

    self.train_loader = torch.utils.data.DataLoader(self.train_set, 
                                                    batch_size=self.batch_size, 
                                                    pin_memory=True, shuffle=True)
    self.valid_loader = torch.utils.data.DataLoader(self.valid_set, 
                                                    batch_size=self.batch_size, 
                                                    pin_memory=True, shuffle=True)
    
    # optimization
    self.optimizer = getattr(torch.optim, self.params['optimizer'])(self.model.parameters(), lr=self.lr)
    self.criterion = nn.CrossEntropyLoss(ignore_index=0, reduction='none' if self.BPC else 'mean')

    self.checkpoint_dir.mkdir(exist_ok=True)


  def write_progress(self, epoch):
    
    final_log = f"[{epoch}/{self.num_epochs}]:"
    
    for metric in self.latest:
      for dset in self.latest[metric]:
        value = self.latest[metric][dset]
        if value:
        # exclude NoneType
          value = f"{round(value * 100, 3)}%" if metric == 'accuracy' else round(value, 3)
          final_log += f" {f'{dset} {metric}'.capitalize()}: {value}"
          final_log += " |"
    final_log += "|"

    print(final_log)
        
    if epoch == self.num_epochs:
      print(f"""
      Completed training after {self.num_epochs} epochs with:
        
        Best loss = {round(self.best['loss']['value'], 3)} in epoch {self.best['loss']['epoch']}, and
        Best accuracy = {round(self.best['accuracy']['value']*100, 3)}% in epoch {self.best['accuracy']['epoch']}.
        """)


  def step(self, batch):

    input, label = batch[:, :-1].to(self.device), batch[:, 1:].to(self.device)
    input = F.one_hot(input, self.model.vocab_size).float()
    # for training, hidden and context aren't necessary
    output, _ = self.model(input)
    
    losses = self.criterion(output.view(-1, self.model.vocab_size), label.view(-1))

    bpc = losses.log2().mean()
    loss = losses.mean()

    return loss, bpc, output, label


  def validate(self):
    
    running_loss = 0
    running_bpc = 0
    total_correct = 0

    with torch.no_grad():
      for batch in tqdm(self.valid_loader, position=0):

        loss, bpc, output, label = self.step(batch)

        running_loss += loss.item()
        running_bpc += bpc.item()
        
        # compute accuracy
        predictions = output.topk(1, dim=-1).indices.squeeze(-1)
        total_correct += (predictions == label).sum().item()

    loss = running_loss / len(self.valid_loader)
    bpc = running_bpc / len(self.valid_loader)
    accuracy = total_correct / (len(self.valid_set)*(self.params['seq_len'] - 1))

    return loss, bpc, accuracy


  def __call__(self):
    self.train()


  def save_model(self, epoch):

    # delete old checkpoints
    [f.unlink() for f in self.checkpoint_dir.glob('*')]

    torch.save(self.model.state_dict(), self.checkpoint_dir / f"e{epoch}.pth")
    torch.save(self.optimizer.state_dict(), self.checkpoint_dir / f"opt_e{epoch}.pth")

    print("Saved new best model")


  def log_progress(self, epoch):
    for metric in self.latest:
      self.writer.add_scalars(
          metric, 
          {metric: value for dset, value in self.latest[metric].items() if value}, 
          epoch)


  def update_milestone(self, epoch):
    
    is_best = False

    for metric in self.latest:
      op = gt if metric == 'accruacy' else lt
      if self.latest[metric]['valid'] and op(self.latest[metric]['valid'], self.best[metric]['value']):
        self.best[metric]['value'] = self.latest[metric]['valid']
        self.best[metric]['epoch'] = epoch
        is_best = True

    if is_best:
      self.save_model(epoch)

  def train(self):

    self.best = {
        # computed on validation data
        'loss': {'epoch': -1, 'value': float('INF')},
        'accuracy': {'epoch': -1, 'value': 0},
        'bpc': {'epoch': -1, 'value': float('INF')}
    }
    self.latest = {
        'loss': {'train': None, 'valid': None},
        'bpc': {'train': None, 'valid': None},
        'accuracy': {'train': None, 'valid': None}   
    }

    for epoch in range(self.num_epochs):
      
      running_loss = 0
      running_bpc = 0
      
      for batch in tqdm(self.train_loader, position=0):
        
        # speeds up training as opposed to optimizer.zero_grad(). Read more here:
        # https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#use-parameter-grad-none-instead-of-model-zero-grad-or-optimizer-zero-grad
        self.optimizer.zero_grad()
        
        loss, bpc, _, _ = self.step(batch)
        running_loss += loss.item()
        running_bpc += bpc.item()          

        # backpropagate and optimize
        loss.backward()
        self.optimizer.step()

      train_loss = running_loss / len(self.train_loader)
      train_bpc = running_bpc / len(self.train_loader)

      valid_loss, valid_bpc, accuracy = self.validate()
      
      self.latest['loss']['train'] = train_loss
      self.latest['loss']['valid'] = valid_loss
      self.latest['bpc']['train'] = train_bpc
      self.latest['bpc']['valid'] = valid_bpc
      self.latest['accuracy']['valid'] = accuracy

      self.log_progress(epoch+1)
      self.update_milestone(epoch+1)
      self.write_progress(epoch+1)

### Params

In [None]:
config = {

  'hparams': {
    'num_epochs': 50,
    'lr': 0.002,
    'batch_size': 64,
    'BPC': True 
    },
      
  'model': {
    'vocab_size': 50,
    'hidden_size': 1000,
    'num_layers': 2
      },

  'optimizer': 'Adam',   # as of now, string must match torch.optim class verbatim (ie. Adam, not `adam`, `ADAM`, etc...)
  'seq_len': 100
}


## Main

In [None]:
trainer = Trainer(config)

In [None]:
trainer.build()

In [None]:
%reload_ext tensorboard
%tensorboard --logdir $trainer.output_dir/log

In [None]:
trainer()

## Testing

In [None]:
def generate_from_prompt(tokenizer, prompt, model, length=100, device=device):

  input = tokenizer.encode(prompt)
  generated = []

  hidden = None

  for i in range(length):
    
    output, hidden = model(input.unsqueeze(0), hidden)
    # output is dist of shape [1, T, V]
    pred = torch.multinomial(output.squeeze(0)[-1], 1)
    generated.append(pred.item())

    input = pred

  return generated

In [None]:
model = PTBCharModel(vocab_size=50, hidden_size=1000)
model.load_state_dict(torch.load('/content/drive/MyDrive/model_state/ptb_char50.pth', map_location=device))

tokenizer = PTBCharTokenizer(raw_text('train'))

prompt = 'once upon a time'

In [None]:
generated = generate_from_prompt(tokenizer, prompt, model, 1000)
tokenizer.decode(generated)

' and the state of the senior of the senior of the senior and an analyst and the state of the senior the senior the senior the senior the senior of the senior the senior the senior of the senior the senior the senior the senior of the senior of the senior of the senior of the senior the senior the senior the senior of the senior of the senior of the senior atternation of the senior the senior of the senior the senior of the senior of the senior telecoming and the senior of the senior the senior the senior of the senior of the senior the senior the senior of the senior the senior of the senior of the senior of the senior of the senior the senior telecoming and the state of the senior of the senior of the senior and an analyst and the stock and the state of the senior of the senior of the senior the senior the senior telecoming and the state of the senior the senior of the senior of the senior of the senior and an analyst and the stock and the state of the senior of the senior the senior

In [None]:
trainer.model

PTBCharModel(
  (lstm): LSTM(50, 1000)
  (linear_proj): Linear(in_features=1000, out_features=50, bias=True)
  (softmax): Softmax(dim=-1)
)

In [None]:
def gates_historgram(writer, gates):

  i, f, g, o = gates.squeeze(1)
  # i, f, g, o: 4*([T, H],)

  for t in range(gates.size(2)):
    writer.add_histogram('i', i[t], t)
    writer.add_histogram('f', f[t], t)
    writer.add_histogram('g', g[t], t)
    writer.add_histogram('o', o[t], t)

In [None]:
test_set = PTBData(raw_text('test'), PTBCharTokenizer, 100)

In [None]:
model = PTBCharModel(vocab_size=50, hidden_size=1000).to(device)
model.load_state_dict(torch.load('/content/drive/MyDrive/model_state/2022-06-01-17_28/checkpoint/e49.pth', map_location=device))

torch.Size([4, 1, 100, 1000])

In [None]:
writer = SummaryWriter()

In [None]:
prompt = test_set[57].to(device)
gates = model.compute_gates(prompt.unsqueeze(0))
gates.shape

torch.Size([4, 1, 100, 1000])

In [None]:
gates_historgram(writer, gates)

In [None]:
PTBCharTokenizer(test_set.raw_text).decode(prompt)

'> up while stocks in new york kept falling sharply. big board chairman john j. phelan said yesterday'

### Tensorboard support