Skip to content

Commit

Permalink
improve sentiment and translation
Browse files Browse the repository at this point in the history
  • Loading branch information
huseinzol05 committed Oct 12, 2022
1 parent 9391d44 commit 02264f1
Show file tree
Hide file tree
Showing 7 changed files with 620 additions and 144 deletions.
167 changes: 108 additions & 59 deletions docs/load-sentiment.ipynb

Large diffs are not rendered by default.

167 changes: 108 additions & 59 deletions example/sentiment/load-sentiment.ipynb

Large diffs are not rendered by default.

14 changes: 0 additions & 14 deletions malaya/sentiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,20 +51,6 @@
'macro recall': 0.92589,
'macro f1-score': 0.92198,
},
'fastformer': {
'Size (MB)': 458,
'Quantized Size (MB)': 116,
'macro precision': 0.96882,
'macro recall': 0.96832,
'macro f1-score': 0.96836,
},
'tiny-fastformer': {
'Size (MB)': 77.3,
'Quantized Size (MB)': 19.7,
'macro precision': 0.90655,
'macro recall': 0.89819,
'macro f1-score': 0.90196,
},
}


Expand Down
5 changes: 1 addition & 4 deletions malaya/supervised/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,4 @@ def load_automodel(model, model_class, huggingface_class=None, **kwargs):


def load_generator(model, initial_text, **kwargs):
model_ = Generator.from_pretrained(model)
model_.load_tokenizer(model)
model_._initial_text = initial_text
return model_
return Generator(model, initial_text)
18 changes: 10 additions & 8 deletions malaya/torch_model/huggingface.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,18 @@
from transformers import T5Tokenizer, T5ForConditionalGeneration
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from malaya_boilerplate.torch_utils import to_tensor_cuda, to_numpy
from herpetologist import check_type
from typing import List


class Generator(T5ForConditionalGeneration):
def __init__(self, config):
super().__init__(config)
class Generator:
def __init__(self, model, initial_text, use_fast_tokenizer=False):
self.tokenizer = AutoTokenizer.from_pretrained(model, use_fast=use_fast_tokenizer)
self.model = AutoModelForSeq2SeqLM.from_pretrained(model)
self._initial_text = initial_text

def load_tokenizer(self, model_name=None, use_fast_tokenizer=False):
self.tokenizer = T5Tokenizer.from_pretrained(
model_name or self.config._name_or_path, use_fast=use_fast_tokenizer)
def cuda():
self.model.cuda()

@check_type
def generate(self, strings: List[str], **kwargs):
Expand All @@ -27,13 +29,13 @@ def generate(self, strings: List[str], **kwargs):
-------
result: List[str]
"""
cuda = next(self.parameters()).is_cuda
cuda = next(self.model.parameters()).is_cuda
input_ids = [{'input_ids': self.tokenizer.encode(f'{self._initial_text}{s}', return_tensors='pt')[
0]} for s in strings]
padded = self.tokenizer.pad(input_ids, padding='longest')
for k in padded.keys():
padded[k] = to_tensor_cuda(padded[k], cuda)
outputs = super().generate(**padded, **kwargs)
outputs = self.model.generate(**padded, **kwargs)
results = []
for o in outputs:
results.append(self.tokenizer.decode(o, skip_special_tokens=True))
Expand Down
73 changes: 73 additions & 0 deletions pretrained-model/long-t5/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Long-T5-Bahasa

Thanks to Google for opensourcing most of the source code to develop Long T5, https://github.com/google-research/longt5, and HuggingFace translated to PyTorch, https://huggingface.co/docs/transformers/model_doc/longt5

**This directory is very lack of comments, understand Tensorflow, Tensorflow estimator and Tensorflow Dataset are really helpful**.

## Objective

1. Provide T5 for Bahasa.

## Acknowledgement

Thanks to [Mesolitica](https://mesolitica.com/) for sponsoring GPU clouds to train Long T5 for Bahasa.

## How-to

TGLOBAL BASE model,
```
WANDB_DISABLED=true \
python3 run_t5.py \
--model_name_or_path ./ms-long-t5-tglobal-base \
--num_train_epochs 10 \
--logging_steps 100 \
--eval_steps 10000 \
--save_steps 10000 \
--evaluation_strategy steps \
--save_total_limit 10 \
--do_train \
--do_eval \
--source_lang src \
--target_lang tgt \
--train_file train-longer.json \
--validation_file test-longer.json \
--output_dir translation-long-t5-tglobal-base \
--per_device_train_batch_size=16 \
--per_device_eval_batch_size=4 \
--predict_with_generate \
--ignore_data_skip \
--max_source_length 1024 \
--max_target_length 1024 \
--warmup_steps 100000 \
--weight_decay 0.1 \
--gradient_checkpointing true
```

TGLOBAL BASE model,
```
WANDB_DISABLED=true \
python3 run_t5.py \
--model_name_or_path ./ms-long-t5-local-base \
--num_train_epochs 10 \
--logging_steps 100 \
--eval_steps 10000 \
--save_steps 10000 \
--evaluation_strategy steps \
--save_total_limit 10 \
--do_train \
--do_eval \
--source_lang src \
--target_lang tgt \
--train_file train-longer.json \
--validation_file test-longer.json \
--output_dir translation-long-t5-local-base-v2 \
--per_device_train_batch_size=16 \
--per_device_eval_batch_size=4 \
--predict_with_generate \
--ignore_data_skip \
--max_source_length 1024 \
--max_target_length 1024 \
--warmup_steps 100000 \
--weight_decay 0.1 \
--gradient_checkpointing true
```
Loading

0 comments on commit 02264f1

Please sign in to comment.