In [1]:
!nvidia-smi

Fri Jul 15 12:20:38 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.103.01   Driver Version: 470.103.01   CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Quadro RTX 6000     Off  | 00000000:D8:00.0 Off |                  Off |
| 58%   71C    P8    19W / 260W |   2001MiB / 24220MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [2]:
import torch
torch.version.cuda

'10.2'

In [3]:
x_ones = torch.ones(3)
x_ones.to("cuda")

tensor([1., 1., 1.], device='cuda:0')

In [4]:
from T5FineTuner import RPDataset
from utils import get_folds
import torch
import argparse
from transformers import T5Tokenizer, AutoTokenizer, AutoModelForSeq2SeqLM
import pytorch_lightning as pl
from pytorch_lightning.loggers import WandbLogger
import wandb

DATASETS = ["RP-Crowd-3", "RP-Crowd-2", "RP-Mod"]
# suffix = "german-t5-oscar-ep1-prompted-germanquad"
model_names = ["google/mt5-small", "google/mt5-base"]
# MODEL_NAME_OR_PATH = f"GermanT5/{suffix}"
# WANDB_PROJECT_NAME = f"all-datasets-{suffix}"
# OUTPUT_DIR = f"./{suffix}"

# auto_tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME_OR_PATH)
# auto_model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME_OR_PATH)


[nltk_data] Downloading package punkt to /home/dobby/nltk_data...
[nltk_data]   Package punkt is already up-to-date!
[nltk_data] Downloading package punkt to /home/dobby/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


In [5]:
import argparse
import glob
import os
import json
import time
import logging
import random
import re
from itertools import chain
from string import punctuation

import nltk
nltk.download('punkt')
from nltk.tokenize import sent_tokenize
import csv

import pandas as pd
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl

# eval packages
import textwrap
from tqdm.auto import tqdm
from sklearn import metrics

from torch.optim import AdamW

from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    get_linear_schedule_with_warmup
)

