In [1]:
!pip install transformers datasets

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting transformers
  Downloading transformers-4.24.0-py3-none-any.whl (5.5 MB)
[K     |████████████████████████████████| 5.5 MB 25.8 MB/s 
[?25hCollecting datasets
  Downloading datasets-2.6.1-py3-none-any.whl (441 kB)
[K     |████████████████████████████████| 441 kB 52.4 MB/s 
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.6 MB)
[K     |████████████████████████████████| 7.6 MB 59.6 MB/s 
Collecting huggingface-hub<1.0,>=0.10.0
  Downloading huggingface_hub-0.10.1-py3-none-any.whl (163 kB)
[K     |████████████████████████████████| 163 kB 71.4 MB/s 
Collecting dill<0.3.6
  Downloading dill-0.3.5.1-py2.py3-none-any.whl (95 kB)
[K     |████████████████████████████████| 95 kB 6.2 MB/s 
Collecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting multiprocess
  

## import the decoder

In [2]:
from google.colab import files
uploaded = files.upload()

Saving decoder.py to decoder.py


In [3]:
import torch
from torch import nn
import torch.nn.functional as F

from torch.utils.data import dataset, DataLoader
import numpy as np
import math
import matplotlib.pyplot as plt

import decoder

## Load data and tokenizer and check

In [4]:
from transformers import AutoTokenizer, DataCollatorWithPadding
from datasets import load_dataset

In [5]:
checkpoint = 'distilbert-base-cased'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)


Downloading:   0%|          | 0.00/29.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/411 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/213k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/436k [00:00<?, ?B/s]

In [6]:
raw_datasets = load_dataset("glue", "sst2")

Downloading builder script:   0%|          | 0.00/28.8k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/28.7k [00:00<?, ?B/s]

Downloading readme:   0%|          | 0.00/27.8k [00:00<?, ?B/s]

Downloading and preparing dataset glue/sst2 to /root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad...


Downloading data:   0%|          | 0.00/7.44M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/67349 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/872 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1821 [00:00<?, ? examples/s]

Dataset glue downloaded and prepared to /root/.cache/huggingface/datasets/glue/sst2/1.0.0/dacbe3125aa31d7f70367a07a8a9e72a5a0bfeb5fc42e75c9db75b96da6053ad. Subsequent calls will reuse this data.


  0%|          | 0/3 [00:00<?, ?it/s]

In [8]:
raw_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx'],
        num_rows: 1821
    })
})

In [9]:
def tokenize_fn(batch):
    return tokenizer(batch['sentence'], truncation=True)

In [12]:
tokenized_datasets = raw_datasets.map(tokenize_fn, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

  0%|          | 0/68 [00:00<?, ?ba/s]

  0%|          | 0/1 [00:00<?, ?ba/s]

  0%|          | 0/2 [00:00<?, ?ba/s]

In [13]:
data_collator

DataCollatorWithPadding(tokenizer=PreTrainedTokenizerFast(name_or_path='distilbert-base-cased', vocab_size=28996, model_max_len=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}), padding=True, max_length=None, pad_to_multiple_of=None, return_tensors='pt')

In [14]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['sentence', 'label', 'idx', 'input_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [15]:
tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx", "label"])


In [16]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 67349
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 872
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask'],
        num_rows: 1821
    })
})

In [17]:
train_loader = DataLoader(
    tokenized_datasets['train'],
    shuffle=True,
    batch_size=32,
    collate_fn=data_collator
)
valid_loader = DataLoader(
    tokenized_datasets['validation'],
    batch_size=32,
    collate_fn=data_collator
)

In [18]:
for batch in train_loader:
  for k, v in batch.items():
    print("k:", k, "v.shapes: ", v.shape)
  break

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


k: input_ids v.shapes:  torch.Size([32, 36])
k: attention_mask v.shapes:  torch.Size([32, 36])


In [19]:
tokenizer.vocab_size

28996

