In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'

In [2]:
from tqdm import tqdm
import json
import random
import torch
import numpy as np
from transformers import AutoTokenizer, T5Config
from malaya.torch_model.t5 import T5ForSequenceClassification

  self.tok = re.compile(r'({})'.format('|'.join(pipeline)))
  self.tok = re.compile(r'({})'.format('|'.join(pipeline)))


In [3]:
config = T5Config.from_pretrained('mesolitica/nanot5-small-malaysian-cased')
config.num_labels = 2
config.vocab = ['contradiction', 'entailment']

In [4]:
model = T5ForSequenceClassification.from_pretrained('mesolitica/nanot5-small-malaysian-cased', config = config)
_ = model.cuda()

Some weights of T5ForSequenceClassification were not initialized from the model checkpoint at mesolitica/nanot5-small-malaysian-cased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [5]:
tokenizer = AutoTokenizer.from_pretrained('mesolitica/nanot5-small-malaysian-cased')

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
trainable_parameters = [param for param in model.parameters() if param.requires_grad]
trainer = torch.optim.AdamW(trainable_parameters, lr = 2e-4)

In [7]:
train_X, train_Y = [], []
with open('shuffled-train.json') as fopen:
    for l in fopen:
        l = json.loads(l)
        train_X.append(l['src'])
        train_Y.append(l['label'])

In [8]:
test_X, test_Y = [], []
with open('shuffled-test.json') as fopen:
    for l in fopen:
        l = json.loads(l)
        test_X.append(l['src'])
        test_Y.append(l['label'])
        
len(test_X)

26768

In [9]:
batch_size = 8
epoch = 100

In [10]:
best_dev_acc = -np.inf
patient = 1
current_patient = 0

for e in range(epoch):
    pbar = tqdm(range(0, len(train_X), batch_size))
    losses = []
    for i in pbar:
        trainer.zero_grad()
        x = train_X[i: i + batch_size]
        y = np.array(train_Y[i: i + batch_size])
        
        padded = tokenizer(x, truncation = True, padding = True, return_tensors = 'pt', max_length = 1024)
        padded['labels'] = torch.from_numpy(y)
        for k in padded.keys():
            padded[k] = padded[k].cuda()
        
        padded.pop('token_type_ids', None)
            
        loss, pred = model(**padded)
        loss.backward()
        
        grad_norm = torch.nn.utils.clip_grad_norm_(trainable_parameters, 1.0)
        trainer.step()
        losses.append(float(loss))
        
        
    dev_predicted = []
    for i in range(0, len(test_X[:10000]), batch_size):
        x = test_X[i: i + batch_size]
        y = np.array(test_Y[i: i + batch_size])
        padded = tokenizer(x, truncation = True, padding = True, return_tensors = 'pt', max_length = 1024)
        padded['labels'] = torch.from_numpy(y)
        for k in padded.keys():
            padded[k] = padded[k].cuda()
            
        padded.pop('token_type_ids', None)
        
        loss, pred = model(**padded)
        dev_predicted.append((pred.argmax(axis = 1).detach().cpu().numpy() == y).mean())
        
    dev_predicted = np.mean(dev_predicted)
    
    print(f'epoch: {e}, loss: {np.mean(losses)}, dev_predicted: {dev_predicted}')
    
    if dev_predicted >= best_dev_acc:
        best_dev_acc = dev_predicted
        current_patient = 0
        model.save_pretrained('small')
    else:
        current_patient += 1
    
    if current_patient >= patient:
        break

100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 131433/131433 [1:01:54<00:00, 35.39it/s]


epoch: 0, loss: 0.3635921087945055, dev_predicted: 0.861


100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 131433/131433 [1:02:03<00:00, 35.30it/s]


epoch: 1, loss: 0.2836592824724388, dev_predicted: 0.8671


100%|███████████████████████████████████████████████████████████████████████████████████████████████████| 131433/131433 [1:01:59<00:00, 35.33it/s]


epoch: 2, loss: 0.2419077000522983, dev_predicted: 0.8599


In [11]:
real_Y = []
for i in tqdm(range(0, len(test_X), batch_size)):
    x = test_X[i: i + batch_size]
    y = np.array(test_Y[i: i + batch_size])
    padded = tokenizer(x, padding = 'longest', return_tensors = 'pt')
    padded['labels'] = torch.from_numpy(y)
    for k in padded.keys():
        padded[k] = padded[k].cuda()
        
    padded.pop('token_type_ids', None)

    loss, pred = model(**padded)
    real_Y.extend(pred.argmax(axis = 1).detach().cpu().numpy().tolist())

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 3346/3346 [00:23<00:00, 143.90it/s]


In [12]:
from sklearn import metrics

print(
    metrics.classification_report(
        real_Y, test_Y,
        digits = 5
    )
)

              precision    recall  f1-score   support

           0    0.86010   0.84621   0.85310     13063
           1    0.85563   0.86881   0.86217     13705

    accuracy                        0.85778     26768
   macro avg    0.85787   0.85751   0.85763     26768
weighted avg    0.85781   0.85778   0.85774     26768



In [13]:
tokenizer.model_input_names = ['input_ids', 'attention_mask']

In [14]:
tokenizer.push_to_hub('mesolitica/finetune-mnli-nanot5-small', safe_serialization = True)

CommitInfo(commit_url='https://huggingface.co/mesolitica/finetune-mnli-nanot5-small/commit/46570f0fa6956806acd089bce96045df7642cac0', commit_message='Upload tokenizer', commit_description='', oid='46570f0fa6956806acd089bce96045df7642cac0', pr_url=None, pr_revision=None, pr_num=None)

In [15]:
model.push_to_hub('mesolitica/finetune-mnli-nanot5-small', safe_serialization = True)

model.safetensors:   0%|          | 0.00/148M [00:00<?, ?B/s]

CommitInfo(commit_url='https://huggingface.co/mesolitica/finetune-mnli-nanot5-small/commit/5d13fa686c2ff9e158d75413e77bc2990dfc4ffa', commit_message='Upload T5ForSequenceClassification', commit_description='', oid='5d13fa686c2ff9e158d75413e77bc2990dfc4ffa', pr_url=None, pr_revision=None, pr_num=None)