In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [2]:
from transformers import (
    AutoTokenizer,
    T5ForConditionalGeneration,
    AutoConfig,
)
from pytorch_lightning import LightningModule
import torch

In [3]:
tokenizer = AutoTokenizer.from_pretrained('./out-large-1.1')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [4]:
class Module(LightningModule):
    def __init__(self):
        super().__init__()
        config = AutoConfig.from_pretrained(
            './out-large-1.1'
        )
        self.model = T5ForConditionalGeneration.from_pretrained(
            './out-large-1.1',
            config=config,
        )

In [5]:
!ls -lh logs/large

total 44G
-rwxrwxrwx 1 ubuntu ubuntu 8.8G Oct 26 02:34 'model-epoch=00-step=1000.ckpt'
-rwxrwxrwx 1 ubuntu ubuntu 8.8G Oct 26 03:59 'model-epoch=00-step=2000.ckpt'
-rw-r--r-- 1 ubuntu ubuntu 8.8G Oct 26 04:52 'model-epoch=00-step=2400.ckpt'
-rw-r--r-- 1 ubuntu ubuntu 8.8G Oct 26 05:22 'model-epoch=00-step=2800.ckpt'
-rw-r--r-- 1 ubuntu ubuntu 8.8G Oct 26 05:52 'model-epoch=00-step=3200.ckpt'


In [6]:
model = Module()

In [7]:
weights = model.state_dict()

In [11]:
old_weights = torch.load('logs/large/model-epoch=00-step=3200.ckpt',
                             map_location=torch.device('cpu'))['state_dict'].items()

In [12]:
for k, v in old_weights:
    new_k = k.replace('._orig_mod', '')
    print(k, new_k)
    weights[new_k] = v

model.shared.weight model.shared.weight
model.encoder.embed_tokens.weight model.encoder.embed_tokens.weight
model.encoder.block.0.layer.0.SelfAttention.q.weight model.encoder.block.0.layer.0.SelfAttention.q.weight
model.encoder.block.0.layer.0.SelfAttention.k.weight model.encoder.block.0.layer.0.SelfAttention.k.weight
model.encoder.block.0.layer.0.SelfAttention.v.weight model.encoder.block.0.layer.0.SelfAttention.v.weight
model.encoder.block.0.layer.0.SelfAttention.o.weight model.encoder.block.0.layer.0.SelfAttention.o.weight
model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight model.encoder.block.0.layer.0.SelfAttention.relative_attention_bias.weight
model.encoder.block.0.layer.0.layer_norm.weight model.encoder.block.0.layer.0.layer_norm.weight
model.encoder.block.0.layer.1.DenseReluDense.wi_0.weight model.encoder.block.0.layer.1.DenseReluDense.wi_0.weight
model.encoder.block.0.layer.1.DenseReluDense.wi_1.weight model.encoder.block.0.layer.1.DenseReluDense.wi_1.w

In [13]:
model.load_state_dict(weights)

<All keys matched successfully>

In [14]:
from utils.copied_utils import (
    compute_input_and_target_lengths,
    DataCollatorForT5MLM,
    tokenize_function,
    DataCollatorForNI,
)

In [15]:
before_mask_input_length, target_length = compute_input_and_target_lengths(
    inputs_length=512,
    noise_density=0.15,
    mean_noise_span_length=3.0,
)
before_mask_input_length, target_length

(568, 114)

In [16]:
from streaming.base.format.mds.encodings import Encoding, _encodings
from streaming import StreamingDataset
import torch
import numpy as np

class Int32(Encoding):
    def encode(self, obj) -> bytes:
        return obj.tobytes()

    def decode(self, data: bytes):
        return np.frombuffer(data, np.int32)


_encodings['int32'] = Int32


class DatasetFixed(torch.utils.data.Dataset):
    def __init__(self, local):
        self.dataset = StreamingDataset(local=local)

    def __getitem__(self, idx):
        data = self.dataset[idx]
        data.pop('token_type_ids', None)
        for k in data.keys():
            data[k] = data[k].astype(np.int64)
        return data

    def __len__(self):
        return len(self.dataset)

In [17]:
dataset = DatasetFixed(local='/home/ubuntu/nanot5-512')

In [18]:
data_collator = DataCollatorForT5MLM(
    tokenizer=tokenizer,
    noise_density=0.15,
    mean_noise_span_length=3.0,
    input_length=512,
    target_length=target_length,
    pad_token_id=0,
)

In [19]:
batch = [dataset[i] for i in range(5000000, 5000001, 1)]

In [20]:
padded = data_collator(batch)

In [21]:
o = model.model(**padded)

In [22]:
o.logits

tensor([[[-2.7402,  2.5614,  6.1874,  ..., -1.3941, -3.1445, -4.2745],
         [-3.2996,  1.7034,  7.6094,  ..., -0.9311, -2.4005, -2.0222],
         [-2.5788,  1.4604,  7.6094,  ..., -0.8557, -1.9514, -1.8257],
         ...,
         [-1.7656,  2.4665,  6.3454,  ..., -1.3973, -0.4826, -1.8745],
         [-1.3759,  2.9568,  7.9547,  ..., -1.9363, -2.1684, -3.5167],
         [-1.0695,  3.1586, 28.7467,  ..., -2.2941, -2.2790, -4.0657]]],
       grad_fn=<UnsafeViewBackward0>)

In [23]:
padded['labels']
tokenizer.decode(padded['labels'][0])

'<extra_id_40> kilang Swarovs<extra_id_59> seorang husband reseller tu,<extra_id_38> roma<extra_id_18> Bini makin<extra_id_46>-2022 06:28 PM Batu api sungg<extra_id_66>iden gil<extra_id_71> seseorang yang amat bermakna dalam<extra_id_67> menerima<extra_id_6> hidup,<extra_id_53> penulis<extra_id_10>wah persahabatan atau<extra_id_61> agar<extra_id_2>tah<extra_id_84> tiada jawapan<extra_id_8> Membuka pintu<extra_id_28> yang<extra_id_70> Bloom<extra_id_76>z?<extra_id_62>...tah br<extra_id_54> duit...huhuh...<extra_id_12> banyak kali ktorang putus...<extra_id_0>s<extra_id_75> nk tgu.....<extra_id_85> tapi duit<extra_id_83>gu<extra_id_52> duit<extra_id_89>....pac<extra_id_77>...mkn la</s>'

In [24]:
tokenizer.decode(o.logits.argmax(-1).detach().cpu().numpy()[0])

'<extra_id_40> B<extra_id_59><extra_id_59><extra_id_59><extra_id_59><extra_id_59>. lelaki.it..<extra_id_38>.atu<extra_id_18> akutw<extra_id_46><extra_id_46>-2022 10:41 PM S C tu<extra_id_66><extra_id_66>em<extra_id_71><extra_id_71><extra_id_71> dan<extra_id_67> tidak<extra_id_67><extra_id_67><extra_id_67> tidak<extra_id_6> tidak<extra_id_53><extra_id_53>nya<extra_id_10>awan.is<extra_id_61><extra_id_61>.<extra_id_2>ak<extra_id_84> tidak<extra_id_8> yang Berikan hati<extra_id_28> yang<extra_id_70>,twberg.<extra_id_62><extra_id_62>...<extra_id_54>2<extra_id_54>......<extra_id_12>uh<extra_id_12><extra_id_12><extra_id_12><extra_id_0>2... die......<extra_id_0>s<extra_id_75> b t<extra_id_85>...<extra_id_85>...<extra_id_83><extra_id_83>gk<extra_id_52> aku<extra_id_89>k</s><extra_id_77><extra_id_77>...</s>cm aku</s>'