In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [2]:
from transformers import (
    AutoTokenizer,
    T5ForConditionalGeneration,
    AutoConfig,
)
from pytorch_lightning import LightningModule
import torch

In [3]:
tokenizer = AutoTokenizer.from_pretrained('./out-base-1.1')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
class Module(LightningModule):
    def __init__(self):
        super().__init__()
        config = AutoConfig.from_pretrained(
            './out-base-1.1'
        )
        self.model = T5ForConditionalGeneration.from_pretrained(
            './out-base-1.1',
            config=config,
        )

In [5]:
!ls -lh logs/base

total 20G
-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 25 03:43 'model-epoch=00-step=2800.ckpt'
-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 25 03:48 'model-epoch=00-step=3000.ckpt'
-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 25 03:54 'model-epoch=00-step=3200.ckpt'
-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 25 16:38 'model-epoch=01-step=26000.ckpt'
-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 26 00:21 'model-epoch=02-step=44000.ckpt'
-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 26 00:45 'model-epoch=03-step=45000.ckpt'
-rwxrwxrwx 1 ubuntu ubuntu 2.8G Oct 26 01:09 'model-epoch=03-step=46000.ckpt'


In [6]:
model = Module()

In [7]:
weights = model.state_dict()

In [8]:
old_weights = torch.load('logs/base/model-epoch=03-step=46000.ckpt',
                             map_location=torch.device('cpu'))['state_dict'].items()

In [9]:
for k, v in old_weights:
    new_k = k.replace('._orig_mod', '')
    print(k, new_k)
    weights[new_k] = v

model.shared.weight model.shared.weight
model.encoder.embed_tokens.weight model.encoder.embed_tokens.weight
model.encoder.block.0.layer.0.SelfAttention.q.weight model.encoder.block.0.layer.0.SelfAttention.q.weight
model.encoder.block.0.layer.0.SelfAttention.k.weight model.encoder.block.0.layer.0.SelfAttention.k.weight
model.encoder.block.0.layer.0.SelfAttention.v.weight model.encoder.block.0.layer.0.SelfAttention.v.weight
model.encoder.block.0.layer.0.SelfAttention.o.weight model.encoder.block.0.layer.0.SelfAttention.o.weight
model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight
model.encoder.block.0.layer.0.layer_norm.weight model.encoder.block.0.layer.0.layer_norm.weight
model.encoder.block.0.layer.1.DenseReluDense.wi_0.weight model.encoder.block.0.layer.1.DenseReluDense.wi_0.weight
model.encoder.block.0.layer.1.DenseReluDense.wi_1.weight model.encoder.block.0.layer.1.DenseReluDense.wi_1.w

In [10]:
model.load_state_dict(weights)

<All keys matched successfully>

In [11]:
from utils.copied_utils import (
    compute_input_and_target_lengths,
    DataCollatorForT5MLM,
    tokenize_function,
    DataCollatorForNI,
)

In [12]:
before_mask_input_length, target_length = compute_input_and_target_lengths(
    inputs_length=512,
    noise_density=0.15,
    mean_noise_span_length=3.0,
)
before_mask_input_length, target_length

(568, 114)

In [13]:
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import StreamingDataset
import torch
import numpy as np

class Int32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.int32)


_encodings['int32'] = Int32


class DatasetFixed(torch.utils.data.Dataset):
    def __init__(self, local):
        self.dataset = StreamingDataset(local=local)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        data.pop('token_type_ids', None)
        for k in data.keys():
            data[k] = data[k].astype(np.int64)
        return data

    def __len__(self):
        return len(self.dataset)

In [14]:
dataset = DatasetFixed(local='/home/ubuntu/nanot5-512')

In [15]:
data_collator = DataCollatorForT5MLM(
    tokenizer=tokenizer,
    noise_density=0.15,
    mean_noise_span_length=3.0,
    input_length=512,
    target_length=target_length,
    pad_token_id=0,
)

In [16]:
batch = [dataset[i] for i in range(5000000, 5000001, 1)]

In [17]:
padded = data_collator(batch)

In [18]:
o = model.model(**padded)

In [19]:
o.logits

tensor([[[-1.4642,  1.9711,  8.1804,  ..., -1.3669, -1.2231, -1.7239],
         [-2.6246, -0.4478, 13.4775,  ..., -0.4373, -1.4065, -1.1230],
         [-2.0779,  2.4580, 14.7889,  ..., -1.0686, -1.1595, -1.2585],
         ...,
         [-1.3984,  2.0701,  8.1561,  ..., -1.2381, -2.6082, -1.7596],
         [-2.2733, -0.1507,  8.6637,  ..., -2.3320, -1.3610, -2.8000],
         [-0.5552,  6.5587, 32.7664,  ..., -0.7088, -1.3464, -2.2160]]],
       grad_fn=<UnsafeViewBackward0>)

In [20]:
padded['labels']
tokenizer.decode(padded['labels'][0])

'<extra_id_40> dll)</s><extra_id_59> Bawal Exclusive pi kilang S<extra_id_38> husband reseller<extra_id_18> tanya<extra_id_46> Bini<extra_id_66>..<extra_id_71> la sk<extra_id_67>t<extra_id_6> hubungan<extra_id_53>wah pers<extra_id_10> tiada jawapan ditemui, kerana yang<extra_id_61>anya. Bagaikan tersedar dari lena<extra_id_2> penulis mengakhiri<extra_id_84>ang dada. Mener<extra_id_8>.<extra_id_28> pasti<extra_id_70></s><extra_id_76><s> Ad<extra_id_62>2<extra_id_54> dah t<extra_id_12>osaur<extra_id_0> Best x??<extra_id_75> lama lagi<extra_id_85> duit kt die tuk<extra_id_83>agaknyalahhhh</s><s> Edited by<extra_id_52>antang<extra_id_89>....<extra_id_77>...ni ha...mkn la</s>'

In [21]:
tokenizer.decode(o.logits.argmax(-1).detach().cpu().numpy()[0])

'<extra_id_40></s></s></s><extra_id_59> Sarac S<extra_id_38> S S S<extra_id_38> kawan<extra_id_18>ign<extra_id_18> kata kat makinodoh dia.<extra_id_71> sk nak<extra_id_67>.<extra_id_6> hubungan<extra_id_53>wah pers<extra_id_10> penulis siapa yang. penulis si<extra_id_61>anya. Namunimbangkan tidakedar<extra_id_2> segalaiku<extra_id_2> penulis menerimahiri<extra_id_84>ang dada, Mener<extra_id_8>.<extra_id_28> pasti<extra_id_70></s><extra_id_76><s> Ad<extra_id_62>2<extra_id_54> t t<extra_id_12>our<extra_id_0></s> ke?<extra_id_75> lama<extra_id_85><extra_id_85> die utk die utk<extra_id_83></s>aknya</s></s></s></s><s> Edited by<extra_id_52>n<extra_id_89>.<extra_id_77> tu</s> mestiuk</s>cmn</s>'