In [119]:
import json
import os
import torch
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence
from transformers import AutoTokenizer
import numpy as np

In [33]:
class Seq2SeqDataset(Dataset):
    """
    A Simple Seq2Seq Dataset Implementation
    """
    def __init__(self, fact_filename, romantic_filename,funny_filename, tokenizer, add_bos_token=True, add_eos_token=True):
        data = []
        with open(fact_filename,'r') as f:
            line = f.readline()
            while line:
                data.append({"source":"","target":line.replace('\n',''),"style":"fact"})
                line = f.readline()

        with open(romantic_filename,'r') as f:
            line = f.readline()
            while line:
                data.append({"source":"","target":line.replace('\n',''),"style":"romantic"})
                line = f.readline()        

        with open(funny_filename,'r') as f:
            line = f.readline()
            while line:
                data.append({"source":"","target":line.replace('\n',''),"style":"funny"})
                line = f.readline()    

        self.data = data
        self.tokenizer = tokenizer
        self.add_bos_token = add_bos_token
        self.add_eos_token = add_eos_token

    def __getitem__(self, index):
        item = self.data[index]
        source_token_ids = self.tokenizer.encode(item["source"], add_special_tokens=False)
        target_token_ids = self.tokenizer.encode(item["target"], add_special_tokens=False)

        if self.add_bos_token:
            target_token_ids.insert(0, self.tokenizer.bos_token_id)

        if self.add_eos_token:
            target_token_ids.append(self.tokenizer.eos_token_id)

        item["source_token_ids"] = torch.LongTensor(source_token_ids)
        item["target_token_ids"] = torch.LongTensor(target_token_ids)
        return item

    def __len__(self):
        return len(self.data)

    def collate_fn(self, batch):
        new_batch = {}
        new_batch["source_token_ids"] = pad_sequence(
            [item["source_token_ids"] for item in batch], batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        new_batch["target_token_ids"] = pad_sequence(
            [item["target_token_ids"] for item in batch], batch_first=True, padding_value=self.tokenizer.pad_token_id
        )
        new_batch["style"] = [item["style"] for item in batch]
        return new_batch


In [26]:
data = []
fact_filename = "./fact-train.txt"
romantic_filename = "./romantic-train.txt"
funny_filename = "./funny-train.txt"

with open(fact_filename,'r') as f:
    line = f.readline()
    while line:
        data.append({"source":"","target":line.replace('\n',''),"style":"fact"})
        line = f.readline()
        
with open(romantic_filename,'r') as f:
    line = f.readline()
    while line:
        data.append({"source":"","target":line.replace('\n',''),"style":"romantic"})
        line = f.readline()        

with open(funny_filename,'r') as f:
    line = f.readline()
    while line:
        data.append({"source":"","target":line.replace('\n',''),"style":"funny"})
        line = f.readline()             
        
print(data[5999])
print(len(data))

{'source': '', 'target': 'a young girl jumping out of the water', 'style': 'fact'}
18000


In [28]:
tokenizer = AutoTokenizer.from_pretrained("roberta-base")

In [34]:
train_dataset = Seq2SeqDataset(fact_filename,romantic_filename,funny_filename , tokenizer)

In [37]:
train_dataloader = DataLoader(
            train_dataset, batch_size=32, shuffle=True, collate_fn=train_dataset.collate_fn)

In [245]:
print(batch)

{'source_token_ids': tensor([], size=(16, 0), dtype=torch.int64), 'target_token_ids': tensor([[    0,   102, 36762,   154,   313, 13905,   141,     7,  3068,    10,
          4806,    19,   129,     5,   124,  5964,    15,     5,  1255,     4,
             2],
        [    0,   260,  7497,   621,  3106,    10,  1104,   109,  1899,     8,
         32950,    49, 15401,   579,  3976,   254,   352,     4,     2,     1,
             1],
        [    0,   627,   891,    33,    10,  8728,  1151,  1937,    15,    49,
          1656,   479,     2,     1,     1,     1,     1,     1,     1,     1,
             1],
        [    0,   102,   664,  2143,    23,    10,   537,  6614,    24,   101,
            24,   128,    29,  2131,   479,     2,     1,     1,     1,     1,
             1],
        [    0,  9983,   390,    11,   657,    32,  3051,   149,     5,  2014,
            11,    10,  3826,  1692,    12,   242, 34927,  1139,   479,     2,
             1],
        [    0,   642, 12151,   918,   

In [40]:
from transformers.models.gpt2 import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
import torch

In [253]:
model = GPT2LMHeadModel.from_pretrained('gpt2')

Downloading:   0%|          | 0.00/665 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/523M [00:00<?, ?B/s]

In [246]:
past_mask = torch.ones(len(batch["target_mask"]),1).bool()
print(past_mask.shape)
batch["target_mask"] = batch["target_token_ids"] != tokenizer.pad_token_id
print(batch["target_mask"].shape)
        # batch["source_position_ids"] = batch["source_mask"].cumsum(-1) - 1
batch["target_position_ids"] = batch["target_mask"].cumsum(-1)
joint_mask = torch.cat((past_mask,batch["target_token_ids"] != tokenizer.pad_token_id),dim=1)
print(batch["target_mask"].shape)
print(batch["target_position_ids"].shape)

torch.Size([16, 1])
torch.Size([16, 21])
torch.Size([16, 21])
torch.Size([16, 21])


In [222]:
# out = model(
#             input_ids=batch["target_token_ids"],
#             position_ids=batch['target_position_ids'],
#             attention_mask=batch["target_mask"],
#         )

In [89]:
print(len(list(out['past_key_values'][0][0][0][0])))
#(12,2, batch_size, num_head, sql_len+1, head_features)
#12 2 16 12 21 64
print(len(batch['target_position_ids']))

21
16


In [247]:
past = torch.randn(size=(12,2,16,12,1,64))
out = model(
            input_ids=batch["target_token_ids"],
            position_ids=batch['target_position_ids'],
            attention_mask=joint_mask,
            past_key_values = past
        )
print(len(out.past_key_values))
print(len(out.past_key_values[0]))
print(len(out.past_key_values[0][0]))
print(len(out.past_key_values[0][0][0]))
print(len(out.past_key_values[0][0][0][0]))
print(len(out.past_key_values[0][0][0][0][0]))


12
2
16
12
22
64


In [249]:
print(len(out.logits[0]))
print(len(out.logits[0,0,:]))
print(out.logits[0,-1,:])
print(torch.max(out.logits[0,-1,:]))

21
50257
tensor([-39.1996, -32.7912, -39.5255,  ..., -46.2507, -52.3892, -35.2706],
       grad_fn=<SliceBackward0>)
tensor(-32.1514, grad_fn=<MaxBackward1>)


In [228]:
yy = out.past_key_values[11][1]
print(yy.shape)
past0 = yy[:,:,0,:]
print(past0.shape)
past = past [11,1,:,:,0,:]
print(past.shape)
print(past0==past)

torch.Size([16, 12, 22, 64])
torch.Size([16, 12, 64])
torch.Size([16, 12, 64])
tensor([[[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True]],

        [[True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         [True, True, True,  ..., True, True, True],
         ...,
         [True, True, True,  ..., True, True, True],
         [True, True, True,

In [194]:
h=torch.tensor([[True,True,False],[True,True,True]])
h = h.cumsum(-1)
print(h)
print(h[:,-1].unsqueeze(-1)+h)

tensor([[1, 2, 2],
        [1, 2, 3]])
tensor([[3, 4, 4],
        [4, 5, 6]])


In [233]:
x = torch.tentor([1,0,0])
x = torch.tile(x.unsqueeze(1).unsqueeze(2).unsqueeze(0).unsqueeze(0),(12,2,1,12,1,1))
print(x.shape)

torch.Size([12, 2, 16, 12, 1, 3])


In [234]:
past = torch.randn(size=(12,2,16,12,1,61))
torch.concat((x,past),dim=-1)

tensor([[[[[[ 1.9894e-01,  9.6250e-01,  5.1389e-01,  ..., -2.6888e-01,
             -1.7448e+00, -6.4395e-01]],

           [[ 1.9894e-01,  9.6250e-01,  5.1389e-01,  ..., -2.0361e-02,
             -8.5841e-01, -1.1557e+00]],

           [[ 1.9894e-01,  9.6250e-01,  5.1389e-01,  ...,  2.5844e+00,
              1.5617e+00, -5.4411e-01]],

           ...,

           [[ 1.9894e-01,  9.6250e-01,  5.1389e-01,  ...,  3.3670e-01,
             -7.2903e-01,  6.5739e-01]],

           [[ 1.9894e-01,  9.6250e-01,  5.1389e-01,  ...,  1.9975e+00,
              6.0292e-01,  9.3579e-01]],

           [[ 1.9894e-01,  9.6250e-01,  5.1389e-01,  ...,  6.1128e-01,
              2.9070e-01,  6.4425e-01]]],


          [[[ 8.4250e-01,  4.0164e-02,  4.3048e-01,  ..., -1.8983e+00,
             -9.2982e-01,  1.6228e+00]],

           [[ 8.4250e-01,  4.0164e-02,  4.3048e-01,  ...,  3.8651e-01,
             -9.2702e-01,  8.6966e-01]],

           [[ 8.4250e-01,  4.0164e-02,  4.3048e-01,  ..., -1.6174e+00,
      

In [252]:
x=torch.randn((7,8,9))
y=x.transpose(0,2)
print(x.shape,y.shape)

torch.Size([7, 8, 9]) torch.Size([9, 8, 7])
