In [9]:
!pip uninstall -ignore-installed blinker
!pip install transformers sentencepiece datasets blinker flask


Usage:   
  pip uninstall [options] <package> ...
  pip uninstall [options] -r <requirements file> ...

no such option: -i
Collecting flask
  Using cached flask-3.0.3-py3-none-any.whl.metadata (3.2 kB)
Collecting blinker
  Using cached blinker-1.8.2-py3-none-any.whl.metadata (1.6 kB)
Using cached flask-3.0.3-py3-none-any.whl (101 kB)
Using cached blinker-1.8.2-py3-none-any.whl (9.5 kB)
Installing collected packages: blinker, flask
  Attempting uninstall: blinker
    Found existing installation: blinker 1.4
[31mERROR: Cannot uninstall 'blinker'. It is a distutils installed project and thus we cannot accurately determine which files belong to it which would lead to only a partial uninstall.[0m[31m
[0m

In [1]:
from flask import Flask, render_template, request, jsonify
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

Defining the website

In [2]:
# Global variables
# Flask app
app = Flask(__name__)
# Language token mappings
LANG_TOKEN_MAPPING = {
    'en' : '<en>',
    'fil' : '<fil>',
    'hi' : '<hi>',
    'id' : '<id>',
    'ja' : '<ja>', 
}

Loading the model and the tokenizer

In [37]:
model_repo = "google/mt5-base"
MODEL_PATH = "./app/models/mt5_translator_best.pt"
# download mt5 tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_repo, legacy=True)

# create a dict of the dict
special_tokens = { 'additional_special_tokens': list(LANG_TOKEN_MAPPING.values()) }
# add special tokens to the tokenizer
tokenizer.add_special_tokens(special_tokens)

# download model
model= AutoModelForSeq2SeqLM.from_pretrained(model_repo)
# Check the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Resize the base model to find SSD
model.resize_token_embeddings(len(tokenizer))
# Get the SSD and load trained model
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(device)))



<All keys matched successfully>

Nessecary Functions to Predict

In [44]:
# tokenizes and numericalizes input string
def encode_input_str(text, target_lang, tokenizer, seq_len,
                     lang_token_map=LANG_TOKEN_MAPPING):
  target_lang_token = lang_token_map[target_lang]

  # Tokenize and add special tokens
  input_ids = tokenizer.encode(
      text = target_lang_token + text,
      return_tensors = 'pt',
      padding = 'max_length',
      truncation = True,
      max_length = seq_len)

  return input_ids

def predict(text, target_lang, tokenizer):
    # Encode the input string
    input_ids = encode_input_str(text, target_lang, tokenizer, 20)
    # Generate the output
    output = model.generate(input_ids.to(device), num_beams = 10, max_length = 30)
    # Decode the output
    translated_text = tokenizer.decode(output[0], skip_special_tokens=True)
    
    return translated_text

In [45]:
# predict the string in japanese
a = predict("I'm going to mars", "hi", tokenizer)
# make sure it is a string
type(a)

str

In [46]:
a

'मुझे यह सुनिश्चित करने के लिए कहा गया कि मैं'

Flask functions

In [87]:
app = Flask(__name__)

In [88]:
@app.route('/')
def translate():
    return render_template('translate.html')
# Define a route for prediction
@app.route('/predict', methods=['POST'])
def prediction():
    data = request.get_json()
    text = data['text']
    print(text)
    target_lang = data['target_lang']
    prediction = predict(text, target_lang, tokenizer)
    return jsonify({'prediction': prediction})

In [89]:
# run the app
app.run()

 * Serving Flask app '__main__'
 * Debug mode: off


 * Running on http://127.0.0.1:5000
Press CTRL+C to quit
127.0.0.1 - - [06/Jul/2024 20:32:07] "GET / HTTP/1.1" 200 -
127.0.0.1 - - [06/Jul/2024 20:32:07] "GET /style.css HTTP/1.1" 404 -
127.0.0.1 - - [06/Jul/2024 20:32:13] "GET /?text=hello&source=en&target=id HTTP/1.1" 200 -
127.0.0.1 - - [06/Jul/2024 20:32:13] "GET /style.css HTTP/1.1" 404 -
127.0.0.1 - - [06/Jul/2024 20:32:37] "GET /?text=hello+this+is+test&source=en&target=id HTTP/1.1" 200 -
127.0.0.1 - - [06/Jul/2024 20:32:37] "GET /style.css HTTP/1.1" 404 -
