# Setup Environment

In [1]:
# External Imports
import time, os, json
import numpy as np
import matplotlib.pyplot as plt
import datasets
import torch
import random
from transformers import AutoTokenizer
from sentence_transformers import SentenceTransformer
from datasets import Dataset
from pprint import pprint


# Internal Imports
import mltoolkit as mltk
from mltoolkit import (
    cfg_reader,
    models,
)

%load_ext autoreload
%autoreload 2


# Load Config, Tokenizer and Model Checkpoint

In [2]:
cfg, keywords = cfg_reader.load('/data/john/projects/mltoolkit/cfg/nlp/sentence_decoding/dev_config.yaml')

tokenizer = AutoTokenizer.from_pretrained(cfg.data['tokenizer_name'])

cfg.model['vocab_size'] = len(tokenizer)
cfg.model['pad_token_id'] = tokenizer.pad_token_id
cfg.model['devices'] = ['cuda:3']

model = models.TransformerAutoencoder(cfg)
model.load_state_dict(torch.load('/data/john/projects/mltoolkit/checkpoints/20230926-220851-sentence-decoding/best_model.pt'))
model.eval()

print(f'embeddings device: {model.embeddings.weight.device}')
#print(f'decoder device: {model.decoder[0].weight.device}')
print(f'decoder device: {next(iter(model.decoder[0].state_dict().values())).device}')
print()
print(model)

embeddings device: cpu
decoder device: cuda:3

TransformerAutoencoder(
  (encoder): SentenceTransformer(
    (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: MPNetModel 
    (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
    (2): Normalize()
  )
  (embeddings): Embedding(30527, 768, padding_idx=1)
  (pos): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (up_conv): Sequential(
    (0): Sequential(
      (0): Conv1d(1, 16, kernel_size=(33,), stride=(1,), padding=(16,))
      (1): ReLU()
    )
    (1): Linear(in_features=768, out_features=768, bias=True)
  )
  (decoder): ModuleList(
    (0): TransformerDecoder(
      (layers): ModuleList(
        (0-2): 3 x TransformerDecoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_

# Load Text Data

In [3]:
#ds = Dataset.from_csv('/data/john/projects/mltoolkit/data/synthetic_overlap/synthetic_overlap_data.csv')
ds = datasets.load_dataset('ptb_text_only')['test']
pprint(ds)

Found cached dataset ptb_text_only (/home/john/.cache/huggingface/datasets/ptb_text_only/penn_treebank/1.1.0/8d1b97746fb9765d140e569ec5ddd35e20af4d37761f5e1bf357ea0b081f2c1f)


  0%|          | 0/3 [00:00<?, ?it/s]

Dataset({
    features: ['sentence'],
    num_rows: 3761
})


# Encode Some Text

In [4]:
sample_size = 10
sample_indices = np.random.choice(
    np.arange(len(ds)), 
    (sample_size,)
)

sentences = ds[sample_indices]['sentence']
enc = model.encode(sentences)

print(enc)

tensor([[ 0.0069,  0.0688, -0.0115,  ...,  0.0389, -0.0072, -0.0203],
        [-0.0231,  0.0562,  0.0215,  ...,  0.0155,  0.0522, -0.0109],
        [-0.0185,  0.0336, -0.0222,  ...,  0.0140, -0.0336, -0.0086],
        ...,
        [ 0.0206,  0.0225,  0.0060,  ..., -0.0089,  0.0008, -0.0031],
        [ 0.0301,  0.0010,  0.0153,  ...,  0.0086, -0.0213, -0.0104],
        [ 0.0197, -0.0128, -0.0121,  ..., -0.0257, -0.0330, -0.0528]],
       device='cuda:3')


# Get Raw Decodings

In [5]:
dec = model.decode(enc)

print(dec)



['<s> in the - - - - - - - - - - - - - - - - - - - - - -...........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................,,,,,,,,,,,,', '<s> in north </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> </s> 

# Clean Up Decodings and Compare With Ground Truth

In [6]:
for i, s in enumerate(dec):
    print(f'\nx:\t{sentences[i]}\nx_hat:\t{s[4:s.find("</s>")]}')


x:	in response to your overly optimistic <unk> piece on how long unemployment lasts people patterns sept. N i am in the communications field above entry level
x_hat:	in the - - - - - - - - - - - - - - - - - - - - - -...........................................................................................................................................................................................................................................................................................................................................................................................................................................................................................,,,,,,,,,,,

x:	between N and N north <unk> was the only state in the midwest to lose population a loss of N people
x_hat:	in north 

x:	but in the process the fed risks <unk> inflation
x_hat:	

x:	employees have the right to trade stock among themselves and the company will establish an internal clearing h

# Test

In [52]:
max_length = 512

# load model
stx = SentenceTransformer(cfg.model['encoder'])

# tokenize sentences
sent_tokens = tokenizer(
    sentences,
    max_length=max_length,
    padding='max_length',
    truncation=True,
    return_tensors='pt',
)

#print(sent_tokens)

stx(sent_tokens)

{'input_ids': tensor([[    0,  2049,  1009,  ...,     1,     1,     1],
        [    0,  2891, 15772,  ...,     1,     1,     1],
        [    0,  2000,  2073,  ...,     1,     1,     1],
        ...,
        [    0,  6993,  5527,  ...,     1,     1,     1],
        [    0,  4295,  4294,  ...,     1,     1,     1],
        [    0,  2064,  2354,  ...,     1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        ...,
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]), 'token_embeddings': tensor([[[ 1.7046e-02,  1.1460e-01, -2.1556e-02,  ...,  7.2599e-02,
          -1.0680e-01, -1.3060e-01],
         [-7.0347e-03,  4.4170e-02, -3.6990e-02,  ...,  8.6892e-02,
          -1.1156e-01, -6.8378e-02],
         [-2.0015e-02,  1.5945e-02, -1.3913e-02,  ...,  9.2661e-02,
          -8.7501e-02, -9.9539e-02],
         ...,
         [ 3.3413e-02,  2.0131e-01, -