def set_seed(seed):
  random.seed(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

set_seed(42)



class T5FineTuner(pl.LightningModule):
  def __init__(self, hparams):
    super(T5FineTuner, self).__init__()
    # self.hparams.save_hyperparameters(hparams)
    self.save_hyperparameters(hparams)
    
    self.model = hparams.model
    self.tokenizer = hparams.tokenizer
    # self.train_dataset = train_dataset
    # self.val_dataset = val_dataset
    
  def is_logger(self):
    return self.trainer.global_rank <= 0
  

  def forward(
      self, input_ids, attention_mask=None, decoder_input_ids=None, decoder_attention_mask=None, labels=None
  ):
    return self.model(
        input_ids,
        attention_mask=attention_mask,
        decoder_input_ids=decoder_input_ids,
        decoder_attention_mask=decoder_attention_mask,
        labels=labels,
        # lm_labels=lm_labels,
    )

  def _step(self, batch):
    labels = batch["target_ids"]
    labels[labels[:, :] == self.tokenizer.pad_token_id] = -100

    outputs = self(
        input_ids=batch["source_ids"],
        attention_mask=batch["source_mask"],
        labels=labels,
        # lm_labels=lm_labels,
        decoder_attention_mask=batch['target_mask']
    )

    loss = outputs[0]

    return loss
  
  def get_accuracy(self, batch):
    labels = batch["target_ids"]

    outs = self.model.generate(input_ids=batch["source_ids"], 
                                attention_mask=batch["source_mask"], 
                                max_length=2)

    dec = [self.tokenizer.decode(ids) for ids in outs]
    target = [self.tokenizer.decode(label) for label in labels]

    new_outputs = [s[6:] for s in dec]
    new_targets = [s[:-4] for s in target]

    accuracy_score = metrics.accuracy_score(new_targets, new_outputs)
    # f1 = metrics.f1_score(new_targets, new_outputs, labels=["problematisch"], average=None)
    # recall = metrics.recall_score(new_targets, new_outputs, labels=["problematisch"], average=None)
    # precision = metrics.precision_score(new_targets, new_outputs, labels=["problematisch"], average=None)
    # rec = metrics.recall_score(new_targets, new_outputs, average="micros")
    return accuracy_score, 0, 0, 0
    #  f1, recall, precision
  # def computer_accuracy(self):
  #   self.model.model.eval()

  def training_step(self, batch, batch_idx):
    loss = self._step(batch)
    self.log("train/loss", loss)
    return {"loss": loss}
  
  def training_epoch_end(self, outputs):
    avg_train_loss = torch.stack([x["loss"] for x in outputs]).mean()
    tensorboard_logs = {"avg_train_loss": avg_train_loss}
    self.log("avg_train_loss", avg_train_loss)
    
  def validation_step(self, batch, batch_idx):
    
    loss = self._step(batch)
    accuracy, f1, recall, precision = [torch.tensor(val) for val in self.get_accuracy(batch)]
    self.log("val/loss", loss, logger=True)
    self.log("val/accuracy", accuracy, logger=True)
    # self.log("val/f1", f1, logger=True)
    # self.log("val/recall", recall, logger=True)
    # self.log("val/precision", precision, logger=True)
    return {"val/loss": loss, "val/accuracy": accuracy, 
    # "val/f1": f1, "val/recall": recall, "val/precision": precision
    }
  
  def validation_epoch_end(self, outputs):
    avg_loss = torch.stack([x["val/loss"] for x in outputs]).mean()
    avg_accuracy = torch.stack([x["val/accuracy"] for x in outputs]).mean()
    # avg_f1 = torch.stack([x["val/f1"] for x in outputs]).mean()
    # avg_recall = torch.stack([x["val/recall"] for x in outputs]).mean()
    # avg_precision = torch.stack([x["val/precision"] for x in outputs]).mean()
    # tensorboard_logs = {"val_loss": avg_loss}
    self.log("avg_val_loss", avg_loss)
    self.log("avg_val_accuracy", avg_accuracy)
    # self.log("avg_val_f1", avg_f1)
    # self.log("avg_val_recall", avg_recall)
    # self.log("avg_val_precision", avg_precision)
    # self.log("log", tensorboard_logs)
    # self.log("progress_bar", tensorboard_logs)
    # self.log({"avg_val_loss": avg_loss, 
    #           "log": tensorboard_logs,
    #           'progress_bar': tensorboard_logs}, logger=True, prog_bar=True)
    return {"avg_val_loss": avg_loss, "avg_val_accuracy": avg_accuracy}

  def configure_optimizers(self):
    "Prepare optimizer and schedule (linear warmup and decay)"

    model = self.model
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
            "weight_decay": self.hparams.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
            "weight_decay": 0.0,
        },
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=self.hparams.learning_rate, \
                      eps=self.hparams.adam_epsilon, betas=self.hparams.adam_betas)
    self.opt = optimizer
    return [optimizer]
  
  def optimizer_step(self, 
                      epoch,
                      batch_idx,
                      optimizer,
                      optimizer_idx,
                      second_order_closure=None,
                      on_tpu=None,
                      using_native_amp=None,
                      using_lbfgs=None):
    # if self.trainer.use_tpu:
    #   xm.optimizer_step(optimizer)
    # else:
    optimizer.step(closure=second_order_closure)
    optimizer.zero_grad()
    self.lr_scheduler.step()
  
  # def closure(self):
  #   return "closure"
  
  def get_tqdm_dict(self):
    tqdm_dict = {"loss": "{:.3f}".format(self.trainer.avg_loss), "lr": self.lr_scheduler.get_last_lr()[-1]}

    return tqdm_dict

  def train_dataloader(self):
    # train_dataset = get_dataset(tokenizer=self.tokenizer, type_path="train", args=self.hparams)
    dataloader = DataLoader(self.hparams.train_dataset, batch_size=self.hparams.train_batch_size, drop_last=True, shuffle=True, num_workers=4)
    t_total = (
        (len(dataloader.dataset) // (self.hparams.train_batch_size * max(1, self.hparams.n_gpu)))
        // self.hparams.gradient_accumulation_steps
        * float(self.hparams.num_train_epochs)
    )
    scheduler = get_linear_schedule_with_warmup(
        self.opt, num_warmup_steps=self.hparams.warmup_steps, num_training_steps=t_total
    )
    self.lr_scheduler = scheduler
    return dataloader

  def val_dataloader(self):
    # val_dataset = get_dataset(tokenizer=self.tokenizer, type_path="val", args=self.hparams)
    return DataLoader(self.hparams.val_dataset, batch_size=self.hparams.eval_batch_size, num_workers=4)

[nltk_data] Downloading package punkt to /home/dobby/nltk_data...
[nltk_data]   Package punkt is already up-to-date!


### Setting up Automatic Hyperparameter tuning

### Normal Training Loop

In [6]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

In [7]:
name = "google/mt5-small"
name.replace("/", "-")

'google-mt5-small'

In [8]:
from pytorch_lightning.callbacks import EarlyStopping
DATASETS = ["RP-Crowd-3", "RP-Crowd-2", "RP-Mod"]
# suffix = "german-t5-oscar-ep1-prompted-germanquad"
model_names = ["google/mt5-small", "google/mt5-base"]
learning_rates = [1e-4, 1e-5, 1e-6]
weight_decays = [0.01, 0.001, 0]
for name in model_names:
    # MODEL_NAME_OR_PATH = "GermanT5/t5-base-german-3e"
    
    OUTPUT_DIR = f"./{name}"

    auto_tokenizer = AutoTokenizer.from_pretrained(name)
    
    for dataset in DATASETS:
        temp = name.replace("/", "-")
        WANDB_PROJECT_NAME = f"{dataset}-{temp}-hyperparameter-search"
        source = f"./Datasets/{dataset}-folds.csv"
        train_inputs, train_targets, val_inputs, val_targets = get_folds(source)

        train_dataset = RPDataset(auto_tokenizer, train_inputs, train_targets)
        valid_dataset = RPDataset(auto_tokenizer, val_inputs, val_targets)
        for wd in weight_decays:
            for lr in learning_rates:
                run_name = f"{dataset}-lr-{lr}-wd-{wd}"

                wandb.finish()

                
                wandb_logger = WandbLogger(project=WANDB_PROJECT_NAME, 
                name=run_name)

                wandb.define_metric("val/accuracy", summary="max")
                # wandb.define_metric("val/f1", summary="max")

                checkpoint_callback = pl.callbacks.ModelCheckpoint(
                    dirpath=OUTPUT_DIR + run_name, filename="{epoch}-{val/accuracy:.2f}-{val/loss:.2f}", monitor="val/accuracy", mode="max", save_top_k=5
                )

                auto_model = AutoModelForSeq2SeqLM.from_pretrained(name, from_flax=True)
                early_stop_callback = EarlyStopping(monitor="val/accuracy", patience=5, mode="max")
                args_dict = dict(
                        data_dir="", # path for data files
                        output_dir=f"./GermanT5-RP-Mod/t5-efficient-oscar-german-small-el32/", # path to save the checkpoints
                        model_name_or_path=name,
                        tokenizer_name_or_path=name,
                        dataset_name=dataset,
                        max_seq_length=512,
                        learning_rate=lr,
                        weight_decay=wd,
                        adam_epsilon=1e-8,
                        adam_betas=(0.9,0.999),
                        warmup_steps=0,
                        train_batch_size=4,
                        eval_batch_size=2,
                        num_train_epochs=10,
                        gradient_accumulation_steps=1,
                        n_gpu=1,
                        early_stop_callback=False,
                        fp_16=False, # if you want to enable 16-bit training then install apex and set this to true
                        opt_level='O1', # you can find out more on optimisation levels here https://nvidia.github.io/apex/amp.html#opt-levels-and-properties
                        max_grad_norm=0.5, # if you enable 16-bit training then set this to a sensible value, 0.5 is a good default
                        seed=42,
                        train_dataset=train_dataset, 
                        val_dataset=valid_dataset, 
                        model=auto_model,
                        tokenizer=auto_tokenizer
                    )
                
                args = argparse.Namespace(**args_dict)

                train_params = dict(
                            accumulate_grad_batches=args.gradient_accumulation_steps,
                            auto_lr_find=True,
                            gpus=args.n_gpu,
                            max_epochs=args.num_train_epochs,
                            default_root_dir=f"/home/dobby/{name}",
                            # early_stop_callback=False,
                            precision= 16 if args.fp_16 else 32,
                            amp_level=args.opt_level,
                            gradient_clip_val=args.max_grad_norm,
                            # checkpoint_callback=checkpoint_callback,
                            logger=wandb_logger,
                            enable_checkpointing=checkpoint_callback,
                            callbacks=[early_stop_callback, checkpoint_callback],
                            # callbacks=[raytuner_callback],
                            # callbacks=[LoggingCallback()],
                            amp_backend="apex"
                        )
                
                model = T5FineTuner(args)
                trainer = pl.Trainer(**train_params)

                trainer.fit(model)

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33misadoraw[0m. Use [1m`wandb login --relogin`[0m to force relogin


  pt_model_dict[flax_key] = torch.from_numpy(flax_tensor)
All Flax model weights were used when initializing MT5ForConditionalGeneration.

Some weights of MT5ForConditionalGeneration were not initialized from the Flax model and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  rank_zero_warn(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                        | Params
------------------------------------------------------
0 | model | MT5ForConditionalGeneration | 300 M 
------------------------------------------------------
300 M     Trainable params
0         Non-trainable params
300 M     Total params
1,200.707 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
avg_train_loss,█▁▁▁▁▁▁▁▁▁
avg_val_accuracy,▁▆▆▇▇█▇▇▇█
avg_val_loss,▂▁▂▃▅▅▇█▇█
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
train/loss,█▂▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/accuracy,▁▆▆▇▇█▇▇▇█
val/loss,▂▁▂▃▅▅▇█▇█

0,1
avg_train_loss,0.22509
avg_val_accuracy,0.78254
avg_val_loss,0.54874
epoch,9.0
train/loss,0.55121
trainer/global_step,12599.0
val/loss,0.54874


All Flax model weights were used when initializing MT5ForConditionalGeneration.

Some weights of MT5ForConditionalGeneration were not initialized from the Flax model and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  rank_zero_warn(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                        | Params
------------------------------------------------------
0 | model | MT5ForConditionalGeneration | 300 M 
------------------------------------------------------
300 M     Trainable params
0         Non-trainable params
300 M     Total params
1,200.707 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
avg_train_loss,█▃▁▁▁▁▁▁▁▁
avg_val_accuracy,▁▇▇███████
avg_val_loss,█▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
train/loss,█▆▇▄▆▃▃▂▂▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▂▁▁▁▂▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/accuracy,▁▇▇███████
val/loss,█▁▁▁▁▁▁▁▁▁

0,1
avg_train_loss,0.57291
avg_val_accuracy,0.56508
avg_val_loss,0.33821
epoch,9.0
train/loss,0.53524
trainer/global_step,12599.0
val/loss,0.33821


All Flax model weights were used when initializing MT5ForConditionalGeneration.

Some weights of MT5ForConditionalGeneration were not initialized from the Flax model and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  rank_zero_warn(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                        | Params
------------------------------------------------------
0 | model | MT5ForConditionalGeneration | 300 M 
------------------------------------------------------
300 M     Trainable params
0         Non-trainable params
300 M     Total params
1,200.707 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
avg_train_loss,█▆▄▂▁▁
avg_val_accuracy,▁▁▁▁▁▁
avg_val_loss,█▆▄▃▂▁
epoch,▁▁▁▁▁▁▁▂▂▂▂▂▂▂▄▄▄▄▄▄▅▅▅▅▅▅▅▇▇▇▇▇▇▇██████
train/loss,▃▃█▄▅▅▅▃▂▃▅▂▁▃▂▂▃▃▂▁▂▄▁▄▄▂▃▁▄▂▃▂▃▁▂▄▂▅▁▂
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇▇▇███
val/accuracy,▁▁▁▁▁▁
val/loss,█▆▄▃▂▁

0,1
avg_train_loss,18.72716
avg_val_accuracy,0.0
avg_val_loss,13.02845
epoch,5.0
train/loss,30.14358
trainer/global_step,7559.0
val/loss,13.02845


All Flax model weights were used when initializing MT5ForConditionalGeneration.

Some weights of MT5ForConditionalGeneration were not initialized from the Flax model and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  rank_zero_warn(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                        | Params
------------------------------------------------------
0 | model | MT5ForConditionalGeneration | 300 M 
------------------------------------------------------
300 M     Trainable params
0         Non-trainable params
300 M     Total params
1,200.707 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
avg_train_loss,█▂▁▁▁▁▁▁▁▁
avg_val_accuracy,▁▆▇▆███▇██
avg_val_loss,▂▁▁▇▅▅▆█▇█
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
train/loss,█▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/accuracy,▁▆▇▆███▇██
val/loss,▂▁▁▇▅▅▆█▇█

0,1
avg_train_loss,0.20735
avg_val_accuracy,0.78571
avg_val_loss,0.58751
epoch,9.0
train/loss,0.0561
trainer/global_step,12599.0
val/loss,0.58751


All Flax model weights were used when initializing MT5ForConditionalGeneration.

Some weights of MT5ForConditionalGeneration were not initialized from the Flax model and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  rank_zero_warn(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                        | Params
------------------------------------------------------
0 | model | MT5ForConditionalGeneration | 300 M 
------------------------------------------------------
300 M     Trainable params
0         Non-trainable params
300 M     Total params
1,200.707 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

VBox(children=(Label(value='0.000 MB of 0.000 MB uploaded (0.000 MB deduped)\r'), FloatProgress(value=1.0, max…

0,1
avg_train_loss,█▃▁▁▁▁▁▁▁▁
avg_val_accuracy,▁▇▇███████
avg_val_loss,█▁▁▁▁▁▁▁▁▁
epoch,▁▁▁▁▂▂▂▂▃▃▃▃▃▃▃▃▄▄▄▄▅▅▅▅▆▆▆▆▆▆▆▆▇▇▇▇████
train/loss,█▆▆▅▃▃▂▁▂▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▃▁▁▁▁▁
trainer/global_step,▁▁▁▂▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
val/accuracy,▁▇▇███████
val/loss,█▁▁▁▁▁▁▁▁▁

0,1
avg_train_loss,0.65644
avg_val_accuracy,0.55952
avg_val_loss,0.33873
epoch,9.0
train/loss,0.20803
trainer/global_step,12599.0
val/loss,0.33873


All Flax model weights were used when initializing MT5ForConditionalGeneration.

Some weights of MT5ForConditionalGeneration were not initialized from the Flax model and are newly initialized: ['encoder.embed_tokens.weight', 'decoder.embed_tokens.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  rank_zero_warn(
GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type                        | Params
------------------------------------------------------
0 | model | MT5ForConditionalGeneration | 300 M 
------------------------------------------------------
300 M     Trainable params
0         Non-trainable params
300 M     Total params
1,200.707 Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

In [None]:
%pip install flax

Defaulting to user installation because normal site-packages is not writeable
Collecting flax
  Downloading flax-0.5.2-py3-none-any.whl (197 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m197.1/197.1 kB[0m [31m10.4 MB/s[0m eta [36m0:00:00[0m
Collecting rich~=11.1
  Downloading rich-11.2.0-py3-none-any.whl (217 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m217.3/217.3 kB[0m [31m32.8 MB/s[0m eta [36m0:00:00[0m
Collecting optax
  Downloading optax-0.1.3-py3-none-any.whl (145 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m145.1/145.1 kB[0m [31m36.3 MB/s[0m eta [36m0:00:00[0m
Collecting colorama<0.5.0,>=0.4.0
  Downloading colorama-0.4.5-py2.py3-none-any.whl (16 kB)
Collecting chex>=0.0.4
  Downloading chex-0.1.3-py3-none-any.whl (72 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m72.2/72.2 kB[0m [31m25.8 MB/s[0m eta [36m0:00:00[0m
Collecting dm-tree>=0.1.5
  Downloading dm_tree-0.1.7-cp39-cp39

In [None]:
print(torch.version.cuda)

10.2
