<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://colab.research.google.com/github//jayralencar/pictoBERT/blob/main/Fine_tuning_PictoBERT_(colourful_semantics).ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/jayralencar/pictoBERT/blob/main/Fine_tuning_PictoBERT_(colourful_semantics).ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

# Fine-tuning PictoBERT for Pictogram Prediction Based on a Grammatical Structure

This notebook presents the procedures for adopting and fine-tuning PictoBERT to perform pictogram prediction based on a Grammatical Structure (cf. Section 5.2.1 in the paper).

## Verify if you are using a GPU

For fine-tuning, we suggest using a GPU, which can allow you to train fast.

In [None]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Thu Mar 24 00:30:13 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   53C    P8    33W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Install dependencies

In [None]:
!pip install pytorch_lightning==1.2.10 transformers tokenizers

Collecting pytorch_lightning==1.2.10
  Downloading pytorch_lightning-1.2.10-py3-none-any.whl (841 kB)
[K     |████████████████████████████████| 841 kB 4.3 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 39.0 MB/s 
[?25hCollecting tokenizers
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)
[K     |████████████████████████████████| 6.5 MB 36.7 MB/s 
[?25hCollecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 52.7 MB/s 
Collecting PyYAML!=5.4.*,>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 596 kB 46.3 MB/s 
Collecting fsspec[http]>=0.8.1
  Downloading fsspec-2022.2.0-py3-none-any.whl (134 kB)
[K     |████████████████████████████████| 134 kB 51

## Download files

### Download PictoBERT versions

In [None]:
!wget http://jayr.clubedosgeeks.com.br/pictobert/pictobert-large-contextual.zip
!wget http://jayr.clubedosgeeks.com.br/pictobert/pictobert-large-gloss.zip

!unzip pictobert-large-contextual.zip
!unzip pictobert-large-gloss.zip

--2022-03-24 00:30:40--  http://jayr.clubedosgeeks.com.br/pictobert/pictobert-large-contextual.zip
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1180295214 (1.1G) [application/zip]
Saving to: ‘pictobert-large-contextual.zip’


2022-03-24 00:31:57 (14.8 MB/s) - ‘pictobert-large-contextual.zip’ saved [1180295214/1180295214]

--2022-03-24 00:31:57--  http://jayr.clubedosgeeks.com.br/pictobert/pictobert-large-gloss.zip
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1180231318 (1.1G) [application/zip]
Saving to: ‘pictobert-large-gloss.zip’


2022-03-24 00:33:13 (14.9 MB/s) - ‘pictobert-large-gloss.zip’ saved [1

### Download PictoBERT Tokenizer

In [None]:
!wget http://jayr.clubedosgeeks.com.br/pictobert/childes_all_new.json

--2022-03-24 00:33:41--  http://jayr.clubedosgeeks.com.br/pictobert/childes_all_new.json
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 332233 (324K) [application/json]
Saving to: ‘childes_all_new.json’


2022-03-24 00:33:43 (435 KB/s) - ‘childes_all_new.json’ saved [332233/332233]



### Download dataset

In [None]:
!wget http://jayr.clubedosgeeks.com.br/pictobert/sem_childes_uk_clean_2.txt

--2022-03-24 00:33:43--  http://jayr.clubedosgeeks.com.br/pictobert/sem_childes_uk_clean_2.txt
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 4796378 (4.6M) [text/plain]
Saving to: ‘sem_childes_uk_clean_2.txt’


2022-03-24 00:33:45 (3.03 MB/s) - ‘sem_childes_uk_clean_2.txt’ saved [4796378/4796378]



### Download ARES embeddings

In [None]:
!wget http://jayr.clubedosgeeks.com.br/pictobert/ares_1024_gloss.bin
!wget http://jayr.clubedosgeeks.com.br/pictobert/ares_1024.bin

--2022-03-24 00:33:45--  http://jayr.clubedosgeeks.com.br/pictobert/ares_1024_gloss.bin
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 852260167 (813M) [application/octet-stream]
Saving to: ‘ares_1024_gloss.bin’


2022-03-24 00:34:42 (14.4 MB/s) - ‘ares_1024_gloss.bin’ saved [852260167/852260167]

--2022-03-24 00:34:42--  http://jayr.clubedosgeeks.com.br/pictobert/ares_1024.bin
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 852260167 (813M) [application/octet-stream]
Saving to: ‘ares_1024.bin’


2022-03-24 00:35:39 (14.5 MB/s) - ‘ares_1024.bin’ saved [852260167/852260167]



## Create tokenizer

In [None]:
import re
special_tokens = ["[UNK]","[SEP]", "[CLS]", "[PAD]", "[MASK]"]
examples = open("./sem_childes_uk_clean_2.txt",'r').readlines()
examples = [l.rstrip() for l in examples]
sentences = [[j for j in re.sub(r'\s+', ' ', l).split(" ") if j not in special_tokens ] for l in examples]
len(sentences)

86692

In [None]:
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import WhitespaceSplit
from tokenizers.processors import BertProcessing

special_tokens = ["[UNK]","[SEP]", "[CLS]", "[PAD]", "[MASK]"]

sense_tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))
sense_tokenizer.add_special_tokens(special_tokens)
sense_tokenizer.pre_tokenizer = WhitespaceSplit()

sep_token = "[SEP]"
cls_token = "[CLS]"
pad_token = "[PAD]"
unk_token = "[UNK]"
sep_token_id = sense_tokenizer.token_to_id(str(sep_token))
cls_token_id = sense_tokenizer.token_to_id(str(cls_token))
pad_token_id = sense_tokenizer.token_to_id(str(pad_token))
unk_token_id = sense_tokenizer.token_to_id(str(unk_token))

sense_tokenizer.post_processor = BertProcessing(
                (str(sep_token), sep_token_id), (str(cls_token), cls_token_id)
            )


In [None]:
from tokenizers.trainers import WordLevelTrainer
g = WordLevelTrainer(special_tokens=["[UNK]"])
sense_tokenizer.train_from_iterator(sentences, trainer=g)
print("Vocab size: ", sense_tokenizer.get_vocab_size())

Vocab size:  4960


In [None]:
sense_tokenizer.save("./cs_tokenizer.json")

## Dataset preparation

In [None]:
TEST_SIZE = 0.2
from sklearn.model_selection import train_test_split
train_idx, val_idx = train_test_split(list(range(len(examples))), test_size=TEST_SIZE, random_state=8)
test_idx, val_idx = train_test_split(val_idx, test_size=0.5, random_state=8)

import numpy as np
train_examples = np.array(examples).take(train_idx)
val_examples = np.array(examples).take(val_idx)
test_examples = np.array(examples).take(test_idx)
len(train_examples),len(val_examples), len(test_examples)

(69353, 8670, 8669)

In [None]:
TOKENIZER_PATH = "./cs_tokenizer.json" # you can change this path to use your custom tokenizer

from transformers import PreTrainedTokenizerFast

loaded_tokenizer = PreTrainedTokenizerFast(tokenizer_file=TOKENIZER_PATH)
loaded_tokenizer.pad_token = "[PAD]"
loaded_tokenizer.sep_token = "[SEP]"
loaded_tokenizer.mask_token = "[MASK]"
loaded_tokenizer.cls_token = "[CLS]"
loaded_tokenizer.unk_token = "[UNK]"

In [None]:
max_len = 9

def tokenize_function(tokenizer,examples):
    # Remove empty lines
    examples = [line for line in examples if len(line) > 0 and not line.isspace()]
    bert = tokenizer(
        examples,
        padding="max_length",
        max_length=max_len,
        return_special_tokens_mask=True,
        truncation=True
    )
    for i, data in enumerate(bert['input_ids']):
      a = np.array(data)
      special_tokens_mask = np.array(bert['special_tokens_mask'][i])
      special_tokens_mask[a == 3] = 1
      bert['special_tokens_mask'][i] = special_tokens_mask
    return bert

In [None]:
train_tokenized_examples = tokenize_function(loaded_tokenizer,train_examples)
val_tokenized_examples = tokenize_function(loaded_tokenizer,val_examples)
test_tokenized_examples = tokenize_function(loaded_tokenizer,test_examples)

In [None]:
from torch import tensor
def make_dict(examples):
  return {
      "input_ids": examples.input_ids,
      "attention_mask":examples.attention_mask,
      "special_tokens_mask":examples.special_tokens_mask,
  }

In [None]:
import pickle

TRAIN_DATA_PATH = "./CS_new_train_data.pt"
TEST_DATA_PATH = "./CS_new_test_data.pt"
VAL_DATA_PATH = "./CS_new_val_data.pt"

pickle.dump(make_dict(train_tokenized_examples),open(TRAIN_DATA_PATH,'wb'))
pickle.dump(make_dict(val_tokenized_examples),open(TEST_DATA_PATH,'wb'))
pickle.dump(make_dict(test_tokenized_examples),open(VAL_DATA_PATH ,'wb'))

## PictoBERT adaptation

### Load pictoBERT

In [None]:
from transformers import BertForMaskedLM
pictobert_version = "contextual" #@param ["contextual","gloss"]

if pictobert_version == "contextual":
  pictobert = BertForMaskedLM.from_pretrained("./pictobert")
else:
  pictobert = BertForMaskedLM.from_pretrained("./pictobert-gloss")

### Load BERT
At this point, we need to also load BERT as its embeddings matrix may be used to calculate the input embeddings of tokens that are not on PictoBERT or ARES vocabulary.

In [None]:
import torch
from transformers import BertTokenizer

pretrained_w = 'bert-large-uncased'
tokenizer_bert = BertTokenizer.from_pretrained(pretrained_w)
bert = BertForMaskedLM.from_pretrained(pretrained_w)

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.25G [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Load PictoBERT tokenizer

In [None]:
TOKENIZER_PATH = "./childes_all_new.json" # you can change this path to use your custom tokenizer

from transformers import PreTrainedTokenizerFast

pictobert_tokenizer = PreTrainedTokenizerFast(tokenizer_file=TOKENIZER_PATH)
pictobert_tokenizer.pad_token = "[PAD]"
pictobert_tokenizer.sep_token = "[SEP]"
pictobert_tokenizer.mask_token = "[MASK]"
pictobert_tokenizer.cls_token = "[CLS]"
pictobert_tokenizer.unk_token = "[UNK]"

### Update embeddings layer

In [None]:
new_vocab = loaded_tokenizer.get_vocab()
pictobert_vocab = pictobert_tokenizer.get_vocab()

In [None]:
in_pictobert = []
not_in = []
for w,idx_new in new_vocab.items():
  idx_old = pictobert_vocab.get(w, -1)
  if idx_old >= 0:
    in_pictobert.append(w)
  else:
    not_in.append(w)
  
print("New vocab size:",len(new_vocab))
print("PictoBERT vocab size:", len(pictobert_vocab))
print("Commom:",len(in_pictobert))
print("New tokens:",len(not_in))

New vocab size: 4959
PictoBERT vocab size: 13583
Commom: 4273
New tokens: 686


In [None]:
bert_embeddings = bert.get_input_embeddings()
pictobert_embeddings = pictobert.get_input_embeddings()

In [None]:
import torch
from gensim.models import KeyedVectors

if pictobert_version == "contextual":
  ares = KeyedVectors.load_word2vec_format("/content/ares_1024.bin", binary=True)
else:
  ares = KeyedVectors.load_word2vec_format("/content/ares_1024_gloss.bin", binary=True)
ares_contextual_mean = torch.tensor(ares.vectors).mean(0)

In [None]:
bert_vocab = tokenizer_bert.get_vocab()
# new_vocab = cs_tokenizer.get_vocab()

bert_wgts = bert.get_input_embeddings().weight.clone().detach()
pictobert_wgts = pictobert.get_input_embeddings().weight.clone().detach()

new_vocab_size = len(loaded_tokenizer.vocab)

In [None]:
new_wgts = pictobert_wgts.new_zeros(new_vocab_size,pictobert_wgts.size(1))

same_tokens_list = list()
different_tokens_list = list()
received_mean = list()

for w,idx_new in new_vocab.items():
  idx_pictobert = pictobert_vocab.get(w,-1)
  if idx_pictobert >= 0:
    new_wgts[idx_new] = pictobert_wgts[idx_pictobert]
  elif w in ares:
    new_wgts[idx_new] = torch.tensor(ares[w])
  else:
    w_ = " ".join(w.split("_"))
    tokenized = tokenizer_bert(w_,add_special_tokens=False,return_tensors='pt')
    v_ = bert_embeddings(tokenized.input_ids[0]).mean(0)
    new_wgts[idx_new] = v_

In [None]:
from torch import nn

new_wte = nn.Embedding(new_vocab_size,pictobert_wgts.size(1)) # new embeddings
new_wte.weight.data.normal_(mean=0,std=pictobert.config.initializer_range) 
new_wte.weight.data = new_wgts

pictobert.resize_token_embeddings(len(loaded_tokenizer))
pictobert.set_input_embeddings(new_wte)

In [None]:
MODEL_NAME = "./pictobert-CS-{0}".format(pictobert_version)
pictobert.save_pretrained(MODEL_NAME)

## Train Model

### Define constants

In [None]:
TOKENIZER_PATH = "./cs_tokenizer.json"
LOGS_PATH = "./logs"
CHECKPOINTS_PATH = "./checkpoints"

TRAIN_DATASET_PATH = "./CS_new_train_data.pt"
VAL_DATASET_PATH = "./CS_new_val_data.pt"
TEST_DATASET_PATH = "./CS_new_test_data.pt"

MAX_EPOCHS = 10
WARMUP_STEPS = int(MAX_EPOCHS * 0.15)
BATCH_SIZE = 32
NUM_WORKERS = 4
GPUS = 1
LEARNING_RATE = 1e-06
ACCUMULATE_GRAD_BATCHES = 4
LOGGER_VERSION = '1e06'
LOGGER_INFO = "first_run"
FREEZE_TO = None
MLM_PROBABILITY= 0.15

### Load Data

In [None]:
from torch.utils.data import Dataset, Subset
from torch import tensor
from sklearn.model_selection import train_test_split
import pickle

class MyDataset(Dataset):
  def __init__(self, examples):
    
    self.input_ids = examples['input_ids']
    self.attention_mask = examples['attention_mask']
    self.special_tokens_mask = examples['special_tokens_mask']
    self.labels = None
    if 'labels' in examples:
      self.labels = examples['labels']
  
  def __len__(self):
    return len(self.input_ids)
  
  def __getitem__(self, idx):
    input_ids = tensor(self.input_ids[idx])
    attention_mask = tensor(self.attention_mask[idx])
    special_tokens_mask = tensor(self.special_tokens_mask[idx])

    out_dict = {
      "input_ids":input_ids, 
      "attention_mask":attention_mask, 
      "special_tokens_mask":special_tokens_mask
    }

    if self.labels is not None:
      out_dict['labels'] = self.labels[idx]

    return out_dict


tds = pickle.load(open(TRAIN_DATASET_PATH,'rb'))
train_dataset = MyDataset(tds)

vds = pickle.load(open(VAL_DATASET_PATH,'rb'))
val_dataset = MyDataset(vds)

tsds = pickle.load(open(TEST_DATASET_PATH,'rb'))
test_dataset = MyDataset(tsds)


In [None]:
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=loaded_tokenizer, mlm_probability=MLM_PROBABILITY)

## Model

In [None]:
from argparse import ArgumentParser
import math
import torch
import torch.nn as nn
from transformers import get_polynomial_decay_schedule_with_warmup
from transformers import AdamW
import pytorch_lightning as pl
from sklearn.metrics import accuracy_score
from pytorch_lightning.callbacks import ModelCheckpoint
from scipy import stats
import numpy as np
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

percentiles = [
  {
    "percentil":'1%',
    "z": stats.norm.ppf(1-0.01),
  },
  {
    "percentil":'5%',
    "z": stats.norm.ppf(1-0.05),
  },
  {
    "percentil":'10%',
    "z": stats.norm.ppf(1-0.1),
  },
  {
    "percentil":'15%',
    "z": stats.norm.ppf(1-0.15),
  },
  {
    "percentil":'20%',
    "z": stats.norm.ppf(1-0.2),
  },
]

firsts = [6,	12,	24,	32,	40]

from transformers import BertLMHeadModel
from transformers import BertForMaskedLM

class LitBertClassifier(pl.LightningModule):
    def __init__(self, pretrained_model_name='bert-large-uncased'):
        super().__init__()
        self.save_hyperparameters()
        self.batch_size = 128
        self.lr = LEARNING_RATE
        self.train_dataset = train_dataset
        self.val_dataset = val_dataset
        self.test_dataset = test_dataset
        self.bert = BertForMaskedLM.from_pretrained(pretrained_model_name)
      
    
    def freeze_to(self, layers):
      for param in self.bert.bert.encoder.layer[:layers].parameters():
        param.requires_grad = False


    def forward(self, input_ids, attention_mask, labels=None):
        if labels == None:
            return self.bert(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )    
        return self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels = labels
        )

    def training_step(self, batch, batch_idx):
        outputs = self._shared_step(batch, batch_idx)
        loss = outputs[0]

        self.log("train_loss", loss, on_epoch=True, prog_bar=True,)

        return loss

    def train_dataloader(self):
      train_dataloader = DataLoader(
          train_dataset,
          batch_size=self.batch_size,
          num_workers=NUM_WORKERS,
          pin_memory=True,
          collate_fn=data_collator,
          drop_last = True,
          shuffle=True
      )
      return train_dataloader
    
    def test_dataloader(self):
      return  DataLoader(
          test_dataset,
          batch_size=self.batch_size,
          num_workers=NUM_WORKERS,
          pin_memory=True,
          drop_last = True,
      )
    
    def val_dataloader(self):
      return  DataLoader(
          val_dataset,
          batch_size=self.batch_size,
          num_workers=NUM_WORKERS,
          pin_memory=True,
          collate_fn=data_collator,
          drop_last = True,
      )

    def get_accuracy(self, batch, results):
        y_true = tensor([]).to(torch.device("cuda:0"))
        y_pred = tensor([]).to(torch.device("cuda:0"))
        for idx, ipids in enumerate(batch["input_ids"]):

            idxs = (ipids == loaded_tokenizer.mask_token_id).nonzero()
            if idxs.size()[0] > 0:
                device = "cuda:0"
                y_true = torch.cat((y_true, batch['labels'][idx][idxs].resize(1, idxs.size()[0])[0]))
                # y_true = y_true + batch['labels'][idx][idxs].resize(1, idxs.size()[0])[0].tolist()

                idxs_2 = tensor([a[0] for a in idxs])
                idxs_2 = idxs_2.to(device)
                res = torch.index_select(results[idx],0,idxs_2).argmax(1)

                y_pred = torch.cat((y_pred,res))
        if len(y_pred) == 0:
            return None
        return accuracy(tensor(y_true),tensor(y_pred))

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
          result = self._shared_step(batch, batch_idx)
          loss = result[0].detach()
          
          predictions = result[1]
          labels = batch['labels']
          masked = batch['input_ids']
          n = masked.detach().cpu().numpy()
          predicted = predictions.detach().cpu().numpy()[n == loaded_tokenizer.mask_token_id]
          accuracy = accuracy_score(labels[n == loaded_tokenizer.mask_token_id].detach().cpu(), np.argmax(predicted, axis=1))
          self.log("val_acc", accuracy, on_epoch=True, prog_bar=True,)


          return {
              "val_loss":loss
          }
    
    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.log("val_loss", val_loss, on_epoch=True, prog_bar=True,)
    
    
    def test_step(self, batch, batch_idx):
        with torch.no_grad():
          result = self._shared_step(batch, batch_idx)
          loss = result[0].detach()
          perplexity = torch.exp(loss)
          self.log("test_ppl", perplexity, on_epoch=True, prog_bar=True,)
          self.log("test_loss", loss, on_epoch=True, prog_bar=True,)

          predictions = F.softmax(result[1], dim=-1)
          labels = batch['labels']
          masked = batch['input_ids']
          n = masked.detach().cpu().numpy()
          predicted = predictions.detach().cpu().numpy()[n == loaded_tokenizer.mask_token_id]
          
          accuracy = accuracy_score(labels[n == loaded_tokenizer.mask_token_id].detach().cpu(), np.argmax(predicted, axis=1))
          self.log("test_acc", accuracy, on_epoch=True, prog_bar=True,)
          
          for percentil in percentiles:
            count = 0
            z = percentil['z']
            for i, data in enumerate(predicted):
              mean = data[data>0].mean()
              std = data[data>0].std()
              x = (z * std) + mean
              if labels[masked == loaded_tokenizer.mask_token_id][i].item() in (predicted[i] > x).nonzero()[0]:
                count += 1
            

            isin = count/predicted.shape[0]

            self.log("test_"+percentil['percentil'], isin, on_epoch=True, prog_bar=True,)
          
          ids = np.argsort(-1*predicted,axis=1)
          
          for first in firsts:
            count = 0
            for i, data in enumerate(ids):
              if labels[masked == loaded_tokenizer.mask_token_id][i].item() in data[:first]:
                count += 1
            isin = count/predicted.shape[0]

            self.log("test_"+str(first), isin, on_epoch=True, prog_bar=True,)
            

          return {
              "test_ppl":perplexity,
              "test_loss":loss,
              "test_count":count,
          }
    

    def _shared_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]

        outputs = self.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        return outputs

    def configure_optimizers(self):
      optimizer = AdamW(self.parameters(), lr=self.lr,betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
      scheduler = {
          'scheduler': get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS,num_training_steps=MAX_EPOCHS,lr_end=1e-09),
          'name': 'lr'
      }
      return [optimizer],[scheduler]
    
    def backward(self, loss, optimizer, idx):
        loss.backward()
    


### Logger

In [None]:
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor

tb_logger = pl_loggers.TensorBoardLogger(LOGS_PATH,name=LOGGER_INFO, version=LOGGER_VERSION)
lr_monitor = LearningRateMonitor(logging_interval='epoch')

### Checkpointing

In [None]:
checkpoint_callback = ModelCheckpoint(
    dirpath=CHECKPOINTS_PATH,
    filename='bert-large-{epoch:02d}-{val_loss:.2f}',
    mode='min',
)


In [None]:
trainer = pl.Trainer(
      accelerator='ddp',
      max_epochs=MAX_EPOCHS,
      logger=tb_logger,
      gpus=GPUS,
      callbacks=[checkpoint_callback, lr_monitor],
      precision=16,
      auto_scale_batch_size="binsearch"
  )

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


In [None]:
to_train = LitBertClassifier(MODEL_NAME)

In [None]:
trainer.tune(to_train)

In [None]:
trainer.fit(to_train)