In [1]:
!nvidia-smi

Wed Jul 29 19:04:17 2020       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 450.51.05    Driver Version: 418.67       CUDA Version: 10.1     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   38C    P0    29W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                 ERR! |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
!pip install transformers
!pip install gdown

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/27/3c/91ed8f5c4e7ef3227b4119200fc0ed4b4fd965b1f0172021c25701087825/transformers-3.0.2-py3-none-any.whl (769kB)
[K     |████████████████████████████████| 778kB 2.8MB/s 
[?25hCollecting sacremoses
[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)
[K     |████████████████████████████████| 890kB 15.8MB/s 
Collecting sentencepiece!=0.1.92
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 18.3MB/s 
Collecting tokenizers==0.8.1.rc1
[?25l  Downloading https://files.pythonhosted.org/packages/40/d0/30d5f8d221a0ed981a186c8eb986ce1c94e3a6e87f994eae9f4aa5250217/tokenizers-0.8.1rc1-cp36-cp36m-manylinux1_x86_64.whl (3.0MB

In [4]:
import os
import time
import numpy as np
import pandas as pd
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.nn.utils.rnn import pad_sequence

# import huggingface transformers
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config, AdamW, get_linear_schedule_with_warmup

In [5]:
def top_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k, top-p (nucleus) and/or threshold filtering
        Args:
            logits: logits distribution shape (vocabulary size)
            top_k: <=0: no filtering, >0: keep only top k tokens with highest probability.
            top_p: <=0.0: no filtering, >0.0: keep only a subset S of candidates, where S is the smallest subset
                whose total probability mass is greater than or equal to the threshold top_p.
                In practice, we select the highest probability tokens whose cumulative probability mass exceeds
                the threshold top_p.
    """
    # batch support!
    if top_k > 0:
        values, _ = torch.topk(logits, top_k)
        min_values = values[:, -1].unsqueeze(1).repeat(1, logits.shape[-1])
        logits = torch.where(logits < min_values, 
                             torch.ones_like(logits, dtype=logits.dtype) * -float('Inf'), 
                             logits)
    if top_p > 0.0:
        # Compute cumulative probabilities of sorted tokens
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probabilities = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probabilities > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0
        
        sorted_logits = sorted_logits.masked_fill_(sorted_indices_to_remove, filter_value)
        logits = torch.zeros_like(logits).scatter(1, sorted_indices, sorted_logits)
    
    return logits

In [7]:
np.random.seed(1234)
torch.random.manual_seed(1234)
torch.cuda.manual_seed(1234)

In [8]:
gpt2_small_config = GPT2Config()
gpt2_medium_config = GPT2Config(n_ctx=1024, n_embd=1024, n_layer=24, n_head=16)
gpt2_large_config = GPT2Config(n_ctx=1024, n_embd=1280, n_layer=36, n_head=20)   

In [9]:
# load the tokenizer
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")

HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…




HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…




In [10]:
# download all model weights
!wget https://convaisharables.blob.core.windows.net/lsp/multiref/small_ft.pkl
!wget https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl
!wget https://convaisharables.blob.core.windows.net/lsp/multiref/large_ft.pkl

--2020-07-29 19:21:32--  https://convaisharables.blob.core.windows.net/lsp/multiref/small_ft.pkl
Resolving convaisharables.blob.core.windows.net (convaisharables.blob.core.windows.net)... 13.77.184.64
Connecting to convaisharables.blob.core.windows.net (convaisharables.blob.core.windows.net)|13.77.184.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 351265273 (335M) [application/octet-stream]
Saving to: ‘small_ft.pkl’


2020-07-29 19:22:23 (6.70 MB/s) - ‘small_ft.pkl’ saved [351265273/351265273]

--2020-07-29 19:22:24--  https://convaisharables.blob.core.windows.net/lsp/multiref/medium_ft.pkl
Resolving convaisharables.blob.core.windows.net (convaisharables.blob.core.windows.net)... 13.77.184.64
Connecting to convaisharables.blob.core.windows.net (convaisharables.blob.core.windows.net)|13.77.184.64|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 862954531 (823M) [application/octet-stream]
Saving to: ‘medium_ft.pkl’


2020-07-29 19:23

In [11]:
# load the model
model_size = "medium"

if model_size == "small":
    model = GPT2LMHeadModel(gpt2_small_config)
    model.load_state_dict(torch.load("small_ft.pkl"), strict=False)
elif model_size == "medium":
    model = GPT2LMHeadModel(gpt2_medium_config)
    model.load_state_dict(torch.load("medium_ft.pkl"), strict=False)
elif model_size == "large":
    model = GPT2LMHeadModel(gpt2_large_config)
    model.load_state_dict(torch.load("large_ft.pkl"), strict=False)

device = torch.device("cuda")
model = model.to(device)

In [12]:
# beg huggingface not to change this anymore
model.lm_head.weight.data = model.transformer.wte.weight.data

In [14]:
eos = [tokenizer.encoder["<|endoftext|>"]]

In [15]:
past = None
temperature = 0.9
top_k = -1
top_p = 0.9

model.eval()
prev_input = None

while True:
    with torch.no_grad():
        # input and update B's utterance
        user = input("User:")
        
        if user == "quit":
            "stop talking!"
            break
        
        user = tokenizer.encode(user)
        prev_input = user
        prev_input = torch.LongTensor(prev_input).unsqueeze(0).to(device)
        _, past = model(prev_input, past=past)

        prev_input = torch.LongTensor([eos]).to(device)
    

        sent = []
        for i in range(500):
            logits, past = model(prev_input, past=past)
            logits = logits[:, -1, :] / temperature
            logits = top_filtering(logits, top_k=top_k, top_p=top_p)

            probs = torch.softmax(logits, dim=-1)

            prev_input = torch.multinomial(probs, num_samples=1)
            prev_word = prev_input.item()

            if prev_word == eos[0]:
                break
            sent.append(prev_word)
        
        print("Bot:", tokenizer.decode(sent))
        prev_input = torch.LongTensor([eos]).to(device)
        _, past = model(prev_input, past=past)

User:hello
Bot: Y u do dis?
User:can you speak english?
Bot: nah, imaa ask the one who speaks it.
User:what day is it today?
Bot: this
User:Hello?
Bot: hello bro
User:How are you doing?
Bot: good, you?
User:I'm fine, thanks
Bot: No problem man, I have a bad case of flu. How are you?
User:Sorry to hear that! Hope you will get better soon.
Bot: Thank you :D
User:quite
Bot: Thanks. i hope
User:quit
