### Intro Project and import neccessary modules

In [None]:
from torchinfo import summary
import warnings
import yaml
import os
import sys

from src.data import IWSLT2017DataLoader, Multi30kDataLoader
from utils.logger import get_logger
from src.transformer import Seq2SeqTransformer
from src.trainer import Trainer, EarlyStopper
from utils.config import SharedConfig, DataLoaderConfig, TransformerConfig, TrainerConfig
from src.processor import Processor
warnings.filterwarnings("ignore", category=UserWarning)

### Section Load Dataloader

In [None]:
path_to_config = './configs/multi30k-small.yaml'
run_id = 'multi30k-small'
device = 'cuda'

logger = get_logger("Main")
      
if os.path.exists(f'./models/{run_id}/metrics'):
      logger.error('Run ID already exists!')
      sys.exit(1)
else:
      os.makedirs(f'./models/{run_id}/metrics')
      
with open(path_to_config) as stream:
      config = yaml.safe_load(stream)

shared_conf = SharedConfig(run_id=run_id)
dl_conf = DataLoaderConfig(**config['dataloader'])

if dl_conf.dataset == "iwslt2017":
      dataloader = IWSLT2017DataLoader.new_instance(dl_conf, shared_conf)
else:
      dataloader = Multi30kDataLoader.new_instance(dl_conf, shared_conf)
            
train_dataloader = dataloader.train_dataloader
test_dataloader = dataloader.test_dataloader
val_dataloader = dataloader.val_dataloader
tokenizer = dataloader.tokenizer
val_dataset = dataloader.val_dataset

### Section Load Model and Processor

In [None]:
SRC_VOCAB_SIZE, TGT_VOCAB_SIZE = tokenizer.get_vocab_size(), tokenizer.get_vocab_size()

model_conf = TransformerConfig(
      **config['transformer'], 
      src_vocab_size=SRC_VOCAB_SIZE, 
      tgt_vocab_size=TGT_VOCAB_SIZE
)

transformer = Seq2SeqTransformer(model_conf)
translator = Processor.from_instance(transformer, tokenizer, device)

trainer_conf = TrainerConfig(
      **config['trainer'],
      device=device, 
      batch_size=dl_conf.batch_size
)
summary(transformer, [(256, dl_conf.batch_size), (256, dl_conf.batch_size), 
                      (256, 256), (256, 256), 
                      (dl_conf.batch_size, 256), (dl_conf.batch_size, 256)], depth=3)

### Section Train model

In [None]:
early_stopper = EarlyStopper(warmup=17, patience=7, min_delta=0)

trainer = Trainer.new_instance(transformer, translator, train_dataloader, test_dataloader, val_dataloader, 
                               tokenizer, early_stopper, trainer_conf, device, run_id)

trainer.train()

### Section Evaluate

In [None]:
# Only execute if no training was executed yet.
shared_conf = SharedConfig()
dl_conf = DataLoaderConfig()

dataloader = IWSLT2017DataLoader(dl_conf, shared_conf)
            
val_dataset = dataloader.val_dataset

### Prepare translator, metrics and dataset

In [None]:
from evaluate import load as load_metric
import torch
# If one want to evaluate another model that was not trained before through this notebook - 
# specify here as a path to the model checkpoint and tokenizer path:
# path_to_checkpoint = ""
# path_to_tokenizer = ""
# Don't forget to comment the paths below

model_dir = f"./models/{run_id}"
path_to_tokenizer = f"{model_dir}/tokenizer.json"

if os.path.isfile(f"{model_dir}/best_checkpoint_scripted.pt"):
    path_to_checkpoint = f"{model_dir}/best_checkpoint_scripted.pt"
elif os.path.isfile(f"{model_dir}/last_checkpoint_scripted.pt"):
    path_to_checkpoint = f"{model_dir}/last_checkpoint_scripted.pt"
else:
    path_to_checkpoint = f"{model_dir}/checkpoint_scripted.pt"
    
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

translator = Processor.from_checkpoint(model_checkpoint=path_to_checkpoint, 
                                             tokenizer=path_to_tokenizer, 
                                             device=device)
      
bleu = load_metric("bleu")
sacre_bleu = load_metric("sacrebleu")
rouge = load_metric("rouge")
meteor = load_metric("meteor")

outputs = []
sources = [x[0] for x in val_dataset]
targets = [x[1] for x in val_dataset]

### Generate samples

In [None]:
for idx, src in enumerate(sources):
    output = translator.translate(src)
            
    outputs.append(output)
            
    print(f"{idx+1}/{len(sources)}", end='\r')

### Compute Scores

In [None]:
import json

bleu_score = bleu.compute(predictions=outputs, references=targets)
            
sacre_bleu_score = sacre_bleu.compute(predictions=outputs, references=targets)
                                
rouge_score = rouge.compute(predictions=outputs, references=targets)
      
meteor_score = meteor.compute(predictions=outputs, references=targets)
      
