In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [2]:
from typing import List, Optional, Tuple, Union
from torch.nn import CrossEntropyLoss
from torch import nn
import torch

class BartClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(
        self,
        input_dim: int,
        inner_dim: int,
        num_classes: int,
        pooler_dropout: float,
    ):
        super().__init__()
        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.out_proj = nn.Linear(inner_dim, num_classes)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states

In [3]:
from transformers import T5Tokenizer, T5ForConditionalGeneration, T5Config

tokenizer = T5Tokenizer.from_pretrained('mesolitica/t5-small-standard-bahasa-cased')

In [4]:
class T5Tagging(T5ForConditionalGeneration):
    def __init__(self, config: T5Config, **kwargs):
        super().__init__(config, **kwargs)
        self.classification_head = BartClassificationHead(
            config.d_model,
            config.d_model,
            config.num_labels,
            config.dropout_rate,
        )
        self._init_weights(self.classification_head.dense)
        self._init_weights(self.classification_head.out_proj)
    
    def forward(
        self,
        input_ids: Optional[torch.LongTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.BoolTensor] = None,
        head_mask: Optional[torch.FloatTensor] = None,
        decoder_head_mask: Optional[torch.FloatTensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        labels_tag = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):
        outputs = super().forward(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            labels=labels,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=True,
            return_dict=True,
        )
        
        last_layer = outputs.decoder_hidden_states[-1]
        logits = self.classification_head(last_layer)
        if labels_tag is not None:
            loss_fct = CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.config.num_labels), labels_tag.view(-1))
            outputs.loss += loss
            
        return outputs

In [5]:
from glob import glob

checkpoints = sorted(glob('finetune-t5-tiny-standard-bahasa-cased/checkpoint-*'))
checkpoints

['finetune-t5-tiny-standard-bahasa-cased/checkpoint-610000',
 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-620000',
 'finetune-t5-tiny-standard-bahasa-cased/checkpoint-630000']

In [6]:
model = T5Tagging.from_pretrained(checkpoints[-1])

In [7]:
import json

data = []
with open('test.jsonl') as fopen:
    for l in fopen:
        data.append(json.loads(l))

In [8]:
len(data)

4994

In [9]:
x = ' '.join([d[0] for d in data[1][1]])
x

'Cognocoli-Monticchi ialah komun bak jabatan Corse-du-Sud di kepulauan Corsica ; Perancis ?'

In [10]:
input_ids = [{'input_ids': tokenizer.encode(f'kesalahan tatabahasa:{s}', return_tensors='pt')[
    0]} for s in [x]]
padded = tokenizer.pad(input_ids, padding='longest')
outputs = model.generate(**padded,do_sample=True, 
    max_length=50, 
    top_k=0,num_return_sequences=3,
              output_attentions = True,
                        output_hidden_states = True,
                        output_scores = True,
                        return_dict_in_generate = True)

In [11]:
outputs.keys()

odict_keys(['sequences', 'scores', 'encoder_attentions', 'encoder_hidden_states', 'decoder_attentions', 'cross_attentions', 'decoder_hidden_states'])

In [12]:
tokenizer.batch_decode(outputs.sequences[:,1:], )

['Cognocoli-Monticchi ialah komun di jabatan Corse-du-Sud di kepulauan Corsica, Perancis.</s>',
 'Cognocoli-Monticchi ialah komun di jabatan Corse-du-Sud di kepulauan Corsica, Perancis.</s>',
 'Cognocoli-Monticchi ialah komun di jabatan Corse-du-Sud di kepulauan Corsica, Perancis.</s>']

In [13]:
outputs.sequences.shape

torch.Size([3, 31])

In [14]:
len(outputs.decoder_hidden_states)

30

In [15]:
import torch

In [16]:
last_layer = torch.stack([o[-1] for o in outputs.decoder_hidden_states])[:,:,0]
last_layer = last_layer.transpose(0, 1)
last_layer.shape

torch.Size([3, 30, 384])

In [17]:
t = model.classification_head(last_layer).detach().numpy().argmax(axis = -1)
t.shape

(3, 30)

In [18]:
from malaya.text.bpe import merge_sentencepiece_tokens_tagging

  self.tok = re.compile(r'({})'.format('|'.join(pipeline)))
  self.tok = re.compile(r'({})'.format('|'.join(pipeline)))


In [19]:
len(outputs.sequences)

3

In [20]:
s = outputs.sequences[:,1:][0].detach().numpy()
s = tokenizer.convert_ids_to_tokens(s)

In [21]:
len(s), len(t)

(30, 3)

In [22]:
from malaya.text.bpe import SPECIAL_TOKENS

In [23]:
tokenizer.tokenize('saya suka')

['▁saya', '▁suka']

