In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

import argparse
import logging
from tqdm import trange

import torch
import torch.nn.functional as F
import numpy as np

from transformers import GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import OpenAIGPTLMHeadModel, OpenAIGPTTokenizer
from transformers import XLNetLMHeadModel, XLNetTokenizer
from transformers import TransfoXLLMHeadModel, TransfoXLTokenizer
from transformers import XLMWithLMHeadModel, XLMTokenizer

In [2]:
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.INFO)
logger = logging.getLogger(__name__)

In [3]:
MAX_LENGTH = int(10000)  # Hardcoded max length to avoid infinite loop
length = 20

In [4]:
temperature = 1.0
top_k = 0.0
top_p = 0.9

In [5]:
# ALL_MODELS = sum((tuple(conf.pretrained_config_archive_map.keys()) for conf in (GPT2Config, OpenAIGPTConfig, XLNetConfig, TransfoXLConfig, XLMConfig)), ())

In [30]:
MODEL_CLASSES = {
    'gpt2': (GPT2LMHeadModel, GPT2Tokenizer),
    #'ctrl': (CTRLLMHeadModel, CTRLTokenizer),
    #'openai-gpt': (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer),
    #'xlnet': (XLNetLMHeadModel, XLNetTokenizer),
    #'transfo-xl': (TransfoXLLMHeadModel, TransfoXLTokenizer),
    #'xlm': (XLMWithLMHeadModel, XLMTokenizer),
}

In [15]:
def set_seed(args):
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if args.n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