metrics = {'bleu': bleu_score, 
           'sacre_bleu': sacre_bleu_score, 
           'rouge': rouge_score, 
           'meteor': meteor_score}
      
# Convert and write JSON object to file
with open(f"./{shared_conf.src_language}test-{shared_conf.tgt_language}-metrics.json", "x") as outfile: 
    json.dump(metrics, outfile, indent=4)

### Print Scores

In [None]:
print(f'\n\nEvaluation: bleu_score - {bleu_score}\nEvaluation: rouge_score - {rouge_score}\nEvaluation: sacre_bleu_score - {sacre_bleu_score}\nEvaluation: meteor_score - {meteor_score}')
      
TEST_SEQUENCE = "The quick brown fox jumped over the lazy dog and then ran away quickly."
output = translator.translate(TEST_SEQUENCE)
      
print(f'Input: {TEST_SEQUENCE}, Output: {output}')

### Section Demo

### Neccessary Imports

In [None]:
import gradio as gr
from src.translate import check_device
from utils.demo_model_config import ModelConfig

### Load utils

In [None]:
device = check_device('cpu')

# Initialize model configuration
model_config = ModelConfig(device)

# Set up Gradio theme
theme = gr.themes.Default()

en_examples = ["The quick brown fox jumps over the lazy dog.", 
                "She sells seashells by the seashore.", 
                "Technology is rapidly changing the way we live and work.", 
                "Can you recommend a good restaurant nearby?", 
                "Despite the rain, they decided to go for a hike."]

de_examples = ["Die schnelle braune Katze sprang über den hohen Zaun.", 
                "Er spielte den ganzen Tag Videospiele.", 
                "Das neue Museum in der Stadt ist einen Besuch wert.", 
                "Kannst du mir helfen, dieses Problem zu lösen?", 
                "Obwohl sie müde war, arbeitete sie bis spät in die Nacht."]

### T5 Dome Tab

In [None]:
def t5_model_tab():
    with gr.Tab(label="T5 Model"):
        with gr.Column():
            with gr.Accordion("Debug Log", open=True):
                debug_log = gr.TextArea(label="", lines=7, max_lines=12)

            with gr.Group():
                load_t5_btn = gr.Button("Load T5 model")
                load_t5_btn.click(fn=model_config.set_t5_model, outputs=[debug_log])

            with gr.Group():
                with gr.Row():
                    seed = gr.Textbox(label="English Sequence", max_lines=2)
                    model_id = gr.Textbox(value="t5", visible=False)

                with gr.Row():
                    output = gr.Textbox(label="German Sequence", max_lines=3)

                with gr.Row():
                    trns_btn = gr.Button("Translate")
                    trns_btn.click(fn=model_config.translate, inputs=[seed, model_id], outputs=[output])
                    gr.ClearButton(components=[seed, output, debug_log])

            with gr.Accordion(label="Examples", open=True):
                gr.Examples(examples=en_examples, inputs=[seed], label="English Sequences")

### Custom Model Demo Tab

In [None]:
def custom_model_tab():
    with gr.Tab(label="Custom Model"):
        with gr.Column():
            with gr.Accordion("Debug Log", open=True):
                debug_log = gr.TextArea(label="", lines=7, max_lines=12)

            with gr.Group():
                with gr.Row():
                    model = gr.File(label="Model", file_types=['.pt'], min_width=200)
                    tokenizer = gr.File(label="Tokenizer", file_types=['.json'], min_width=200)

                with gr.Row():
                    load_custom_btn = gr.Button("Load custom model")
                    load_custom_btn.click(fn=model_config.set_custom_model, inputs=[model, tokenizer], outputs=[debug_log])

            with gr.Group():
                with gr.Row():
                    seed = gr.Textbox(label="Input Sequence", max_lines=2)
                    model_id = gr.Textbox(value="custom", visible=False)

                with gr.Row():
                    output = gr.Textbox(label="Output Sequence", max_lines=3)

                with gr.Row():
                    trns_btn = gr.Button("Translate")
                    trns_btn.click(fn=model_config.translate, inputs=[seed, model_id], outputs=[output])
                    gr.ClearButton(components=[seed, output, debug_log])

            with gr.Accordion(label="Examples", open=True):
                gr.Examples(examples=en_examples, inputs=[seed], label="English Sequences")
                gr.Examples(examples=de_examples, inputs=[seed], label="German Sequences")

### Launch Demo

In [None]:
with gr.Blocks(theme=theme) as demo:
    header = gr.Markdown("# KI in den Life Sciences: Machine Translation Demo")
    line1 = gr.Markdown("by [Nico Fuchs](https://github.com/nico-byte) and [Matthias Laton](https://github.com/20DragonSlayer01)")
    line2 = gr.Markdown("---")
    line3 = gr.Markdown("### This demo uses a T5 model to translate English to German. You can also load your own model and tokenizer.")

    t5_model_tab()
    custom_model_tab()

# Launch the Gradio demo
demo.launch()