<a target="_blank" href="https://colab.research.google.com/github/cswamy/pytorch/blob/main/notebooks/Summarization_finetuned_mt5_amznreviews.ipynb">
  <img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/>
</a>

### **Notes**

Notebook to finetune a Mt5-small model for summarization in English and Spanish. Trained using the Amazon reviews dataset for books and digital books categories.

App: https://huggingface.co/spaces/cswamy/summarization_en_and_es_text

Resources:

Hugging face checkpoint: https://huggingface.co/google/mt5-small

Original MT5 paper: https://arxiv.org/abs/2010.11934

KDE4 dataset: https://huggingface.co/datasets/amazon_reviews_multi

Inspired by hugging face tutorial: https://huggingface.co/learn/nlp-course/chapter7/5?fw=pt

### **Setup**

In [1]:
import torch

device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cuda'

In [2]:
try:
  import transformers
  print("[INFO] Hugging face transformers imported successfully!")
except:
  !pip install -q transformers
  import transformers
  print("[INFO] Hugging face transformers installed and imported successfully!")

[INFO] Hugging face transformers imported successfully!


In [3]:
try:
  import datasets
  print("[INFO] Hugging face datasets imported successfully!")
except:
  !pip install -q datasets
  import datasets
  print("[INFO] Hugging face datasets installed and imported successfully!")

[INFO] Hugging face datasets imported successfully!


### **Download dataset**

In [4]:
from datasets import load_dataset

spanish_dataset = load_dataset('amazon_reviews_multi', 'es')
english_dataset = load_dataset('amazon_reviews_multi', 'en')

In [5]:
english_dataset

DatasetDict({
    train: Dataset({
        features: ['review_id', 'product_id', 'reviewer_id', 'stars', 'review_body', 'review_title', 'language', 'product_category'],
        num_rows: 200000
    })
    validation: Dataset({
        features: ['review_id', 'product_id', 'reviewer_id', 'stars', 'review_body', 'review_title', 'language', 'product_category'],
        num_rows: 5000
    })
    test: Dataset({
        features: ['review_id', 'product_id', 'reviewer_id', 'stars', 'review_body', 'review_title', 'language', 'product_category'],
        num_rows: 5000
    })
})

### **Prep datasets**

In [6]:
# Check number of reviews per product category
reviews_per_category = {}
for sample in english_dataset['train']:
  category = sample['product_category']
  if category in reviews_per_category.keys():
    reviews_per_category[category] += 1
  else:
    reviews_per_category[category] = 1

reviews_per_category

{'furniture': 2984,
 'home_improvement': 7136,
 'home': 17679,
 'wireless': 15717,
 'pc': 6401,
 'industrial_supplies': 1994,
 'kitchen': 10382,
 'apparel': 15951,
 'automotive': 7506,
 'camera': 2139,
 'lawn_and_garden': 7327,
 'watch': 761,
 'beauty': 12091,
 'pet_products': 7082,
 'drugstore': 11730,
 'electronics': 6186,
 'toy': 8745,
 'digital_ebook_purchase': 6749,
 'book': 3756,
 'jewelry': 2747,
 'sports': 8277,
 'other': 13418,
 'baby_product': 3150,
 'video_games': 775,
 'office_product': 5521,
 'grocery': 4730,
 'digital_video_download': 1364,
 'luggage': 1328,
 'shoes': 5197,
 'musical_instruments': 1102,
 'personal_care_appliances': 75}

In [7]:
# Define function to filter books and digital ebook categories only
def filter_books(example):
  return(
      example['product_category'] == 'book' or
      example['product_category'] == 'digital_ebook_purchase'
  )

In [8]:
# Get data subsets
english_books = english_dataset.filter(filter_books)
spanish_books = spanish_dataset.filter(filter_books)

In [9]:
english_books

DatasetDict({
    train: Dataset({
        features: ['review_id', 'product_id', 'reviewer_id', 'stars', 'review_body', 'review_title', 'language', 'product_category'],
        num_rows: 10505
    })
    validation: Dataset({
        features: ['review_id', 'product_id', 'reviewer_id', 'stars', 'review_body', 'review_title', 'language', 'product_category'],
        num_rows: 231
    })
    test: Dataset({
        features: ['review_id', 'product_id', 'reviewer_id', 'stars', 'review_body', 'review_title', 'language', 'product_category'],
        num_rows: 278
    })
})