In [16]:
def top_k_top_p_filtering(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
    """ Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
        Args:
            logits: logits distribution shape (batch size x vocabulary size)
            top_k > 0: keep only top k tokens with highest probability (top-k filtering).
            top_p > 0.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
                Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    top_k = min(top_k, logits.size(-1))  # Safety check
    if top_k > 0:
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        logits[indices_to_remove] = filter_value

    if top_p > 0.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold
        sorted_indices_to_remove = cumulative_probs > top_p
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    return logits

In [17]:
def sample_sequence(model, tokenizer, length, context, num_samples=1, temperature=1, top_k=0, top_p=0.0, repetition_penalty=1.0,
                    is_xlnet=False, is_xlm_mlm=False, xlm_mask_token=None, xlm_lang=None, device='cpu'):
    context = torch.tensor(context, dtype=torch.long, device=device)
    context = context.unsqueeze(0).repeat(num_samples, 1)
    logger.info('** context: {}'.format(context.shape))
    generated = context
    with torch.no_grad():
        for idx in range(length):

            inputs = {'input_ids': generated}
            if is_xlnet: 
                # XLNet is a direct (predict same token, not next token) and bi-directional model by default
                # => need one additional dummy token in the input (will be masked), attention mask and target mapping (see model docstring)
                input_ids = torch.cat((generated, torch.zeros((1, 1), dtype=torch.long, device=device)), dim=1)
                perm_mask = torch.zeros((1, input_ids.shape[1], input_ids.shape[1]), dtype=torch.float, device=device)
                perm_mask[:, :, -1] = 1.0  # Previous tokens don't see last token
                target_mapping = torch.zeros((1, 1, input_ids.shape[1]), dtype=torch.float, device=device)
                target_mapping[0, 0, -1] = 1.0  # predict last token
                inputs = {'input_ids': input_ids, 'perm_mask': perm_mask, 'target_mapping': target_mapping}

            if is_xlm_mlm and xlm_mask_token:
                # XLM MLM models are direct models (predict same token, not next token)
                # => need one additional dummy token in the input (will be masked and guessed)
                input_ids = torch.cat((generated, torch.full((1, 1), xlm_mask_token, dtype=torch.long, device=device)), dim=1)
                inputs = {'input_ids': input_ids}

            if xlm_lang is not None:
                inputs["langs"] = torch.tensor([xlm_lang] * inputs["input_ids"].shape[1], device=device).view(1, -1)

            outputs = model(**inputs)  # Note: we could also use 'past' with GPT-2/Transfo-XL/XLNet/CTRL (cached hidden-states)
            next_token_logits = outputs[0][:, -1, :] / (temperature if temperature > 0 else 1.)
            print('** {}th: input_ids, next_token_logits: {} {}'.format(idx, generated.shape, next_token_logits.shape))
            print('** {}th: input_ids -> {}:'.format(idx, tokenizer.decode(generated[0].tolist())))

            # repetition penalty from CTRL (https://arxiv.org/abs/1909.05858)
            for i in range(num_samples):
                for _ in set(generated[i].tolist()):
                    next_token_logits[i, _] /= repetition_penalty
                
            filtered_logits = top_k_top_p_filtering(next_token_logits, top_k=top_k, top_p=top_p)
            if temperature == 0: # greedy sampling:
                next_token = torch.argmax(filtered_logits, dim=-1).unsqueeze(-1)
            else:
                next_token = torch.multinomial(F.softmax(filtered_logits, dim=-1), num_samples=1)
            generated = torch.cat((generated, next_token), dim=1)
    return generated


In [18]:
def generate(model, tokenizer, raw_text):
    context_tokens = tokenizer.encode(raw_text)
    logger.info('context_tokens: {}'.format(context_tokens))
    out = sample_sequence(
        model=model,
        tokenizer=tokenizer,
        context=context_tokens,
        length=length,
        temperature=temperature,
        top_k=top_k,
        top_p=top_p,
        device=device,
        is_xlnet=False,
    )
    logger.info('out length: {}'.format(len(out)))
    logger.info('out[0]: {}'.format(out[0].shape))
    logger.info('out[0, context_length:]: {}'.format(out[0, len(context_tokens):].shape))

    out = out[0, len(context_tokens):].tolist()
    text = tokenizer.decode(out, clean_up_tokenization_spaces=True)
    logger.info('text: {}'.format(text))

    return text

In [19]:
device = torch.device('cpu')

In [20]:
# tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
# model = GPT2LMHeadModel.from_pretrained('gpt2-large')
# model.to(device)
# model.eval()

In [21]:
model_type = 'gpt2'
model_class, tokenizer_class = MODEL_CLASSES[model_type]

In [22]:
tokenizer = tokenizer_class.from_pretrained('gpt2-large')
model = model_class.from_pretrained('gpt2-large')
model.to(device)
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 1280)
    (wpe): Embedding(1024, 1280)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0): Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
      (1): Block(
        (ln_1): LayerNorm((1280,), eps=1e-05, elementwise_affine=True)
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2):

In [23]:
if length < 0 and model.config.max_position_embeddings > 0:
    length = model.config.max_position_embeddings
elif 0 < model.config.max_position_embeddings < length:
    length = model.config.max_position_embeddings  # No generation bigger than model size 
elif length < 0:
    length = MAX_LENGTH  # avoid infinite loop

In [24]:
english = []
with open('top_1000_zeroshot.txt', 'r') as f:
    for line in f:
        english.append(line.lower().replace('\n',''))

In [25]:
english

['paris is relaxing during december , but it is usually chilly in july .']

In [26]:
! cat top_1000_zeroshot.txt

paris is relaxing during december , but it is usually chilly in july .


In [27]:
french = []
for en in english:
    prompt = f"Translate English to French: new jersey is sometimes quiet during autumn , and it is snowy in april . => new jersey est parfois calme pendant l' automne , et il est neigeux en avril . \n the united states is usually chilly during july , and it is usually freezing in november . => les états-unis est généralement froid en juillet , et il gèle habituellement en novembre . \n california is usually quiet during march , and it is usually hot in june . => california est généralement calme en mars , et il est généralement chaud en juin . \n the united states is sometimes mild during june , and it is cold in september . => les états-unis est parfois légère en juin , et il fait froid en septembre . \n your least liked fruit is the grape , but my least liked is the apple . => votre moins aimé fruit est le raisin , mais mon moins aimé est la pomme . \n his favorite fruit is the orange , but my favorite is the grape . => son fruit préféré est l'orange , mais mon préféré est le raisin . \n {en} =>"
    fr = generate(model, tokenizer, prompt)
    fr = fr.split(".")[0]
    french.append(fr)

10/22/2021 07:42:55 - INFO - __main__ -   context_tokens: [8291, 17660, 3594, 284, 4141, 25, 649, 22383, 318, 3360, 5897, 1141, 23608, 837, 290, 340, 318, 46742, 287, 46593, 346, 764, 5218, 649, 22383, 1556, 1582, 6513, 271, 2386, 1326, 279, 23048, 300, 6, 3557, 710, 837, 2123, 4229, 1556, 497, 10045, 2821, 551, 1196, 22379, 764, 220, 198, 262, 16503, 2585, 318, 3221, 49018, 1141, 474, 2062, 837, 290, 340, 318, 3221, 20884, 287, 645, 303, 1916, 764, 5218, 10287, 220, 25125, 1381, 12, 403, 271, 1556, 308, 35942, 2634, 1373, 972, 277, 3882, 551, 7544, 32512, 837, 2123, 4229, 308, 14064, 293, 7947, 2731, 1732, 551, 645, 303, 2022, 260, 764, 220, 198, 2386, 361, 3317, 318, 3221, 5897, 1141, 9960, 837, 290, 340, 318, 3221, 3024, 287, 474, 1726, 764, 5218, 2386, 361, 3317, 1556, 308, 35942, 2634, 1373, 972, 2386, 1326, 551, 48962, 837, 2123, 4229, 1556, 308, 35942, 2634, 1373, 972, 442, 3885, 551, 7544, 259, 764, 220, 198, 262, 16503, 2585, 318, 3360, 11607, 1141, 474, 1726, 837, 290, 340, 3

** 0th: input_ids, next_token_logits: torch.Size([1, 312]) torch.Size([1, 50257])
** 0th: input_ids -> Translate English to French: new jersey is sometimes quiet during autumn, and it is snowy in april. => new jersey est parfois calme pendant l' automne, et il est neigeux en avril. 
 the united states is usually chilly during july, and it is usually freezing in november. => les états-unis est généralement froid en juillet, et il gèle habituellement en novembre. 
 california is usually quiet during march, and it is usually hot in june. => california est généralement calme en mars, et il est généralement chaud en juin. 
 the united states is sometimes mild during june, and it is cold in september. => les états-unis est parfois légère en juin, et il fait froid en septembre. 
 your least liked fruit is the grape, but my least liked is the apple. => votre moins aimé fruit est le raisin, mais mon moins aimé est la pomme. 
 his favorite fruit is the orange, but my favorite is the grape. => so

** 8th: input_ids, next_token_logits: torch.Size([1, 320]) torch.Size([1, 50257])
** 8th: input_ids -> Translate English to French: new jersey is sometimes quiet during autumn, and it is snowy in april. => new jersey est parfois calme pendant l' automne, et il est neigeux en avril. 
 the united states is usually chilly during july, and it is usually freezing in november. => les états-unis est généralement froid en juillet, et il gèle habituellement en novembre. 
 california is usually quiet during march, and it is usually hot in june. => california est généralement calme en mars, et il est généralement chaud en juin. 
 the united states is sometimes mild during june, and it is cold in september. => les états-unis est parfois légère en juin, et il fait froid en septembre. 
 your least liked fruit is the grape, but my least liked is the apple. => votre moins aimé fruit est le raisin, mais mon moins aimé est la pomme. 
 his favorite fruit is the orange, but my favorite is the grape. => so

** 15th: input_ids, next_token_logits: torch.Size([1, 327]) torch.Size([1, 50257])
** 15th: input_ids -> Translate English to French: new jersey is sometimes quiet during autumn, and it is snowy in april. => new jersey est parfois calme pendant l' automne, et il est neigeux en avril. 
 the united states is usually chilly during july, and it is usually freezing in november. => les états-unis est généralement froid en juillet, et il gèle habituellement en novembre. 
 california is usually quiet during march, and it is usually hot in june. => california est généralement calme en mars, et il est généralement chaud en juin. 
 the united states is sometimes mild during june, and it is cold in september. => les états-unis est parfois légère en juin, et il fait froid en septembre. 
 your least liked fruit is the grape, but my least liked is the apple. => votre moins aimé fruit est le raisin, mais mon moins aimé est la pomme. 
 his favorite fruit is the orange, but my favorite is the grape. => 

10/22/2021 07:43:28 - INFO - __main__ -   out length: 1
10/22/2021 07:43:28 - INFO - __main__ -   out[0]: torch.Size([332])
10/22/2021 07:43:28 - INFO - __main__ -   out[0, context_length:]: torch.Size([20])
10/22/2021 07:43:28 - INFO - __main__ -   text:  paris est durée tranquille durant d'est de cette yeuver, et


** 19th: input_ids, next_token_logits: torch.Size([1, 331]) torch.Size([1, 50257])
** 19th: input_ids -> Translate English to French: new jersey is sometimes quiet during autumn, and it is snowy in april. => new jersey est parfois calme pendant l' automne, et il est neigeux en avril. 
 the united states is usually chilly during july, and it is usually freezing in november. => les états-unis est généralement froid en juillet, et il gèle habituellement en novembre. 
 california is usually quiet during march, and it is usually hot in june. => california est généralement calme en mars, et il est généralement chaud en juin. 
 the united states is sometimes mild during june, and it is cold in september. => les états-unis est parfois légère en juin, et il fait froid en septembre. 
 your least liked fruit is the grape, but my least liked is the apple. => votre moins aimé fruit est le raisin, mais mon moins aimé est la pomme. 
 his favorite fruit is the orange, but my favorite is the grape. => 

In [28]:
french

[" paris est durée tranquille durant d'est de cette yeuver, et"]

In [43]:
en

'paris is relaxing during december , but it is usually chilly in july .'

In [47]:
prompt = f'''Translate English to French: 
new jersey is sometimes quiet during autumn , and it is snowy in april . => new jersey est parfois calme pendant l' automne , et il est neigeux en avril . 
the united states is usually chilly during july , and it is usually freezing in november . => les états-unis est généralement froid en juillet , et il gèle habituellement en novembre . 
california is usually quiet during march , and it is usually hot in june . => california est généralement calme en mars , et il est généralement chaud en juin . 
the united states is sometimes mild during june , and it is cold in september . => les états-unis est parfois légère en juin , et il fait froid en septembre .  
your least liked fruit is the grape , but my least liked is the apple . => votre moins aimé fruit est le raisin , mais mon moins aimé est la pomme . 
his favorite fruit is the orange , but my favorite is the grape . => son fruit préféré est l'orange , mais mon préféré est le raisin . 
{en} => '''