In [1]:
import os

os.environ['CUDA_VISIBLE_DEVICES'] = ''

In [2]:
from glob import glob

checkpoints = sorted(glob('finetune-t5-base-standard-bahasa-cased/checkpoint-*'))
checkpoints

['finetune-t5-base-standard-bahasa-cased/checkpoint-160000',
 'finetune-t5-base-standard-bahasa-cased/checkpoint-170000',
 'finetune-t5-base-standard-bahasa-cased/checkpoint-180000',
 'finetune-t5-base-standard-bahasa-cased/checkpoint-190000',
 'finetune-t5-base-standard-bahasa-cased/checkpoint-200000']

In [3]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained('mesolitica/t5-base-standard-bahasa-cased')

In [4]:
from transformers.models.bart.modeling_bart import shift_tokens_right
from transformers.modeling_outputs import Seq2SeqSequenceClassifierOutput
from transformers import T5Model, T5Config
from typing import List, Optional, Tuple, Union
from torch import nn
import torch

In [5]:
class BartClassificationHead(nn.Module):
    """Head for sentence-level classification tasks."""

    def __init__(
        self,
        input_dim: int,
        inner_dim: int,
        num_classes: int,
        pooler_dropout: float,
    ):
        super().__init__()
        self.dense = nn.Linear(input_dim, inner_dim)
        self.dropout = nn.Dropout(p=pooler_dropout)
        self.out_proj = nn.Linear(inner_dim, num_classes)

    def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.dense(hidden_states)
        hidden_states = torch.tanh(hidden_states)
        hidden_states = self.dropout(hidden_states)
        hidden_states = self.out_proj(hidden_states)
        return hidden_states


class T5ForSequenceClassification(T5Model):
    def __init__(self, config: T5Config, **kwargs):
        super().__init__(config, **kwargs)
        self.classification_head = BartClassificationHead(
            config.d_model,
            config.d_model,
            config.num_labels,
            config.dropout_rate,
        )
        self._init_weights(self.classification_head.dense)
        self._init_weights(self.classification_head.out_proj)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        decoder_input_ids: Optional[torch.LongTensor] = None,
        decoder_attention_mask: Optional[torch.LongTensor] = None,
        head_mask: Optional[torch.Tensor] = None,
        decoder_head_mask: Optional[torch.Tensor] = None,
        cross_attn_head_mask: Optional[torch.Tensor] = None,
        encoder_outputs: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
    ):

        decoder_input_ids = shift_tokens_right(input_ids,
                                               self.config.pad_token_id, self.config.decoder_start_token_id)
        outputs = super().forward(
            input_ids,
            attention_mask=attention_mask,
            decoder_input_ids=decoder_input_ids,
            decoder_attention_mask=decoder_attention_mask,
            head_mask=head_mask,
            decoder_head_mask=decoder_head_mask,
            cross_attn_head_mask=cross_attn_head_mask,
            encoder_outputs=encoder_outputs,
            inputs_embeds=inputs_embeds,
            decoder_inputs_embeds=decoder_inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )
        hidden_states = outputs[0]  # last hidden state

        eos_mask = input_ids.eq(self.config.eos_token_id)

        if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
            raise ValueError("All examples must have the same number of <eos> tokens.")
        sentence_representation = hidden_states[eos_mask, :].view(hidden_states.size(0), -1, hidden_states.size(-1))[
            :, -1, :
        ]
        logits = self.classification_head(sentence_representation)

        loss = None

        if not return_dict:
            output = (logits,) + outputs[1:]
            return ((loss,) + output) if loss is not None else output

        return Seq2SeqSequenceClassifierOutput(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            decoder_hidden_states=outputs.decoder_hidden_states,
            decoder_attentions=outputs.decoder_attentions,
            cross_attentions=outputs.cross_attentions,
            encoder_last_hidden_state=outputs.encoder_last_hidden_state,
            encoder_hidden_states=outputs.encoder_hidden_states,
            encoder_attentions=outputs.encoder_attentions,
        )

In [6]:
model = T5ForSequenceClassification.from_pretrained(checkpoints[0])

In [12]:
s1 = 'tlong order foodpanda'
s2 = 'This text is about makanan'
s = f'ayat1: {s1} ayat2: {s2}'
strings = []
strings.append(s)

input_ids = [{'input_ids': tokenizer.encode(s, return_tensors='pt')[0]} for s in strings]
padded = tokenizer.pad(input_ids, padding='longest')
outputs = model(**padded, return_dict = True)

