In [11]:
# Follow tutorial from here: https://huggingface.co/docs/transformers/v4.17.0/en/tasks/language_modeling
%load_ext autoreload
%autoreload 2
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import numpy as np
from sklearn.model_selection import train_test_split
import json
import torch
import datasets
from transformers import AutoTokenizer, DataCollatorWithPadding, AutoModelForCausalLM, TrainingArguments, Trainer, DataCollatorForLanguageModeling
import data
import os
from os.path import join
from tqdm import tqdm

# hyperparams
# checkpoint = "EleutherAI/gpt-j-6B"
checkpoint = "EleutherAI/gpt-neo-2.7B"
# checkpoint = 'facebook/opt-2.7b'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)
tokenizer.pad_token = tokenizer.eos_token

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [12]:
checkpoint_dir = f'./title-{checkpoint.replace("/", "__")}'
checkpoints = os.listdir(checkpoint_dir)

In [14]:
for checkpoint in tqdm(checkpoints):
    # load model

    mod = AutoModelForCausalLM.from_pretrained(
        # checkpoint,
        join(checkpoint_dir, checkpoint),
        # revision="float16",
        # torch_dtype=torch.float16,
        # low_cpu_mem_usage=True
    )
    # mod = mod.half()
    prompt = '2021\n\n'
    samples = data.generate_samples(
        mod,
        tokenizer,
        prompt=prompt,
        num_return_sequences=40,
    )
    save_str = '\n\n'.join([repr(x[len(prompt):].strip()) for x in samples])
    open(f'../samples/gptneo/2021/{checkpoint}.txt', 'w').write(save_str)

  0%|          | 0/9 [00:00<?, ?it/s]Setting `pad_token_id` to `eos_token_id`:198 for open-end generation.
 11%|█         | 1/9 [00:30<04:05, 30.73s/it]Setting `pad_token_id` to `eos_token_id`:198 for open-end generation.
 22%|██▏       | 2/9 [00:58<03:22, 28.88s/it]Setting `pad_token_id` to `eos_token_id`:198 for open-end generation.
 33%|███▎      | 3/9 [01:52<04:01, 40.23s/it]Setting `pad_token_id` to `eos_token_id`:198 for open-end generation.
 44%|████▍     | 4/9 [03:06<04:27, 53.57s/it]Setting `pad_token_id` to `eos_token_id`:198 for open-end generation.
 56%|█████▌    | 5/9 [08:09<09:34, 143.73s/it]Setting `pad_token_id` to `eos_token_id`:198 for open-end generation.
 67%|██████▋   | 6/9 [09:24<06:00, 120.20s/it]Setting `pad_token_id` to `eos_token_id`:198 for open-end generation.
 78%|███████▊  | 7/9 [15:08<06:27, 193.50s/it]Setting `pad_token_id` to `eos_token_id`:198 for open-end generation.
 89%|████████▉ | 8/9 [16:21<02:35, 155.20s/it]Setting `pad_token_id` to `eos_token_id

Setting `pad_token_id` to `eos_token_id`:198 for open-end generation.


['2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n\n',
 '2020\n\n  "the Court finds that defendants have met the first two prongs of the\n\n',
 '2020\n\n                  ',
 '2020\n\n\n\n\n\n\n\n\n\n\n\xa0\n\xa0\n\xa0\n\xa0\n\xa0\n\xa0\n\n\n\n\n\n\