# PictoBERT: Transformers for Pictogram Prediction (ARASAAC fine-tuning)

This notebook presents the procedures for fine-tuning PictoBERT to make predictions based on the vocabulary from [ARASAAC](https://arasaac.org/) portal. This notebook refers to section 5.2.2 of the paper.

## Verify if you are using a GPU

For fine-tuning, we suggest using a GPU, which can allow you to train fast.

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Thu Mar 24 20:36:26 2022       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.32.03    Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla K80           Off  | 00000000:00:04.0 Off |                    0 |
| N/A   37C    P8    28W / 149W |      0MiB / 11441MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

## Install dependencies

In [2]:
!pip install transformers tokenizers pytorch_lightning==1.2.10

Collecting transformers
  Downloading transformers-4.17.0-py3-none-any.whl (3.8 MB)
[K     |████████████████████████████████| 3.8 MB 5.2 MB/s 
[?25hCollecting tokenizers
  Downloading tokenizers-0.11.6-cp37-cp37m-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (6.5 MB)
[K     |████████████████████████████████| 6.5 MB 26.8 MB/s 
[?25hCollecting pytorch_lightning==1.2.10
  Downloading pytorch_lightning-1.2.10-py3-none-any.whl (841 kB)
[K     |████████████████████████████████| 841 kB 24.6 MB/s 
[?25hCollecting future>=0.17.1
  Downloading future-0.18.2.tar.gz (829 kB)
[K     |████████████████████████████████| 829 kB 44.1 MB/s 
Collecting torchmetrics==0.2.0
  Downloading torchmetrics-0.2.0-py3-none-any.whl (176 kB)
[K     |████████████████████████████████| 176 kB 43.9 MB/s 
[?25hCollecting PyYAML!=5.4.*,>=5.1
  Downloading PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl (596 kB)
[K     |████████████████████████████████| 

## Download data

### Download ARASAAC all pictograms

You can download ARASAAC all pictograms by using the API provided by the portal: https://arasaac.org/developers/api

Or a dump from fev 2, 2022: http://jayr.clubedosgeeks.com.br/pictobert/ARASAAC_All_pictograms.json 

In [3]:
!wget http://jayr.clubedosgeeks.com.br/pictobert/ARASAAC_All_pictograms.json

--2022-03-24 20:36:59--  http://jayr.clubedosgeeks.com.br/pictobert/ARASAAC_All_pictograms.json
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9402637 (9.0M) [application/json]
Saving to: ‘ARASAAC_All_pictograms.json’


2022-03-24 20:36:59 (25.9 MB/s) - ‘ARASAAC_All_pictograms.json’ saved [9402637/9402637]



### Download pictogram to sense mappings

In [4]:
!wget http://jayr.clubedosgeeks.com.br/pictobert/arasaac_mapping.csv

--2022-03-24 20:36:59--  http://jayr.clubedosgeeks.com.br/pictobert/arasaac_mapping.csv
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 489601 (478K) [text/csv]
Saving to: ‘arasaac_mapping.csv’


2022-03-24 20:37:00 (2.54 MB/s) - ‘arasaac_mapping.csv’ saved [489601/489601]



### Download SemCHILDES

In [5]:
!wget http://jayr.clubedosgeeks.com.br/pictobert/all_mt_2.txt

--2022-03-24 20:37:00--  http://jayr.clubedosgeeks.com.br/pictobert/all_mt_2.txt
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 52489765 (50M) [text/plain]
Saving to: ‘all_mt_2.txt’


2022-03-24 20:37:01 (39.8 MB/s) - ‘all_mt_2.txt’ saved [52489765/52489765]



### Download corpus already adapted for ARASAAC

In [6]:
!wget http://jayr.clubedosgeeks.com.br/pictobert/corpus_arasaac.txt

--2022-03-24 20:37:01--  http://jayr.clubedosgeeks.com.br/pictobert/corpus_arasaac.txt
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 2834625 (2.7M) [text/plain]
Saving to: ‘corpus_arasaac.txt’


2022-03-24 20:37:02 (10.3 MB/s) - ‘corpus_arasaac.txt’ saved [2834625/2834625]



### Download already trained tokenizer

In [7]:
!wget http://jayr.clubedosgeeks.com.br/pictobert/tokenizer_arasaac.txt

--2022-03-24 20:37:02--  http://jayr.clubedosgeeks.com.br/pictobert/tokenizer_arasaac.txt
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 404 Not Found
2022-03-24 20:37:03 ERROR 404: Not Found.



### Download PictoBERT versions

In [8]:
!wget http://jayr.clubedosgeeks.com.br/pictobert/pictobert-large-contextual.zip
!wget http://jayr.clubedosgeeks.com.br/pictobert/pictobert-large-gloss.zip

!unzip pictobert-large-contextual.zip
!unzip pictobert-large-gloss.zip

--2022-03-24 20:37:03--  http://jayr.clubedosgeeks.com.br/pictobert/pictobert-large-contextual.zip
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1180295214 (1.1G) [application/zip]
Saving to: ‘pictobert-large-contextual.zip’


2022-03-24 20:37:19 (71.9 MB/s) - ‘pictobert-large-contextual.zip’ saved [1180295214/1180295214]

--2022-03-24 20:37:19--  http://jayr.clubedosgeeks.com.br/pictobert/pictobert-large-gloss.zip
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1180231318 (1.1G) [application/zip]
Saving to: ‘pictobert-large-gloss.zip’


2022-03-24 20:37:34 (75.9 MB/s) - ‘pictobert-large-gloss.zip’ saved [1

### Download PictoBERT tokenizer

In [9]:
!wget http://jayr.clubedosgeeks.com.br/pictobert/childes_all_new.json

--2022-03-24 20:38:02--  http://jayr.clubedosgeeks.com.br/pictobert/childes_all_new.json
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 332233 (324K) [application/json]
Saving to: ‘childes_all_new.json’


2022-03-24 20:38:02 (2.47 MB/s) - ‘childes_all_new.json’ saved [332233/332233]



### Download ARES embeddings

In [10]:
!wget http://jayr.clubedosgeeks.com.br/pictobert/ares_1024_gloss.bin
!wget http://jayr.clubedosgeeks.com.br/pictobert/ares_1024.bin

--2022-03-24 20:38:02--  http://jayr.clubedosgeeks.com.br/pictobert/ares_1024_gloss.bin
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 852260167 (813M) [application/octet-stream]
Saving to: ‘ares_1024_gloss.bin’


2022-03-24 20:38:12 (82.0 MB/s) - ‘ares_1024_gloss.bin’ saved [852260167/852260167]

--2022-03-24 20:38:12--  http://jayr.clubedosgeeks.com.br/pictobert/ares_1024.bin
Resolving jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)... 192.185.214.132
Connecting to jayr.clubedosgeeks.com.br (jayr.clubedosgeeks.com.br)|192.185.214.132|:80... connected.
HTTP request sent, awaiting response... 200 OK
Length: 852260167 (813M) [application/octet-stream]
Saving to: ‘ares_1024.bin’


2022-03-24 20:38:23 (76.7 MB/s) - ‘ares_1024.bin’ saved [852260167/852260167]



## Map ARASAAC to WordNET 3.0 word-senses

This process can take several minutes. You can download the already processed file from http://jayr.clubedosgeeks.com.br/pictobert/arasaac_mappings.csv

In [1]:
import nltk
nltk.download("wordnet")

[nltk_data] Downloading package wordnet to /root/nltk_data...
[nltk_data]   Package wordnet is already up-to-date!


True

In [None]:
import requests, json
from tqdm import tqdm
from nltk.corpus import wordnet as wn
import pandas as pd

def wordnet_map(wn_id):
    try:
        url = "http://wordnet-rdf.princeton.edu/json/id/"+wn_id
        r = requests.get(url)
        response_json = r.json()
        if len(response_json[0]['old_keys']) > 0:
            return response_json[0]['old_keys']['pwn30'][0]
        else:
            return None
    except:
        return None

pictograms = json.load(open('ARASAAC_All_pictograms.json'))

pictograms_dic = []

for i,pictogram in enumerate(tqdm(pictograms)):
    pictogram_id = pictogram['_id']
    if "personal pronoun" in pictogram['categories']:
        # download_pictogram(pictogram_id)
        for keyword in pictogram['keywords']:
            pictograms_dic.append({
                "word": keyword['keyword'],
                "pictogram_id": pictogram['_id'],
                "synset": keyword['keyword']
            })
    for s in pictogram['synsets']:
        synset = wordnet_map(s)
        if synset is not None:
          wn_ss = wn.of2ss(synset)
          lemma = wn_ss.lemmas()[0].key()
        else:
          lemma = None
        for keyword in pictogram['keywords']:
            if lemma is None:
              lemma = keyword['keyword']
            pictograms_dic.append({
                "word": keyword['keyword'],
                "pictogram_id": pictogram['_id'],
                "synset": synset,
                "word_senses": lemma
            })
df = pd.DataFrame(pictograms_dic)
df.to_csv("arasaac_mappings.csv")

 21%|██▏       | 2443/11406 [04:04<10:55, 13.67it/s]

## Adapt SemCHILDES to fit ARASAAC vocabulary

In [2]:
import pandas as pd

pic_map = pd.read_csv("./arasaac_mapping.csv")
vocab = ['_'.join(w.split(" ")) for w in list(pic_map['word_senses'])]
pic_map.head()

Unnamed: 0.2,Unnamed: 0,Unnamed: 0.1,word,pictogram_id,synset,word_senses,Unnamed: 6
0,0,0,pavement,2247.0,04215402-n,pavement%1:06:01::,
1,1,1,sidewalk,2247.0,04215402-n,sidewalk%1:06:00::,
2,4,4,carpet,2249.0,04118021-n,carpet%1:06:00::,
3,5,5,rug,2249.0,04118021-n,rug%1:06:00::,
4,6,6,pillow,2250.0,03938244-n,pillow%1:06:00::,


In [3]:
from tqdm import tqdm
f = open("./all_mt_2.txt",'r')
f2 = open("./newcorpus.txt",'w')
new_sentences = []
for l in tqdm(f.readlines()):
  write = True
  for token in l.rstrip().split(" "):
    if token not in vocab and token not in ['.',',',';','?']:
      write = False
  if write and len(l.rstrip().split(" ")) > 3:
    new_sentences.append(l)
    f2.write(l)
  
len(new_sentences)

  6%|▌         | 59262/955489 [00:44<11:17, 1322.26it/s]


KeyboardInterrupt: ignored

In [4]:
len(vocab)

8259

In [5]:
sentences = [s.rstrip() for s in new_sentences]
len(sentences)

3880

## Create Tokenizer

In [6]:
from tokenizers import Tokenizer
from tokenizers.models import WordLevel
from tokenizers.pre_tokenizers import WhitespaceSplit
from tokenizers.processors import BertProcessing

sense_tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"
  ))
sense_tokenizer.add_special_tokens(["[SEP]", "[CLS]", "[PAD]", "[MASK]","[UNK]"])
sense_tokenizer.pre_tokenizer = WhitespaceSplit()

sep_token = "[SEP]"
cls_token = "[CLS]"
pad_token = "[PAD]"
unk_token = "[UNK]"
sep_token_id = sense_tokenizer.token_to_id(str(sep_token))
cls_token_id = sense_tokenizer.token_to_id(str(cls_token))
pad_token_id = sense_tokenizer.token_to_id(str(pad_token))
unk_token_id = sense_tokenizer.token_to_id(str(unk_token))


sense_tokenizer.post_processor = BertProcessing(
                (str(sep_token), sep_token_id), (str(cls_token), cls_token_id)
            )

In [7]:
from tokenizers.trainers import WordLevelTrainer
g = WordLevelTrainer(special_tokens=["[UNK]"])
sense_tokenizer.train_from_iterator(sentences+vocab, trainer=g)
print("Vocab size: ", sense_tokenizer.get_vocab_size())

Vocab size:  8266


In [8]:
tokenizer_vocab = [w for w,i in sense_tokenizer.get_vocab().items()]
difference = list(set(vocab).difference(set(tokenizer_vocab)))
difference

[]

In [9]:
sense_tokenizer.save("./tokenizer_arasaac.json")

## Dataset preparation

In [10]:
TEST_SIZE = 0.2
from sklearn.model_selection import train_test_split
train_idx, val_idx = train_test_split(list(range(len(sentences))), test_size=TEST_SIZE, random_state=8)
test_idx, val_idx = train_test_split(val_idx, test_size=0.5, random_state=8)

import numpy as np
train_examples = np.array(sentences).take(train_idx)
val_examples = np.array(sentences).take(val_idx)
test_examples = np.array(sentences).take(test_idx)
len(train_examples),len(val_examples), len(test_examples)

(3104, 388, 388)

In [11]:
from transformers import PreTrainedTokenizerFast
TOKENIZER_PATH = "./tokenizer_arasaac.json"
loaded_tokenizer = PreTrainedTokenizerFast(tokenizer_file=TOKENIZER_PATH)
loaded_tokenizer.pad_token = "[PAD]"
loaded_tokenizer.sep_token = "[SEP]"
loaded_tokenizer.mask_token = "[MASK]"
loaded_tokenizer.cls_token = "[CLS]"
loaded_tokenizer.unk_token = "[UNK]"

In [12]:
max_len = 16

def tokenize_function(tokenizer,examples):
    # Remove empty lines
    examples = [line for line in examples if len(line) > 0 and not line.isspace()]
    bert = tokenizer(
        examples,
        padding="max_length",
        max_length=max_len,
        return_special_tokens_mask=True,
        truncation=True
    )
    return bert

In [13]:
train_tokenized_examples = tokenize_function(loaded_tokenizer,train_examples)
val_tokenized_examples = tokenize_function(loaded_tokenizer,val_examples)
test_tokenized_examples = tokenize_function(loaded_tokenizer,test_examples)

In [14]:
from torch import tensor
def make_dict(examples):
  return {
      "input_ids": examples.input_ids,
      "attention_mask":examples.attention_mask,
      "special_tokens_mask":examples.special_tokens_mask
  }

In [15]:
!mkdir data

In [16]:
import pickle

TRAIN_DATA_PATH = "./data/train_data.pt"
TEST_DATA_PATH = "./data/test_data.pt"
VAL_DATA_PATH = "./data/val_data.pt"

pickle.dump(make_dict(train_tokenized_examples),open(TRAIN_DATA_PATH,'wb'))
pickle.dump(make_dict(val_tokenized_examples),open(TEST_DATA_PATH,'wb'))
pickle.dump(make_dict(test_tokenized_examples),open(VAL_DATA_PATH ,'wb'))

## PictoBERT adaptation

### Load PictoBERT

In [17]:
from transformers import BertForMaskedLM
pictobert_version = "contextual" #@param ["contextual","gloss"]

if pictobert_version == "contextual":
  pictobert = BertForMaskedLM.from_pretrained("./pictobert")
else:
  pictobert = BertForMaskedLM.from_pretrained("./pictobert-gloss")

### Load BERT

At this point, we need to also load BERT as its embeddings matrix may be used to calculate the input embeddings of tokens that are not on PictoBERT or ARES vocabulary.

In [18]:
import torch
from transformers import BertTokenizer

pretrained_w = 'bert-large-uncased'
tokenizer_bert = BertTokenizer.from_pretrained(pretrained_w)
bert = BertForMaskedLM.from_pretrained(pretrained_w)

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.25G [00:00<?, ?B/s]

Some weights of the model checkpoint at bert-large-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


### Load PictoBERT tokenizer

In [19]:
TOKENIZER_PATH = "./childes_all_new.json" # you can change this path to use your custom tokenizer

from transformers import PreTrainedTokenizerFast

pictobert_tokenizer = PreTrainedTokenizerFast(tokenizer_file=TOKENIZER_PATH)
pictobert_tokenizer.pad_token = "[PAD]"
pictobert_tokenizer.sep_token = "[SEP]"
pictobert_tokenizer.mask_token = "[MASK]"
pictobert_tokenizer.cls_token = "[CLS]"
pictobert_tokenizer.unk_token = "[UNK]"

### Update embeddings layer

In [20]:
new_vocab = loaded_tokenizer.get_vocab()
pictobert_vocab = pictobert_tokenizer.get_vocab()

In [21]:
in_pictobert = []
not_in = []
for w,idx_new in new_vocab.items():
  idx_old = pictobert_vocab.get(w, -1)
  if idx_old >= 0:
    in_pictobert.append(w)
  else:
    not_in.append(w)
  
print("New vocab size:",len(new_vocab))
print("PictoBERT vocab size:", len(pictobert_vocab))
print("Commom:",len(in_pictobert))
print("New tokens:",len(not_in))

New vocab size: 8265
PictoBERT vocab size: 13583
Commom: 3513
New tokens: 4752


In [22]:
bert_embeddings = bert.get_input_embeddings()
pictobert_embeddings = pictobert.get_input_embeddings()

In [23]:
import torch
from gensim.models import KeyedVectors

if pictobert_version == "contextual":
  ares = KeyedVectors.load_word2vec_format("/content/ares_1024.bin", binary=True)
else:
  ares = KeyedVectors.load_word2vec_format("/content/ares_1024_gloss.bin", binary=True)
ares_contextual_mean = torch.tensor(ares.vectors).mean(0)

In [24]:
bert_vocab = tokenizer_bert.get_vocab()
# new_vocab = cs_tokenizer.get_vocab()

bert_wgts = bert.get_input_embeddings().weight.clone().detach()
pictobert_wgts = pictobert.get_input_embeddings().weight.clone().detach()

new_vocab_size = len(loaded_tokenizer.vocab)

In [25]:
new_wgts = pictobert_wgts.new_zeros(new_vocab_size,pictobert_wgts.size(1))

same_tokens_list = list()
different_tokens_list = list()
received_mean = list()

for w,idx_new in new_vocab.items():
  idx_pictobert = pictobert_vocab.get(w,-1)
  if idx_pictobert >= 0:
    new_wgts[idx_new] = pictobert_wgts[idx_pictobert]
  elif w in ares:
    new_wgts[idx_new] = torch.tensor(ares[w])
  else:
    w_ = " ".join(w.split("_"))
    tokenized = tokenizer_bert(w_,add_special_tokens=False,return_tensors='pt')
    v_ = bert_embeddings(tokenized.input_ids[0]).mean(0)
    new_wgts[idx_new] = v_

In [26]:
from torch import nn

new_wte = nn.Embedding(new_vocab_size,pictobert_wgts.size(1)) # new embeddings
new_wte.weight.data.normal_(mean=0,std=pictobert.config.initializer_range) 
new_wte.weight.data = new_wgts

pictobert.resize_token_embeddings(len(loaded_tokenizer))
pictobert.set_input_embeddings(new_wte)

In [27]:
MODEL_NAME = "./pictobert-ARASAAC-{0}".format(pictobert_version)
pictobert.save_pretrained(MODEL_NAME)

## Train Model

### Define constants

In [28]:
TOKENIZER_PATH = "./cs_tokenizer.json"
LOGS_PATH = "./logs"
CHECKPOINTS_PATH = "./checkpoints"

TRAIN_DATASET_PATH = "./data/train_data.pt"
VAL_DATASET_PATH = "./data/val_data.pt"
TEST_DATASET_PATH = "./data/test_data.pt"

MAX_EPOCHS = 10
WARMUP_STEPS = int(MAX_EPOCHS * 0.15)
BATCH_SIZE = 32
NUM_WORKERS = 4
GPUS = 1
LEARNING_RATE = 1e-06
ACCUMULATE_GRAD_BATCHES = 4
LOGGER_VERSION = '1e06'
LOGGER_INFO = "first_run"
FREEZE_TO = None
MLM_PROBABILITY= 0.15

### Load Data

In [29]:
from torch.utils.data import Dataset, Subset
from torch import tensor
from sklearn.model_selection import train_test_split
import pickle

class MyDataset(Dataset):
  def __init__(self, examples):
    
    self.input_ids = examples['input_ids']
    self.attention_mask = examples['attention_mask']
    self.special_tokens_mask = examples['special_tokens_mask']
    self.labels = None
    if 'labels' in examples:
      self.labels = examples['labels']
  
  def __len__(self):
    return len(self.input_ids)
  
  def __getitem__(self, idx):
    input_ids = tensor(self.input_ids[idx])
    attention_mask = tensor(self.attention_mask[idx])
    special_tokens_mask = tensor(self.special_tokens_mask[idx])

    out_dict = {
      "input_ids":input_ids, 
      "attention_mask":attention_mask, 
      "special_tokens_mask":special_tokens_mask
    }

    if self.labels is not None:
      out_dict['labels'] = self.labels[idx]

    return out_dict


tds = pickle.load(open(TRAIN_DATASET_PATH,'rb'))
train_dataset = MyDataset(tds)

vds = pickle.load(open(VAL_DATASET_PATH,'rb'))
val_dataset = MyDataset(vds)

tsds = pickle.load(open(TEST_DATASET_PATH,'rb'))
test_dataset = MyDataset(tsds)


In [30]:
from transformers import DataCollatorForLanguageModeling
data_collator = DataCollatorForLanguageModeling(tokenizer=loaded_tokenizer, mlm_probability=MLM_PROBABILITY)

In [31]:
def top_n_data_collator(examples):
  
  batch = {
      "input_ids" : torch.stack([example['input_ids'] for example in examples]),
      "attention_mask": torch.stack([example['attention_mask'] for example in examples]),
  }
  special_tokens_mask = torch.stack([example['special_tokens_mask'] for example in examples])
  masked_indices = special_tokens_mask.bool()

  mask_labels = special_tokens_mask.clone()
  probability_matrix = mask_labels
  probability_matrix.masked_fill_(torch.tensor(special_tokens_mask, dtype=torch.bool), value=0.0)
  masked_indices = probability_matrix.bool()

  labels = batch['input_ids'].clone()
  
  for i, row in enumerate(special_tokens_mask.bool()):
    mask_id = torch.where(row==False)[0][-2]
    masked_indices[i][mask_id] = True
    batch['input_ids'][i][mask_id] = loaded_tokenizer.mask_token_id


  labels[~masked_indices] =  -100
  
  batch['labels'] = labels

  return batch

In [32]:
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler

train_dataloader = DataLoader(
    train_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=data_collator,
    drop_last = True,
    shuffle=True,
    
)

val_dataloader = DataLoader(
    val_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    pin_memory=True,
    collate_fn=data_collator,
    drop_last = True
)

test_dataloader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    num_workers=NUM_WORKERS,
    collate_fn=top_n_data_collator,
    pin_memory=True,
    drop_last = True
)

  cpuset_checked))


### Model

In [33]:
from argparse import ArgumentParser
import math
import torch
import torch.nn as nn
from transformers import get_polynomial_decay_schedule_with_warmup
from transformers import AdamW
import pytorch_lightning as pl
from sklearn.metrics import accuracy_score
from pytorch_lightning.callbacks import ModelCheckpoint
from scipy import stats
import numpy as np
import torch.nn.functional as F

percentiles = [
  {
    "percentil":'1%',
    "z": stats.norm.ppf(1-0.01),
  },
  {
    "percentil":'5%',
    "z": stats.norm.ppf(1-0.05),
  },
  {
    "percentil":'10%',
    "z": stats.norm.ppf(1-0.1),
  },
  {
    "percentil":'15%',
    "z": stats.norm.ppf(1-0.15),
  },
  {
    "percentil":'20%',
    "z": stats.norm.ppf(1-0.2),
  },
]

firsts = [1,9,18,	25,	36]

from transformers import BertLMHeadModel
from transformers import BertForMaskedLM

class LitBertClassifier(pl.LightningModule):
    def __init__(self, pretrained_model_name='bert-large-uncased'):
        super().__init__()
        self.save_hyperparameters()
        self.batch_size = 128
        self.lr = LEARNING_RATE
        self.train_dataset = train_dataset
        self.bert = BertForMaskedLM.from_pretrained(pretrained_model_name)
      
    
    def freeze_to(self, layers):
      for param in self.bert.bert.encoder.layer[:layers].parameters():
        param.requires_grad = False


    def forward(self, input_ids, attention_mask, labels=None):
        if labels == None:
            return self.bert(
                input_ids=input_ids,
                attention_mask=attention_mask,
            )    
        return self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels = labels
        )

    def training_step(self, batch, batch_idx):
        outputs = self._shared_step(batch, batch_idx)
        loss = outputs[0]

        self.log("train_loss", loss, on_epoch=True, prog_bar=True,)

        return loss


    def train_dataloader(self):
      # return DataLoader(self.train_dataset,batch_size=self.batch_size,num_workers=NUM_WORKERS,pin_memory=True,collate_fn=data_collator)
      train_dataloader = DataLoader(
          train_dataset,
          batch_size=self.batch_size,
          num_workers=NUM_WORKERS,
          pin_memory=True,
          collate_fn=data_collator,
          drop_last = True,
          shuffle=True
      )
      return train_dataloader
      # return DataLoader(self.train_dataset, batch_size=self.batch_size)

    def validation_step(self, batch, batch_idx):
        with torch.no_grad():
          result = self._shared_step(batch, batch_idx)
          loss = result[0].detach()
          
          return {
              "val_loss":loss
          }
    
    def validation_epoch_end(self, outputs):
        val_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
        self.log("val_loss", val_loss, on_epoch=True, prog_bar=True,)
    
    
    def test_step(self, batch, batch_idx):
        with torch.no_grad():
          result = self._shared_step(batch, batch_idx)
          loss = result[0].detach()
          perplexity = torch.exp(loss)
          self.log("test_ppl", perplexity, on_epoch=True, prog_bar=True,)
          self.log("test_loss", loss, on_epoch=True, prog_bar=True,)

          predictions = F.softmax(result[1], dim=-1)
          labels = batch['labels']
          masked = batch['input_ids']
          n = masked.detach().cpu().numpy()
          predicted = predictions.detach().cpu().numpy()[n == loaded_tokenizer.mask_token_id]

          ids = np.argsort(-1*predicted,axis=1)
          for first in firsts:
            count = 0
            for i, data in enumerate(ids):
              if labels[masked == loaded_tokenizer.mask_token_id][i].item() in data[:first]:
                count += 1
            isin = count/predicted.shape[0]

            self.log("test_"+str(first), isin, on_epoch=True, prog_bar=True,)
          
          return {
              "test_ppl":perplexity,
              "test_loss":loss,
          }


    def _shared_step(self, batch, batch_idx):
        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        labels = batch["labels"]

        outputs = self.forward(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )

        return outputs

    def configure_optimizers(self):
      optimizer = AdamW(self.parameters(), lr=self.lr,betas=(0.9, 0.999), eps=1e-08, weight_decay=0.01)
      scheduler = {
          'scheduler': get_polynomial_decay_schedule_with_warmup(optimizer, num_warmup_steps=WARMUP_STEPS,num_training_steps=MAX_EPOCHS,lr_end=1e-09),
          'name': 'lr'
      }
      return [optimizer],[scheduler]
    
    def backward(self, loss, optimizer, idx):
        loss.backward()
    




### Logger and Checkpointing

In [34]:
from pytorch_lightning import loggers as pl_loggers
from pytorch_lightning.callbacks import LearningRateMonitor

tb_logger = pl_loggers.TensorBoardLogger(LOGS_PATH,name=LOGGER_INFO, version=LOGGER_VERSION)
lr_monitor = LearningRateMonitor(logging_interval='epoch')

In [35]:
checkpoint_callback = ModelCheckpoint(
    dirpath=CHECKPOINTS_PATH,
    filename='bert-large-{epoch:02d}-{val_loss:.2f}',
    mode='min',
)

### Trainer

In [36]:
trainer = pl.Trainer(
    accelerator='ddp',
    max_epochs=MAX_EPOCHS,
    logger=tb_logger,
    gpus=GPUS,
    callbacks=[checkpoint_callback, lr_monitor],
    precision=16,
    auto_scale_batch_size="binsearch"
)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
Using native 16bit precision.


In [37]:
to_train = LitBertClassifier(MODEL_NAME)

In [38]:
trainer.tune(to_train)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
initializing ddp: GLOBAL_RANK: 0, MEMBER: 1/1
  cpuset_checked))
Batch size 2 succeeded, trying batch size 4
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 4 succeeded, trying batch size 8
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 8 succeeded, trying batch size 16
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 16 succeeded, trying batch size 32
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 32 succeeded, trying batch size 64
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 64 succeeded, trying batch size 128
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 128 succeeded, trying batch size 256
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 256 failed, trying batch size 192
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 192 succeeded, trying batch size 224
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 224 failed, trying batch size 208
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Batch size 208 

In [None]:
trainer.fit(to_train, train_dataloader,val_dataloader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name | Type            | Params
-----------------------------------------
0 | bert | BertForMaskedLM | 312 M 
-----------------------------------------
312 M     Trainable params
0         Non-trainable params
312 M     Total params
1,249.444 Total estimated model params size (MB)
  cpuset_checked))


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]