In [13]:
entail_contradiction_logits = outputs.logits[:,[0,1]]
probs = entail_contradiction_logits.softmax(dim=1)
probs[:,1].detach().numpy()

array([0.96643597], dtype=float32)

In [14]:
s1 = 'gov macam bengong, kami nk pilihan raya, gov backdoor, sakai'
tags = ['najib razak', 'mahathir', 'kerajaan', 'PRU', 'anarki']
strings = []
for t in tags:
    s = f'ayat1: {s1} ayat2: This text is about {t}'
    strings.append(s)
input_ids = [{'input_ids': tokenizer.encode(s, return_tensors='pt')[0]} for s in strings]
padded = tokenizer.pad(input_ids, padding='longest')
outputs = model(**padded, return_dict = True)
entail_contradiction_logits = outputs.logits[:,[0,1]]
probs = entail_contradiction_logits.softmax(dim=1)
probs[:,1].detach().numpy()

array([0.21508282, 0.05425181, 0.9331222 , 0.9351567 , 0.05064638],
      dtype=float32)

In [15]:
s1 = 'kerajaan sebenarnya sangat prihatin dengan rakyat, bagi duit bantuan'
tags = ['makan', 'makanan', 'buku', 'kerajaan', 'food delivery',
                                       'kerajaan jahat', 'kerajaan prihatin', 'bantuan rakyat',
       'kerajaan islam', 'gov syg kami']
strings = []
for t in tags:
    s = f'ayat1: {s1} ayat2: ayat ini berkaitan tentang {t}'
    strings.append(s)
input_ids = [{'input_ids': tokenizer.encode(s, return_tensors='pt')[0]} for s in strings]
padded = tokenizer.pad(input_ids, padding='longest')
outputs = model(**padded, return_dict = True)
entail_contradiction_logits = outputs.logits[:,[0,1]]
probs = entail_contradiction_logits.softmax(dim=1)
probs[:,1].detach().numpy()

array([0.00112138, 0.01928574, 0.04843272, 0.9965913 , 0.08923698,
       0.02016815, 0.99809355, 0.9951243 , 0.22032733, 0.94214803],
      dtype=float32)

In [17]:
model.push_to_hub('finetune-mnli-t5-base-standard-bahasa-cased', organization='mesolitica')

Cloning https://huggingface.co/mesolitica/finetune-mnli-t5-base-standard-bahasa-cased into local empty directory.


Upload file pytorch_model.bin:   0%|          | 32.0k/853M [00:00<?, ?B/s]

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/mesolitica/finetune-mnli-t5-base-standard-bahasa-cased
   3723c0b..59411ed  main -> main



'https://huggingface.co/mesolitica/finetune-mnli-t5-base-standard-bahasa-cased/commit/59411edc43fa5c28d35e030b526835b77a3bf0ff'

In [18]:
tokenizer.push_to_hub('finetune-mnli-t5-base-standard-bahasa-cased', organization='mesolitica')

Upload file spiece.model:   4%|4         | 32.0k/784k [00:00<?, ?B/s]

remote: Scanning LFS files for validity, may be slow...        
remote: LFS file scan complete.        
To https://huggingface.co/mesolitica/finetune-mnli-t5-base-standard-bahasa-cased
   59411ed..c3b64d7  main -> main



'https://huggingface.co/mesolitica/finetune-mnli-t5-base-standard-bahasa-cased/commit/c3b64d73d79fa80c2c75fa31a69e68e93e5a1592'

In [8]:
import json
import numpy as np
from tqdm import tqdm

predicted, actual = [], []
count = 0
with open('translated-mnli-dev_matched.jsonl') as fopen:
    for l in tqdm(fopen):
        data = json.loads(l)
        if data['gold_label'] == '-':
            continue
        s = f"ayat1: {data['translate'][0]} ayat2: {data['translate'][1]}"
        input_ids = tokenizer.encode(s, return_tensors='pt')
        logits = model(input_ids = input_ids, return_dict = True).logits[0].detach().numpy()
        predicted.append(model.config.id2label[np.argmax(logits)])
        actual.append(data['gold_label'])
        
        count += 1

10000it [12:23, 13.44it/s]


In [10]:
from sklearn import metrics

print(
    metrics.classification_report(
        predicted, actual,
        digits = 5
    )
)

               precision    recall  f1-score   support

contradiction    0.79365   0.82658   0.80978      3085
   entailment    0.79534   0.82894   0.81179      3338
      neutral    0.77810   0.71639   0.74597      3392

     accuracy                        0.78930      9815
    macro avg    0.78903   0.79064   0.78918      9815
 weighted avg    0.78885   0.78930   0.78841      9815