In [10]:
# Concatenate datasets into one DatasetDict
from datasets import concatenate_datasets, DatasetDict

books_dataset = DatasetDict()

for split in english_books.keys():
  books_dataset[split] = concatenate_datasets(
      [english_books[split], spanish_books[split]]
  )
  books_dataset[split] = books_dataset[split].shuffle(seed=42)

books_dataset['train'][:5]['review_body']

["I gave it 4 stars because I felt the beginning chapter was a bit uneven, but as I got into the story, it gripped me. The characters, location, and the twists kept me reading well into the night. After posting this review, I'm going back to purchase the next book in the series.",
 'Good book to read for business students',
 'Enjoyable read will more than likely purchase other books by this author.',
 'No está mal, pero más para niños de 9 años',
 'I grew up reading Koontz, and years ago, I stopped,convinced i had "outgrown" him. Still,when a friend was looking for something suspenseful too read, I suggested Koontz. She found Strangers. The excitement art how good it was startled me. I was sure i had recommended something else. I ordered a copy for myself. This was a great reintroduction to an old favorite writer -- a novel full of fully developed characters that are totally relatable. People you could care about with full back stories. A mystery that had to be solved. I expected it to

In [11]:
# Filter out reviews with less than 2 words in title
books_dataset = books_dataset.filter(lambda x: len(x['review_title'].split()) > 2)
books_dataset

DatasetDict({
    train: Dataset({
        features: ['review_id', 'product_id', 'reviewer_id', 'stars', 'review_body', 'review_title', 'language', 'product_category'],
        num_rows: 9672
    })
    validation: Dataset({
        features: ['review_id', 'product_id', 'reviewer_id', 'stars', 'review_body', 'review_title', 'language', 'product_category'],
        num_rows: 238
    })
    test: Dataset({
        features: ['review_id', 'product_id', 'reviewer_id', 'stars', 'review_body', 'review_title', 'language', 'product_category'],
        num_rows: 245
    })
})

### **Tokenize data**

In [12]:
try:
  import sentencepiece
  print("[INFO] Sentencepiece imported successfully!")
except:
  !pip install -q sentencepiece
  import sentencepiece
  print("[INFO] Sentencepiece installed and imported successfully!")

[INFO] Sentencepiece imported successfully!


In [13]:
from transformers import AutoTokenizer

checkpoint = 'google/mt5-small'
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. If you see this, DO NOT PANIC! This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565


In [14]:
# Define tokenization function
def tokenize_func(examples):
  max_input_length = 512
  max_target_length = 30

  model_inputs = tokenizer(examples['review_body'],
                           max_length=max_input_length,
                           truncation=True)

  labels = tokenizer(examples['review_title'],
                     max_length=max_target_length,
                     truncation=True)
  model_inputs['labels'] = labels['input_ids']
  return model_inputs

In [15]:
tokenized_datasets = books_dataset.map(tokenize_func,
                                       batched=True,
                                       remove_columns=books_dataset["train"].column_names)

tokenized_datasets.set_format('torch')

Map:   0%|          | 0/9672 [00:00<?, ? examples/s]

Map:   0%|          | 0/238 [00:00<?, ? examples/s]

Map:   0%|          | 0/245 [00:00<?, ? examples/s]

In [16]:
tokenized_datasets

DatasetDict({
    train: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 9672
    })
    validation: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 238
    })
    test: Dataset({
        features: ['input_ids', 'attention_mask', 'labels'],
        num_rows: 245
    })
})

In [17]:
tokenized_datasets['train'][:1]

{'input_ids': tensor([[  653,  1957,  1314,   261,  2757,  1280,   435,   259, 29166,   263,
            269,   774,  5547,     1]]),
 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]),
 'labels': tensor([[ 298,  259, 5994,  269,  774, 5547,    1]])}

### **Setup ROUGE metric**

In [18]:
try:
  from torchmetrics.text import ROUGEScore
  rouge_fn = ROUGEScore()
  print(f"[INFO] ROUGE metric setup completed!")