In [24]:
def merge_sentencepiece_tokens_tagging(x, y, model='bert', rejected = None, **kwargs):
    new_paired_tokens = []
    n_tokens = len(x)
    if rejected is None:
        rejected = list(SPECIAL_TOKENS[model].values())

    i = 0

    while i < n_tokens:

        current_token, current_label = x[i], y[i]

        if isinstance(current_token, bytes):
            current_token = current_token.decode()
        if not current_token.startswith('▁') and current_token not in rejected and i > 0:
            previous_token, previous_label = new_paired_tokens.pop()
            merged_token = previous_token
            merged_label = [previous_label]
            while (
                not current_token.startswith('▁')
                and current_token not in rejected
            ):
                merged_token = merged_token + current_token.replace('▁', '')
                merged_label.append(current_label)
                i = i + 1
                current_token, current_label = x[i], y[i]
            merged_label = merged_label[0]
            new_paired_tokens.append((merged_token, merged_label))

        else:
            new_paired_tokens.append((current_token, current_label))
            i = i + 1

    words = [
        i[0].replace('▁', '') for i in new_paired_tokens if i[0] not in rejected
    ]
    labels = [i[1] for i in new_paired_tokens if i[0] not in rejected]
    return words, labels

In [25]:
merged = merge_sentencepiece_tokens_tagging(
    s, t[0], rejected = tokenizer.all_special_tokens
)
list(zip(merged[0], merged[1]))

[('Cognocoli-Monticchi', 2),
 ('ialah', 2),
 ('komun', 2),
 ('di', 9),
 ('jabatan', 2),
 ('Corse-du-Sud', 2),
 ('di', 2),
 ('kepulauan', 2),
 ('Corsica', 2),
 (',', 14),
 ('Perancis', 2),
 ('.', 14)]

In [27]:
model.push_to_hub('finetune-tatabahasa-t5-tiny-standard-bahasa-cased', organization='mesolitica')

Cloning https://huggingface.co/mesolitica/finetune-tatabahasa-t5-tiny-standard-bahasa-cased into local empty directory.


Upload file pytorch_model.bin:   0%|          | 4.00k/133M [00:00<?, ?B/s]

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/mesolitica/finetune-tatabahasa-t5-tiny-standard-bahasa-cased
   09cbd5e..2b53521  main -> main



'https://huggingface.co/mesolitica/finetune-tatabahasa-t5-tiny-standard-bahasa-cased/commit/2b53521b65ca1b07e531f5b109d39c56cb35d00b'

In [28]:
tokenizer.push_to_hub('finetune-tatabahasa-t5-tiny-standard-bahasa-cased', organization='mesolitica')

Upload file spiece.model:   1%|          | 4.00k/784k [00:00<?, ?B/s]

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/mesolitica/finetune-tatabahasa-t5-tiny-standard-bahasa-cased
   2b53521..632337e  main -> main



'https://huggingface.co/mesolitica/finetune-tatabahasa-t5-tiny-standard-bahasa-cased/commit/632337e0eed28996c5aa1e30ed4fff1fd423a9ff'

In [52]:
import collections

def compute_exact(a_gold, a_pred):
    return int(a_gold == a_pred)


def compute_f1(a_gold, a_pred):
    gold_toks = a_gold
    pred_toks = a_pred
    common = collections.Counter(gold_toks) & collections.Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        # If either is no-answer, then F1 is 1 if they agree, 0 otherwise
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1

In [61]:
len(data)

4994

In [73]:
from tqdm import tqdm

exact_match, f1, exact_match_tag, f1_tag = [], [], [], []
for i in tqdm(range(len(data))):
    x = ' '.join([d[0] for d in data[i][1]])
    y = [d[0] for d in data[i][0]]
    tag = [d[1] for d in data[i][1]]
    input_ids = [{'input_ids': tokenizer.encode(f'kesalahan tatabahasa:{s}', return_tensors='pt')[
        0]} for s in [x]]
    padded = tokenizer.pad(input_ids, padding='longest')
    outputs = model.generate(**padded,max_length=256, output_attentions = True,
                            output_hidden_states = True,
                            output_scores = True,
                            return_dict_in_generate = True)
    last_layer = torch.stack([o[-1] for o in outputs.decoder_hidden_states])[:,:,0]
    last_layer = last_layer.transpose(0, 1)
    t = model.classification_head(last_layer).detach().numpy().argmax(axis = -1)
    s = outputs.sequences[:,1:][0].detach().numpy()
    s = tokenizer.convert_ids_to_tokens(s)
    merged = merge_sentencepiece_tokens_tagging(
        s, t[0], rejected = tokenizer.all_special_tokens
    )
    
    exact_match.append(compute_exact(y, merged[0]))
    exact_match_tag.append(compute_exact(tag, merged[1]))
    f1.append(compute_f1(y, merged[0]))
    f1_tag.append(compute_f1(tag, merged[1]))

100%|████████████████████████████████████████████████████████████████████████████████████████████| 4994/4994 [09:16<00:00,  8.97it/s]


In [70]:
import numpy as np

In [74]:
np.mean(exact_match), np.mean(exact_match_tag)

(0.7665198237885462, 0.8740488586303564)

In [75]:
np.mean(f1), np.mean(f1_tag)

(0.9709082299141384, 0.9878723587004393)