In [1]:
import os
print('os.getcwd()', os.getcwd())
import sys
sys.path.insert(1, '../')
print(sys.version)
import time

#plotting tools
from matplotlib import pyplot as plt 
from tqdm.notebook import tqdm as tqdm

#torch libs
import torch
print('torch.__version__', torch.__version__)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
pipe_device = 0 if torch.cuda.is_available() else -1
print(pipe_device, device)

#huggingface transformers
import transformers
print('transformers.__version__',transformers.__version__)
from transformers import AutoTokenizer, pipeline
from transformers import GPT2Tokenizer, GPT2LMHeadModel, GPT2PreTrainedModel

from datasets import load_dataset

#curious
from curious.models import GPT2HeadWithValueModel
from curious.rl import PPOTrainer
from curious.utils import LengthSampler, collater, respond_to_batch, generate_text

#jupyter stuff
%load_ext autoreload
%autoreload 2
%matplotlib inline

os.getcwd() /Users/carson/projects/language_reinforce/notebooks
3.10.6 (main, Aug 30 2022, 05:09:33) [Clang 12.0.0 (clang-1200.0.32.29)]
torch.__version__ 1.12.1
-1 cpu
transformers.__version__ 4.22.2


In [16]:
# load imdb with datasets
ds = load_dataset('imdb', split='train')
ds

Found cached dataset imdb (/Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)


Dataset({
    features: ['text', 'label'],
    num_rows: 25000
})

In [17]:
ds[2]

{'text': "If only to avoid making this type of film in the future. This film is interesting as an experiment but tells no cogent story.<br /><br />One might feel virtuous for sitting thru it because it touches on so many IMPORTANT issues but it does so without any discernable motive. The viewer comes away with no new perspectives (unless one comes up with one while one's mind wanders, as it will invariably do during this pointless film).<br /><br />One might better spend one's time staring out a window at a tree growing.<br /><br />",
 'label': 0}

```
Found cached dataset imdb (/home/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1)
Loading cached processed dataset at /home/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-fa8f4f047f540716.arrow
```

In [18]:
# rename teh columns
ds = ds.rename_columns({'text': 'review', 'label': 'sentiment'})
# 
ds = ds.filter(lambda x: len(x["review"])>200, batched=False)
ds

Loading cached processed dataset at /Users/carson/.cache/huggingface/datasets/imdb/plain_text/1.0.0/2fdd8b9bcadd6e7055e742a706876ba43f19faee861df134affd7a3f60fc38a1/cache-e68378636d846987.arrow


Dataset({
    features: ['review', 'sentiment'],
    num_rows: 24895
})

In [19]:
len(ds[2]['review'])

528

In [20]:
input_size = LengthSampler(min_value = 4, max_value = 8)
output_size = LengthSampler(min_value = 4, max_value = 16)

for i in range(10):
    print(input_size())

7
4
5
5
7
7
4
4
7
4


In [21]:
gpt2_tokenizer = AutoTokenizer.from_pretrained(
    'gpt2',
    pad_token='<|endoftext|>',
    padding_side = 'left',
)

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.04M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/456k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [22]:
def map_tokenize(sample):
    
    '''
    this function is applied to the dataset and 
    only the first few tokens of review are used for "tokens"
    they are decoded and stored as query in their text form
    '''
    
    sample["tokens"] = gpt2_tokenizer.encode(sample["review"])[:input_size()]
    sample["query"] = gpt2_tokenizer.decode(sample["tokens"])
    
    return sample

ds = ds.map(map_tokenize, batched=False)

  0%|          | 0/24895 [00:00<?, ?ex/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (1168 > 1024). Running this sequence through the model will result in indexing errors


In [23]:
ds[2]

{'review': "If only to avoid making this type of film in the future. This film is interesting as an experiment but tells no cogent story.<br /><br />One might feel virtuous for sitting thru it because it touches on so many IMPORTANT issues but it does so without any discernable motive. The viewer comes away with no new perspectives (unless one comes up with one while one's mind wanders, as it will invariably do during this pointless film).<br /><br />One might better spend one's time staring out a window at a tree growing.<br /><br />",
 'sentiment': 0,
 'tokens': [1532, 691, 284, 3368, 1642],
 'query': 'If only to avoid making'}

In [27]:
def collater(data):
    return dict((key, [d[key] for d in data]) for key in data[0])

dataloader = torch.utils.data.DataLoader(ds, batch_size=2, collate_fn=collater)

type(dataloader)

torch.utils.data.dataloader.DataLoader

In [34]:
batch = next(iter(dataloader))

print(type(batch), batch.keys())
print(batch['query'])
print(batch['tokens'])

<class 'dict'> dict_keys(['review', 'sentiment', 'tokens', 'query'])
['I rented I AM C', '"I Am Curious']
[[40, 26399, 314, 3001, 327], [1, 40, 1703, 44269]]


In [35]:
query_tensors = [torch.tensor(t).long().to(device) for t in batch["tokens"]]
query_tensors

[tensor([   40, 26399,   314,  3001,   327]),
 tensor([    1,    40,  1703, 44269])]

```python
input_ids = self.data_collator([torch.cat([q, r]) for q, r in zip(query_batch, response_batch)])["input_ids"]
with torch.no_grad():
    logits, _, v = self.model(input_ids)
    ref_logits, _, _ = self.ref_model(input_ids)
logprobs = logprobs_from_logits(logits[:,:-1,:], input_ids[:,1:])
ref_logprobs = logprobs_from_logits(ref_logits[:,:-1,:], input_ids[:,1:])
```