In [21]:
tokenizer.pad_token_id

0

In [20]:
tokenizer.max_model_input_sizes

{'distilbert-base-uncased': 512,
 'distilbert-base-uncased-distilled-squad': 512,
 'distilbert-base-cased': 512,
 'distilbert-base-cased-distilled-squad': 512,
 'distilbert-base-german-cased': 512,
 'distilbert-base-multilingual-cased': 512}

## Train the encoder model

In [65]:

model = decoder.Decoder(
    vocab_size=tokenizer.vocab_size,
    max_len=tokenizer.max_model_input_sizes[checkpoint],
    d_k=16,
    d_model=64,
    n_heads=4,
    n_layers=2,
    dropout_prob=0.1,
)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

Decoder(
  (embedding): Embedding(28996, 64)
  (pos_encoding): PositionalEncoding(
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (transformer_blocks): Sequential(
    (0): TransformerBlock(
      (k_embed): Linear(in_features=64, out_features=64, bias=True)
      (q_embed): Linear(in_features=64, out_features=64, bias=True)
      (v_embed): Linear(in_features=64, out_features=64, bias=True)
      (ln1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (ln2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
      (mha): CausalAttention(
        (key): Linear(in_features=64, out_features=64, bias=True)
        (query): Linear(in_features=64, out_features=64, bias=True)
        (value): Linear(in_features=64, out_features=64, bias=True)
        (fc): Linear(in_features=64, out_features=64, bias=True)
      )
      (ann): Sequential(
        (0): Linear(in_features=64, out_features=256, bias=True)
        (1): GELU(approximate=none)
        (2): Linear(in_features=256, o

In [66]:
loss_func = nn.CrossEntropyLoss(ignore_index=tokenizer.pad_token_id)
optimizer = torch.optim.Adam(model.parameters())

In [67]:
from datetime import datetime

def train(model, criterion, optimizer, train_loader, valid_loader, epochs):
  train_losses = np.zeros(epochs)
  test_losses = np.zeros(epochs)

  for it in range(epochs):
    model.train()
    t0 = datetime.now()
    train_loss = []
    for batch in train_loader:
      batch = {k: v.to(device) for k, v in batch.items()}
      optimizer.zero_grad()

      targets = batch['input_ids'].clone().detach()
      targets = torch.roll(targets, shifts=-1,dims=-1)
      targets[:,-1] = tokenizer.pad_token_id

      outputs = model(batch['input_ids'], batch['attention_mask'])
      # outputs N*T*V -> N*V*T, targets: N*T, similar as for V-classification, where V is the vocab size
      loss = criterion(outputs.transpose(2,1), targets)
      loss.backward()
      optimizer.step()

      train_loss.append(loss.item())

    train_loss = np.mean(train_loss)

    model.eval()
    test_loss = []
    for batch in valid_loader:
      batch = {k: v.to(device) for k, v in batch.items()}

      targets = batch['input_ids'].clone().detach()
      targets = torch.roll(targets, shifts=-1,dims=-1)
      targets[:,-1] = tokenizer.pad_token_id

      outputs = model(batch['input_ids'], batch['attention_mask'])
      loss = criterion(outputs.transpose(2,1), targets)

      test_loss.append(loss.item())

    test_loss = np.mean(test_loss)
    
    train_losses[it] = train_loss
    test_losses[it] = test_loss

    dt = datetime.now() - t0
    print(f'Epoch {it+1}/{epochs}, Train Loss: {train_loss:.4f}, \
    Test Loss: {test_loss:.4f}, duration: {dt}')
  
  return train_losses, test_losses

In [68]:
train_losses, test_losses = train(model, loss_func, optimizer, train_loader, valid_loader, epochs = 12)

Epoch 1/12, Train Loss: 5.9729,     Test Loss: 5.7452, duration: 0:01:06.656090
Epoch 2/12, Train Loss: 5.0200,     Test Loss: 5.7354, duration: 0:01:06.217065
Epoch 3/12, Train Loss: 4.6948,     Test Loss: 5.8015, duration: 0:01:06.269206
Epoch 4/12, Train Loss: 4.5148,     Test Loss: 5.8325, duration: 0:01:06.181389
Epoch 5/12, Train Loss: 4.3870,     Test Loss: 5.8681, duration: 0:01:06.251982
Epoch 6/12, Train Loss: 4.2835,     Test Loss: 5.9020, duration: 0:01:06.570188
Epoch 7/12, Train Loss: 4.1972,     Test Loss: 5.9549, duration: 0:01:06.379913
Epoch 8/12, Train Loss: 4.1225,     Test Loss: 5.9875, duration: 0:01:05.627260
Epoch 9/12, Train Loss: 4.0543,     Test Loss: 6.0439, duration: 0:01:05.976331
Epoch 10/12, Train Loss: 3.9944,     Test Loss: 6.0713, duration: 0:01:06.334182
Epoch 11/12, Train Loss: 3.9410,     Test Loss: 6.0805, duration: 0:01:05.532704
Epoch 12/12, Train Loss: 3.8886,     Test Loss: 6.1199, duration: 0:01:06.085955


## A Brief Evaluation

In [70]:
test_loader = DataLoader(
    tokenized_datasets['test'],
    batch_size=1,
    collate_fn=data_collator
)

In [71]:
model.eval()

for batch in test_loader:
    batch = {k: v.to(device) for k, v in batch.items()}
    outputs = model(batch['input_ids'], batch['attention_mask'])
    break

outputs.shape

torch.Size([1, 11, 28996])

In [72]:
torch.argmax(outputs, axis = -1)

tensor([[  170,  2742,  9995, 21575,   102,  1103,   102,   170,   102,   102,
           118]], device='cuda:0')

In [73]:
prediction_ids = torch.argmax(outputs, axis=-1)
prediction_ids.shape

torch.Size([1, 11])

In [74]:
for i in range(11):
    print(i, tokenizer.decode(prediction_ids[0][i]))

0 a
1 marriage
2 ##san
3 ##mash
4 [SEP]
5 the
6 [SEP]
7 a
8 [SEP]
9 [SEP]
10 -


In [75]:
tokenizer.decode(torch.concat((batch['input_ids'][0,:5], prediction_ids[:, 4])))

'[CLS] uneasy mishmash [SEP]'

## Generation of Texts

In [76]:
prompt = "It's"

tokenized_prompt = tokenizer(prompt, return_tensors="pt")
tokenized_prompt

{'input_ids': tensor([[ 101, 1135,  112,  188,  102]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}

In [77]:
tokenizer.decode(tokenized_prompt['input_ids'][0][:-1])

"[CLS] It's"

In [78]:
outputs = model(
    tokenized_prompt['input_ids'][:,:-1].to(device),
    tokenized_prompt['attention_mask'][:,:-1].to(device)
)
outputs.shape

torch.Size([1, 4, 28996])

In [79]:
prediction_ids = torch.argmax(outputs[:, -1, :], axis=-1)
prediction_ids

tensor([170], device='cuda:0')

In [80]:
tokenizer.decode(prediction_ids)

'a'

In [81]:
prompt = "It's"

tokenized_prompt = tokenizer(prompt, return_tensors="pt")

input_ids = tokenized_prompt['input_ids'][:,:-1].to(device)
mask = tokenized_prompt['attention_mask'][:,:-1].to(device)

for _ in range(20):
  outputs = model(input_ids, mask)
  prediction_id = torch.argmax(outputs[:, -1, :], axis=-1)

  input_ids = torch.hstack((input_ids, prediction_id.view(1,1)))
  mask = torch.ones_like(input_ids)

  if prediction_id == tokenizer.sep_token_id:
    break


In [82]:
input_ids.shape

torch.Size([1, 8])

In [83]:
tokenizer.decode(input_ids[0])

"[CLS] It's a good time [SEP]"