# Imports

In [None]:
from transformers import DistilBertTokenizer, DistilBertModel

import torch
import torchtext 
from torchtext.datasets import IMDB

import nltk
from nltk.stem import WordNetLemmatizer
from nltk.tokenize import word_tokenize
from nltk.corpus import stopwords
from nltk.tokenize import RegexpTokenizer

import nlpaug
import nlpaug.augmenter.char as nac
import nlpaug.augmenter.word as naw
import nlpaug.augmenter.sentence as nas
import nlpaug.flow as nafc

from nlpaug.util import Action
from nlpaug.util.file.download import DownloadUtil

import re
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm, trange

if False:
  DownloadUtil.download_word2vec(dest_dir='.')
  DownloadUtil.download_glove(model_name='glove.6B', dest_dir='.')

  nltk.download('punkt')
  nltk.download('wordnet')
  nltk.download('omw-1.4')
  nltk.download('averaged_perceptron_tagger')
  nltk.download('stopwords')
  download_switch = False

# Model

In [None]:
class PositionEncoding(torch.nn.Module):
    def __init__(self, n):
        super(PositionEncoding, self).__init__()
        
        self.E = np.zeros((2**(n-1), n), dtype=np.float32)
        for i in range(n):
            self.E[:, i] = self.periodic(np.arange(2**(n-1)), 2**i)
        
        self.E = torch.tensor(self.E).unsqueeze(0)
        
    def forward(self, X):
        return torch.cat((self.E[:,0:X.shape[1],:].repeat(X.shape[0],1,1), X), dim=2)

    def _apply(self, fn):
        super(PositionEncoding, self)._apply(fn)
        self.E = fn(self.E)
        return self
    
    @staticmethod
    def periodic(x, n):
        return 2*np.abs(np.mod(x/n, 2) - 1) - 1


class Encoder(torch.nn.Module):
  def __init__(self):
    super(Encoder, self).__init__()

    self.model = DistilBertModel.from_pretrained("distilbert-base-uncased")
    self.model.embeddings = torch.nn.Identity()
    self.expand = torch.nn.Sequential(
        torch.nn.Linear(768, 1024*2),
        torch.nn.ReLU(),
        torch.nn.Linear(1024*2, 1024*2),
        torch.nn.ReLU()
    )
    self.output_expect = torch.nn.Linear(1024*2, 1024*4)
    self.output_logstd = torch.nn.Linear(1024*2, 1024*4)

  def forward(self, X, M, keepdim=False, return_std=False):
    
    X = self.model(X, M).last_hidden_state
    X = self.expand(X)
    E = self.output_expect(X).mean(1, keepdim=keepdim)
    if return_std:
      S = self.logstd(X).exp().mean(1, keepdim=keepdim)
      return E, S
    else:
      return E

class Decoder(torch.nn.Module):
  def __init__(self, max_log2len=8):
    super(Decoder, self).__init__()

    self.posenc = PositionEncoding(max_log2len)

    self.input = torch.nn.Sequential(
        torch.nn.Linear(1024*4 + max_log2len, 1024*1),
        torch.nn.ReLU(),
        torch.nn.Linear(1024*1, 768*1),
        torch.nn.ReLU(),
        torch.nn.Linear(768*1, 768),
    )
    
    self.model = DistilBertModel.from_pretrained("distilbert-base-uncased")
    self.model.embeddings = torch.nn.Identity()
    self.output = torch.nn.Linear(768, 768)

  def forward(self, X, M):
    X = self.posenc(X)
    X = self.input(X)
    X = self.output(X)
    X  = self.model(X, M).last_hidden_state

    return X

class Autoencoder(torch.nn.Module):
  def __init__(self,
               # Pos. Enc. params
               max_log2len=8, 
               # If use separate variance
               variational=False):
    
    super(Autoencoder, self).__init__()
    self.variational = variational
    self.encoder = Encoder()
    self.decoder = Decoder(max_log2len=max_log2len)
    
  def forward(self, input_ids, attention_mask):
    X, M = input_ids, attention_mask 
    n = X.shape[1]
    X = self.encoder(X, M, keepdim=True, return_std=self.variational)
    if self.variational:
      X = X[0] + X[1]*torch.randn_like(X[1])
    X = X.repeat(1, n, 1)
    X = self.decoder(X, M)
    return X


# Dataset

In [None]:
class AutoencodingDataset(torch.utils.data.Dataset):
  def __init__(self, sentences, max_len=16*5):
    super(AutoencodingDataset, self).__init__()

    self.max_len = max_len

    self.sentences = []
    for s in tqdm(sentences):
      i = 1
      while len(s)//i > self.max_len:
        i += 1
      itl = len(s)//i
      for j in range(i):
        self.sentences.append(s[j*itl:(j+1)*itl])
      
    self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased")
    self.embedding = DistilBertModel.from_pretrained("distilbert-base-uncased").embeddings

  def __len__(self):
    return len(self.sentences)
  
  def __getitem__(self, idx):
    return self.sentences[idx]

  def prepare_batch(self, sentences):
    X = self.tokenizer(sentences, return_tensors="pt", padding=True)
    X['input_ids'] = self.embedding(X['input_ids'])
    return X

In [None]:
def test_dataset():
  ds = AutoencodingDataset([" cat Hello , my dog is cute", "the can sat on a lonely rock."], max_len=4*3)
  assert len(ds) == 6
  ds[0]
  ds[1]
  return True

In [None]:
def test_dataloader():
  ds = AutoencodingDataset([" cat Hello , my dog is cute", "the can sat on a lonely rock."])
  dl = torch.utils.data.DataLoader(ds, batch_size=2, collate_fn=ds.prepare_batch)
  for x in dl:
    assert      'input_ids' in x
    assert 'attention_mask' in x
    assert x['input_ids'].shape[0] == 2
    assert x['attention_mask'].shape[0] == 2
    assert x['input_ids'].shape[1] == x['attention_mask'].shape[1]

  return True

In [None]:
def IMDB_preparation(max_len=16*5):
  print("Loading IMDB...")
  dr = list(IMDB(split="train"))
  print("Taking sentences...")
  sentences = [d[1] for d in tqdm(dr)]
  print("Preparing dataset...")
  dp = AutoencodingDataset(sentences, max_len=max_len)
  return dp

# Training

## Dataset Preparation

In [None]:
ds = IMDB_preparation()

In [None]:
dl = torch.utils.data.DataLoader(ds, batch_size=4, collate_fn=ds.prepare_batch, shuffle=True)

## Model Preparation

Create model: 

In [None]:
device = torch.device("cuda")
m = Autoencoder().to(device)
Ls = []

Number of parameters:

In [None]:
sum(p.numel() for p in m.parameters())

## Training Loop

In [None]:
opt = torch.optim.Adamax(m.parameters(), lr=0.0001, weight_decay=0.0000)

In [None]:
EPOCHS = 1
for e in range(EPOCHS):
  print(f"Epoch: {e}")
  pbar = tqdm(dl)
  for xm in pbar:
    opt.zero_grad()
    X = xm['input_ids'].to(device)
    M = xm['attention_mask'].to(device)

    P = m(X, M)

    L = (X - P).pow(2).mean()
    L.backward()
    opt.step()
    
    pbar.set_description(f"L: {L.item()}")
    Ls.append(L.item())

In [None]:
plt.plot(np.log(Ls))