except:
  !pip install -q torchmetrics
  from torchmetrics.text import ROUGEScore
  rouge_fn = ROUGEScore()
  print(f"[INFO] ROUGE metric installed and setup completed!")

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/764.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m122.9/764.8 kB[0m [31m3.5 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━[0m [32m481.3/764.8 kB[0m [31m7.0 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m757.8/764.8 kB[0m [31m8.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m764.8/764.8 kB[0m [31m7.0 MB/s[0m eta [36m0:00:00[0m
[?25h[INFO] ROUGE metric installed and setup completed!


### **Setup lead-3 baseline**

In [19]:
# Extract first 3 sentence from review body using nltk
import nltk
from nltk.tokenize import sent_tokenize

nltk.download('punkt')

def three_sentence_summary(text):
    return type("\n".join(sent_tokenize(text)[:3]))

print(three_sentence_summary(books_dataset["train"][1]["review_body"]))

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Unzipping tokenizers/punkt.zip.


<class 'str'>


In [20]:
# Setup function to create baseline
import nltk
from nltk.tokenize import sent_tokenize
nltk.download('punkt')

def evaluate_baseline(dataset,
                      rouge_fn=rouge_fn):
  """
  Function to create a ROUGE lead-3 baseline
  """
  summaries = ["\n".join(sent_tokenize(text)[:3]) for text in dataset['review_body']]
  all_scores = rouge_fn(preds=summaries, target=dataset['review_title'])
  f_measures = {k: round(v.item()*100, 2) for k, v in all_scores.items() if k.endswith('fmeasure')}
  return f_measures

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [21]:
scores = evaluate_baseline(books_dataset['validation'], rouge_fn)
scores

{'rouge1_fmeasure': 16.82,
 'rouge2_fmeasure': 8.89,
 'rougeL_fmeasure': 15.57,
 'rougeLsum_fmeasure': 16.0}

### **Setup dataloaders**

In [22]:
# Setup data collator
from transformers import AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq

tmp_model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer,
                                       model=tmp_model)

Downloading pytorch_model.bin:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

Downloading (…)neration_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [23]:
from torch.utils.data import DataLoader

BATCH_SIZE = 8

train_dataloader = DataLoader(dataset=tokenized_datasets['train'],
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              collate_fn=data_collator)

val_dataloader = DataLoader(dataset=tokenized_datasets['validation'],
                            batch_size=BATCH_SIZE,
                            shuffle=False,
                            collate_fn=data_collator)

test_dataloader = DataLoader(dataset=tokenized_datasets['test'],
                             batch_size=BATCH_SIZE,
                             shuffle=False,
                             collate_fn=data_collator)

len(train_dataloader), len(val_dataloader), len(test_dataloader)

(1209, 30, 31)

### **Train and eval**

#### Setup training

In [24]:
from transformers import AutoModelForSeq2SeqLM

EPOCHS = 8
LEARNING_RATE = 2e-5

# Setup model on device
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint).to(device)

# Setup optimiser
optimiser = torch.optim.AdamW(params=model.parameters(),
                              lr=LEARNING_RATE)

# Setup scheduler
lr_scheduler = torch.optim.lr_scheduler.LinearLR(optimizer=optimiser)

#### Explore model with torchinfo

In [25]:
try:
  from torchinfo import summary
except:
  !pip install -q torchinfo
  from torchinfo import summary

summary(model)

Layer (type:depth-idx)                                       Param #
MT5ForConditionalGeneration                                  --
├─Embedding: 1-1                                             128,057,344
├─MT5Stack: 1-2                                              128,057,344
│    └─Embedding: 2-1                                        (recursive)
│    └─ModuleList: 2-2                                       --
│    │    └─MT5Block: 3-1                                    2,360,512
│    │    └─MT5Block: 3-2                                    2,360,320
│    │    └─MT5Block: 3-3                                    2,360,320
│    │    └─MT5Block: 3-4                                    2,360,320
│    │    └─MT5Block: 3-5                                    2,360,320
│    │    └─MT5Block: 3-6                                    2,360,320
│    │    └─MT5Block: 3-7                                    2,360,320
│    │    └─MT5Block: 3-8                                    2,360,320
│    └─MT5LayerN

#### Train loop

In [26]:
from tqdm.auto import tqdm

train_loss = 0
for epoch in tqdm(range(EPOCHS)):
  for batch_num, batch in enumerate(train_dataloader):
    # Send data to device
    batch = {k: v.to(device) for k, v in batch.items()}

    # Forward pass
    outputs = model(**batch)

    # Accumulate loss
    loss = outputs.loss
    train_loss += loss

    # Zero gradients
    optimiser.zero_grad()

    # Backpropagate loss
    loss.backward()

    # Step optimiser and scheduler
    optimiser.step()
    lr_scheduler.step()

    # Track where training is
    if batch_num % 100 == 0:
      print(f"Epoch: {epoch+1} | Batch: {batch_num}")

  # Average loss across batches
  train_loss /= len(train_dataloader)

  # Print loss
  print(f"Epoch: {epoch+1} | Training loss: {train_loss:.4f}")

  0%|          | 0/8 [00:00<?, ?it/s]

You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch: 1 | Batch: 0
Epoch: 1 | Batch: 100
Epoch: 1 | Batch: 200
Epoch: 1 | Batch: 300
Epoch: 1 | Batch: 400
Epoch: 1 | Batch: 500
Epoch: 1 | Batch: 600
Epoch: 1 | Batch: 700
Epoch: 1 | Batch: 800
Epoch: 1 | Batch: 900
Epoch: 1 | Batch: 1000
Epoch: 1 | Batch: 1100
Epoch: 1 | Batch: 1200
Epoch: 1 | Training loss: 6.8137
Epoch: 2 | Batch: 0
Epoch: 2 | Batch: 100
Epoch: 2 | Batch: 200
Epoch: 2 | Batch: 300
Epoch: 2 | Batch: 400
Epoch: 2 | Batch: 500
Epoch: 2 | Batch: 600
Epoch: 2 | Batch: 700
Epoch: 2 | Batch: 800
Epoch: 2 | Batch: 900
Epoch: 2 | Batch: 1000
Epoch: 2 | Batch: 1100
Epoch: 2 | Batch: 1200
Epoch: 2 | Training loss: 3.4874
Epoch: 3 | Batch: 0
Epoch: 3 | Batch: 100
Epoch: 3 | Batch: 200
Epoch: 3 | Batch: 300
Epoch: 3 | Batch: 400
Epoch: 3 | Batch: 500
Epoch: 3 | Batch: 600
Epoch: 3 | Batch: 700
Epoch: 3 | Batch: 800
Epoch: 3 | Batch: 900
Epoch: 3 | Batch: 1000
Epoch: 3 | Batch: 1100
Epoch: 3 | Batch: 1200
Epoch: 3 | Training loss: 3.0023
Epoch: 4 | Batch: 0
Epoch: 4 | Batch: 10

#### Define function to ROUGE metric

In [41]:
import numpy as np
import nltk
from nltk.tokenize import sent_tokenize
nltk.download('punkt')

def process_for_rogue(preds:torch.tensor,
                      labels:torch.tensor):
  """
  Function to process preds and labels for ROUGE scores.
  Args:
    preds(torch.tensor): batch of predicted tokens from model.
    labels(torch.tensor): batch of truth label tokens.
  Returns:
    Lists of decoded preds and labels for batch.
  """

  # Convert tensors to numpy
  preds = preds.cpu().numpy()
  labels = labels.cpu().numpy()

  # Decode tokens
  decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True)
  labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
  decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)

  # Setup for ROUGE
  decoded_preds = [pred.strip() for pred in decoded_preds]
  decoded_labels = [label.strip() for label in decoded_labels]
  rouge_preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in decoded_preds]
  rouge_labels = [["\n".join(nltk.sent_tokenize(label)) for label in decoded_labels]]

  return rouge_preds, rouge_labels

[nltk_data] Downloading package punkt to /root/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


#### Eval loop function

In [53]:
def eval(model:torch.nn.Module,
         dataloader:torch.utils.data.DataLoader,
         device:torch.device,
         process_for_rogue,
         rouge_fn):
  """
  Function that runs eval either on val_dataloader or test_dataloader
  """
  val_loss = 0

  rouge_fn = rouge_fn.to(device)
  fmeasures_list = []
  rouge1_fmeasure = 0
  rouge2_fmeasure = 0
  rougeL_fmeasure = 0
  rougeLsum_fmeasure = 0

  model.to(device)
  model.eval()
  with torch.inference_mode():
    for batch_num, batch in enumerate(dataloader):
      # Send data to device
      batch = {k: v.to(device) for k, v in batch.items()}

      # Forward pass, calculate loss
      val_loss += model(**batch).loss

      # Generate, calculate ROUGE
      preds = model.generate(batch['input_ids'],
                             attention_mask=batch['attention_mask'],
                             max_length=30)
      rouge_preds, rouge_labels = process_for_rogue(preds, batch['labels'])
      batch_scores = rouge_fn(preds=rouge_preds, target=rouge_labels)
      f_measures = {k: round(v.item()*100, 2) for k, v in batch_scores.items() if k.endswith('fmeasure')}
      fmeasures_list.append(f_measures)

      # Track eval loop
      #if batch_num % 10 == 0:
      #  print(f"Batch: {batch_num}")

    # Average loss and fmeasures across batches
    val_loss /= len(val_dataloader)
    for fmeasures_dict in fmeasures_list:
      rouge1_fmeasure += fmeasures_dict['rouge1_fmeasure']
      rouge2_fmeasure += fmeasures_dict['rouge2_fmeasure']
      rougeL_fmeasure += fmeasures_dict['rougeL_fmeasure']
      rougeLsum_fmeasure += fmeasures_dict['rougeLsum_fmeasure']
    rouge1_fmeasure /= len(dataloader)
    rouge2_fmeasure /= len(dataloader)
    rougeL_fmeasure /= len(dataloader)
    rougeLsum_fmeasure /= len(dataloader)

  # Print output
  print(f"Loss: {val_loss:.4f}\n")

  print(f"---Rouge f1 measures---")
  print(f"rouge1_fmeasure = {rouge1_fmeasure:.4f}")
  print(f"rouge2_fmeasure = {rouge2_fmeasure:.4f}")
  print(f"rougeL_fmeasure = {rougeL_fmeasure:.4f}")
  print(f"rougeLsum_fmeasure = {rougeLsum_fmeasure:.4f}")

#### Eval on validation dataset

In [54]:
eval(model=model,
     dataloader=val_dataloader,
     device=device,
     process_for_rogue=process_for_rogue,
     rouge_fn=rouge_fn)

Loss: 3.1098

---Rouge f1 measures---
rouge1_fmeasure = 24.6017
rouge2_fmeasure = 7.0993
rougeL_fmeasure = 22.9367
rougeLsum_fmeasure = 22.9367


#### Eval on test dataset

In [55]:
eval(model=model,
     dataloader=test_dataloader,
     device=device,
     process_for_rogue=process_for_rogue,
     rouge_fn=rouge_fn)

Loss: 3.2337

---Rouge f1 measures---
rouge1_fmeasure = 22.0581
rouge2_fmeasure = 5.1742
rougeL_fmeasure = 20.0494
rougeLsum_fmeasure = 20.0494


### **Test model**

In [69]:
text = "The ball hit the splice a lot and sent a fizzing sensation up the handle and into the bottom hand, so I adapted at each session by playing softer and softer, later and later. I found it very difficult to get down the pitch and meet the ball as it landed and so persuaded myself to play back more. It occurred to me that a better player would manage the shimmy down the pitch with more skill and faster footwork, and that the good sweepers would have to take him on in the way that Kevin Pietersen managed so successfully on occasions."
# text = "Todo muy bien, cumple con lo esperado. Lo único malo es que: se calienta un poco y la batería no dura 8h. A una persona le ha parecido esto útil"

inputs = tokenizer(text,
                   max_length=512,
                   truncation=True,
                   return_tensors='pt')

inputs = {k: v.to(device) for k, v in inputs.items()}

output = model.generate(inputs['input_ids'],
                        attention_mask=inputs['attention_mask'],
                        max_length=30)
tokenizer.batch_decode(output, skip_special_tokens=True)

['Very difficult to get down the pitch']

### **Save model**

In [75]:
!git clone https://github.com/cswamy/pytorch

Cloning into 'pytorch'...
remote: Enumerating objects: 92, done.[K
remote: Counting objects: 100% (92/92), done.[K
remote: Compressing objects: 100% (79/79), done.[K
remote: Total 92 (delta 29), reused 31 (delta 8), pack-reused 0[K
Receiving objects: 100% (92/92), 67.46 KiB | 776.00 KiB/s, done.
Resolving deltas: 100% (29/29), done.


In [76]:
from pytorch.scripts import utils
utils.save_model(model=model,
                 target_dir='models',
                 model_name='mt5_amzn_enes_reviews_summarization.pth')

[INFO] Saving model to: models/mt5_amzn_enes_reviews_summarization.pth


### **Deploy to hugging face**

In [77]:
from pathlib import Path

# Create folders
demo_path = Path("demos/mt5_amznreviews")
demo_path.mkdir(parents=True, exist_ok=True)

In [78]:
# Move model to demo folder
!mv models/mt5_amzn_enes_reviews_summarization.pth demos/mt5_amznreviews

In [79]:
%%writefile demos/mt5_amznreviews/model.py
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

def create_mt5_small():
  """
  Initializes model and tokenizer.
  """
  checkpoint = 'google/mt5-small'
  tokenizer = AutoTokenizer.from_pretrained(checkpoint, return_tensors='pt')
  model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

  return model, tokenizer

Writing demos/mt5_amznreviews/model.py


In [80]:
%%writefile demos/mt5_amznreviews/app.py
import torch
import gradio as gr

from model import create_mt5_small

# Setup model and tokenizer
model, tokenizer = create_mt5_small()

# Load state dict from model
model.load_state_dict(
    torch.load(
        f="mt5_amzn_enes_reviews_summarization.pth",
        map_location=torch.device("cpu")
    ))

# Predict function
def predict(text:str):

  # Tokenize inputs and get model outputs
  input = tokenizer(text,
                    max_length=512,
                    truncation=True,
                    return_tensors='pt')
  output_tokens = model.generate(input['input_ids'],
                                 attention_mask=input['attention_mask'],
                                 max_length=30)
  output_text = tokenizer.batch_decode(output_tokens,
                                       skip_special_tokens=True)

  return output_text[0]

# Create examples list
examples_list = ["The ball hit the splice a lot and sent a fizzing sensation up the handle and into the bottom hand, so I adapted at each session by playing softer and softer, later and later. I found it very difficult to get down the pitch and meet the ball as it landed and so persuaded myself to play back more. It occurred to me that a better player would manage the shimmy down the pitch with more skill and faster footwork, and that the good sweepers would have to take him on in the way that Kevin Pietersen managed so successfully on occasions.",
                 "Todo muy bien, cumple con lo esperado. Lo único malo es que: se calienta un poco y la batería no dura 8h. A una persona le ha parecido esto útil"]

# Create gradio app
title = "Summarizer for English and Spanish inputs"
description = "MT5-small model finetuned for summarization on English or Spanish text trained on the Amazon reviews dataset."

demo = gr.Interface(fn=predict,
                    inputs=gr.inputs.Textbox(label="Input",
                                             placeholder="Enter sentences here in English or Spanish..."),
                    outputs="text",
                    examples=examples_list,
                    title=title,
                    description=description)

# Launch gradio
demo.launch()

Writing demos/mt5_amznreviews/app.py


In [81]:
%%writefile demos/mt5_amznreviews/requirements.txt
torch==1.12.0
gradio==3.44.1
transformers==4.33.1
sentencepiece==0.1.99

Writing demos/mt5_amznreviews/requirements.txt


In [82]:
!cd demos/mt5_amznreviews && zip -r ../mt5_amznreviews.zip *

  adding: app.py (deflated 54%)
  adding: model.py (deflated 42%)
  adding: mt5_amzn_enes_reviews_summarization.pth (deflated 24%)
  adding: requirements.txt (deflated 6%)


In [None]:
try:
  from google.colab import files
  files.download("demos/mt5_amznreviews.zip")
except:
  print(f"Download failed!")