In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [2]:
from transformers import (
    AutoTokenizer,
    T5ForConditionalGeneration,
    AutoConfig,
)
from pytorch_lightning import LightningModule
import torch

In [3]:
tokenizer = AutoTokenizer.from_pretrained('./out-small-1.1')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
class Module(LightningModule):
    def __init__(self):
        super().__init__()
        config = AutoConfig.from_pretrained(
            './out-small-1.1'
        )
        self.model = T5ForConditionalGeneration.from_pretrained(
            './out-small-1.1',
            config=config,
        )

In [5]:
!ls -lh logs/small

total 2.7G
-rwxrwxrwx 1 ubuntu ubuntu 917M Oct 26 23:09 'model-epoch=04-step=42000.ckpt'
-rwxrwxrwx 1 ubuntu ubuntu 917M Oct 26 23:35 'model-epoch=04-step=44000.ckpt'
-rwxrwxrwx 1 ubuntu ubuntu 917M Oct 27 00:01 'model-epoch=04-step=46000.ckpt'


In [6]:
model = Module()

In [7]:
weights = model.state_dict()

In [8]:
old_weights = torch.load('logs/small/model-epoch=04-step=46000.ckpt',
                             map_location=torch.device('cpu'))['state_dict'].items()

In [9]:
for k, v in old_weights:
    new_k = k.replace('._orig_mod', '')
    print(k, new_k)
    weights[new_k] = v

model.shared.weight model.shared.weight
model.encoder.embed_tokens.weight model.encoder.embed_tokens.weight
model.encoder.block.0.layer.0.SelfAttention.q.weight model.encoder.block.0.layer.0.SelfAttention.q.weight
model.encoder.block.0.layer.0.SelfAttention.k.weight model.encoder.block.0.layer.0.SelfAttention.k.weight
model.encoder.block.0.layer.0.SelfAttention.v.weight model.encoder.block.0.layer.0.SelfAttention.v.weight
model.encoder.block.0.layer.0.SelfAttention.o.weight model.encoder.block.0.layer.0.SelfAttention.o.weight
model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight
model.encoder.block.0.layer.0.layer_norm.weight model.encoder.block.0.layer.0.layer_norm.weight
model.encoder.block.0.layer.1.DenseReluDense.wi_0.weight model.encoder.block.0.layer.1.DenseReluDense.wi_0.weight
model.encoder.block.0.layer.1.DenseReluDense.wi_1.weight model.encoder.block.0.layer.1.DenseReluDense.wi_1.w

In [10]:
model.load_state_dict(weights)

<All keys matched successfully>

In [11]:
from utils.copied_utils import (
    compute_input_and_target_lengths,
    DataCollatorForT5MLM,
    tokenize_function,
    DataCollatorForNI,
)

In [12]:
before_mask_input_length, target_length = compute_input_and_target_lengths(
    inputs_length=512,
    noise_density=0.15,
    mean_noise_span_length=3.0,
)
before_mask_input_length, target_length

(568, 114)

In [13]:
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import StreamingDataset
import torch
import numpy as np

class Int32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.int32)


_encodings['int32'] = Int32


class DatasetFixed(torch.utils.data.Dataset):
    def __init__(self, local):
        self.dataset = StreamingDataset(local=local)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        data.pop('token_type_ids', None)
        for k in data.keys():
            data[k] = data[k].astype(np.int64)
        return data

    def __len__(self):
        return len(self.dataset)

In [14]:
dataset = DatasetFixed(local='/home/ubuntu/nanot5-512')

In [15]:
data_collator = DataCollatorForT5MLM(
    tokenizer=tokenizer,
    noise_density=0.15,
    mean_noise_span_length=3.0,
    input_length=512,
    target_length=target_length,
    pad_token_id=0,
)

In [16]:
batch = [dataset[i] for i in range(5000000, 5000001, 1)]

In [17]:
padded = data_collator(batch)

In [18]:
o = model.model(**padded)

In [19]:
o.logits

tensor([[[-0.1583, -1.0797,  2.4659,  ..., -0.3984, -0.3889, -0.9530],
         [-0.9148, -3.7254,  8.1551,  ..., -1.9462, -2.2352, -2.9646],
         [-0.3222, -1.2524,  9.8471,  ..., -1.1862, -1.5322, -4.7910],
         ...,
         [-0.9169,  1.2961,  6.9794,  ..., -0.5234,  0.4043, -0.7112],
         [-1.0103, -1.0219,  8.2162,  ..., -2.5970, -2.2210, -3.6502],
         [-6.6296,  2.9095, 40.6058,  ..., -4.4431, -2.1524, -3.0650]]],
       grad_fn=<UnsafeViewBackward0>)

In [20]:
padded['labels']
tokenizer.decode(padded['labels'][0])

'<extra_id_40>. Boleh pi IG Haliza May<extra_id_59> dia tanya<extra_id_38> tak ramai<extra_id_18> dia balik<extra_id_46></s><s><extra_id_66>28 PM<extra_id_71> sket<extra_id_67>wira konon tp<extra_id_6> makcik bawang.<extra_id_53>, konfiden giler<extra_id_10></s><extra_id_61>arkan<extra_id_2> penulis yang telah<extra_id_84>rah dan menerima ketentuan Il<extra_id_8> berpisah ketika masih hidup, sakitnya<extra_id_28> akhirnya<extra_id_70> memisahkan. Banyak<extra_id_76> p<extra_id_62>atapan<extra_id_54>?? Dia pun<extra_id_12></s><s> akak kot<extra_id_0> tgu...tah br<extra_id_75> kalau die betul<extra_id_85> die<extra_id_83> Skuau at 26<extra_id_52> PM<extra_id_89>....<extra_id_77> ha...mkn la</s>'

In [21]:
tokenizer.decode(o.logits.argmax(-1).detach().cpu().numpy()[0])

'<extra_id_40>. Pas<extra_id_59> Mah<extra_id_59><extra_id_59> Mah<extra_id_59> dia<extra_id_38> dia siapa<extra_id_18><extra_id_18> dia balik<extra_id_46>.<s><extra_id_66>39 PM<extra_id_71>..rg..giira<extra_id_6><extra_id_6><extra_id_6>...<extra_id_53>..<extra_id_10>ekem<extra_id_10><extra_id_10>er<extra_id_10></s><extra_id_61>yaratkan<extra_id_2> kisah<extra_id_84><extra_id_84><extra_id_84>rah dengan kembali ketentuan Il<extra_id_8> tidak<extra_id_28> itu<extra_id_28>,<extra_id_28><extra_id_28><extra_id_28> akan<extra_id_70> berlaku. Banyak<extra_id_76> p<extra_id_62>edaapan<extra_id_54>...</s> pun<extra_id_12></s><s> h<extra_id_0><extra_id_0> tgu...<extra_id_75><extra_id_75><extra_id_75> kalau betul betul<extra_id_85> die<extra_id_83> muau at 24<extra_id_52> PM<extra_id_89>.<extra_id_77> yg...</s>cm x</s>'