In [1]:
import os
from pathlib import Path
import numpy as np
import pandas as pd
import sys

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.nn.utils.rnn import pad_sequence

import time

DATASET_PATH = Path("./data/text")

VOC_SIZE = 1000

def load_data(datapath, max_size=None):
    texts_files = list(datapath.glob("*.txt"))
    texts = []  
    for files in texts_files:
        with open(files, "r", encoding='utf8') as files:
            text = files.readlines()
            texts += text
    texts = list(set(texts))
    
    return texts

texts = load_data(DATASET_PATH)

from tokenizers import Tokenizer
# from transformers import AutoTokenizer
from transformers import BertTokenizer, BertForMaskedLM
from tokenizers.models import WordPiece
from tokenizers.trainers import WordPieceTrainer
from tokenizers.pre_tokenizers import Whitespace


tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
# model_checkpoint = "distilgpt2"
# tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
tokenizer.pre_tokenizer = Whitespace()
model = BertForMaskedLM.from_pretrained('bert-base-uncased')

inputs = tokenizer(texts, return_tensors='pt', max_length=100
                   , truncation=True, padding='max_length')

inputs['labels'] = inputs.input_ids.detach().clone()

rand = torch.rand(inputs.input_ids.shape)

BertForMaskedLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.
  - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).
  - If you are not the owner of the model architecture class, please contact the model code owner to update it.
Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['bert.pooler.dense.bias', 'bert.pooler.dense.weight', 'cls.seq_relationship.bias', 'cls.seq_relationship.weight']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another archite

In [2]:
mask_arr = (rand < 0.15) * (inputs.input_ids != 101) * (inputs.input_ids != 102) * (inputs.input_ids != 0)

In [3]:
import random as rd

In [4]:
selection = []

for i in range(mask_arr.shape[0]):
    selection.append(
        torch.flatten(mask_arr[i].nonzero()).tolist()
        )

selection[:5]

for i in range(mask_arr.shape[0]):
    inputs.input_ids[i, selection[i]] = 103 # application du token [MASK]

class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, idx):
        self.encodings = encodings
        self.idx = idx
        self.encodings = {key: [val[i] for i in self.idx] for key, val in self.encodings.items()}
        
    def __getitem__(self, idx):
        return {key : torch.tensor(val[idx]) for key, val in self.encodings.items()}
    
    def __len__(self):
        return len(self.encodings['input_ids'])

sample_idx = [i for i in range(len(inputs.input_ids))]

shuffled_sample_idx = rd.sample(sample_idx, len(sample_idx))

train_idx = shuffled_sample_idx[:int(0.70*len(shuffled_sample_idx))]
val_idx = shuffled_sample_idx[int(0.70*len(shuffled_sample_idx)):int(0.85*len(shuffled_sample_idx))]
test_idx = shuffled_sample_idx[int(0.85*len(shuffled_sample_idx)):]
                                
dataset_train = CustomDataset(inputs, train_idx)
dataset_val = CustomDataset(inputs, val_idx)
dataset_test = CustomDataset(inputs, test_idx)

train_dataloaded = torch.utils.data.DataLoader(dataset_train, batch_size=16, shuffle=True)
val_dataloaded = torch.utils.data.DataLoader(dataset_val, batch_size=16, shuffle=True)
test_dataloaded = torch.utils.data.DataLoader(dataset_test, batch_size=16, shuffle=True)

In [5]:
#class MLM_model(nn.Module):
#    def __init__(self, model):
#        super(MLM_model, self).__init__()
#        self.history = {"epochs":[], "test":[]}
#        self.model = model
    
#    def parameters(self):
#        return self.model.parameters()

#    def forward(self, x, attention_mask, labels):
#        return self.model(x, attention_mask, labels)
    
#    def train_log(self, train_batch_losses, val_batch_losses, train_loss, validation_loss):
#        self.history["epochs"].append({"train_batch_losses":train_batch_losses, 
#                                "val_batch_losses":val_batch_losses, 
#                                "train_loss":train_loss, 
#                                "validation_loss":validation_loss})
    
#    def test_log(self, test_batch_losses, test_loss):
#        self.history["test"].append({"test_batch_losses":test_batch_losses,
#                                "test_loss":test_loss})

In [6]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
#model = MLM_model(model)
model.to(device)
print(device)

cuda


In [7]:
def train_step(module, batch, batch_idx, optimizer):
    module.train(True)
    
    inputs_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    labels = batch['labels'].to(device)
    
    outputs = module(inputs_ids, attention_mask, labels=labels)
    
    loss = outputs.loss
    print(f"\n\033[1;37mBatch loss {batch_idx+1} : {loss.item()}")
    loss.backward()
    
    torch.nn.utils.clip_grad_norm_(module.parameters(), max_norm=1.0)
    optimizer.step()
    optimizer.zero_grad()
    
    return module, loss

def eval_step(module, batch, batch_idx, optimizer=None, training=True):
    with torch.no_grad():
        
        inputs_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['labels'].to(device)
    
        outputs = module(inputs_ids, attention_mask, labels=labels)
    
        loss = outputs.loss
         
        if training:
            print(f"\n\033[1;32mValidation Batch loss {batch_idx+1} : {loss.item()}")
            return module, loss
        else:
            print(f"\n\033[1;32mTest Batch loss {batch_idx+1} : {loss.item()}")
            return module, loss, outputs, labels

def train_loop(module, EPOCHS, train_dataset, val_dataset, optimizer, lr_scheduler=None):
    for epoch in range(EPOCHS):
        deb=time.time()
        
        module.train(True)
        
        train_batch_losses = []
        for batch_idx in range(len(train_dataset)):
            batch = next(iter(train_dataset))
            module, loss = train_step(module, batch, batch_idx, optimizer)
            train_batch_losses.append(loss.item())
            
        if lr_scheduler is not None:
          lr_scheduler.step()
        train_loss = np.mean(train_batch_losses)

        module.train(False)
        val_batch_losses = []
        for batch_idx in range(len(val_dataset)):
            batch = next(iter(val_dataset))
            module, loss = eval_step(module, batch, batch_idx)
            val_batch_losses.append(loss.item())
        val_loss = np.mean(val_batch_losses)

#        module.train_log(train_batch_losses, val_batch_losses, train_loss, val_loss)
        print(f"\n\033[1;33mEpoch {epoch+1} :\n\033[1;37mTraining Loss : {train_loss}")
        print(f"\033[1;32mValidation Loss : {val_loss}")
        print(f"\033[1;31mDurée epoch : {time.time()-deb} secondes")
    return module

def evaluate(module, test_dataset):
    module.train(False)
    test_batch_losses = []
    predictions = []
    true_targets = []
    for batch_idx in range(len(test_dataset)):
        batch = next(iter(test_dataset))
        module, loss, outputs, labels = eval_step(module, batch, batch_idx, training=False)

        test_batch_losses.append(loss.item())
        predictions.append(outputs)
        true_targets.append(labels)

    test_loss = np.mean(test_batch_losses)
#    module.test_log(test_batch_losses, test_loss)
    print(f"\nTest Loss : {test_loss}")
    return predictions, true_targets

In [8]:
if __name__ == "__main__":
    EPOCHS = 1
    LR = 1e-4
    
    optimizer = torch.optim.Adam(model.parameters(), lr=LR, eps=5e-8)
    module = train_loop(module=model,
                        EPOCHS=EPOCHS, 
                        train_dataset=train_dataloaded, 
                        val_dataset=val_dataloaded,
                        optimizer=optimizer)
    predictions, true_targets = evaluate(module, 
                                         test_dataloaded)



  return {key : torch.tensor(val[idx]) for key, val in self.encodings.items()}



[1;37mBatch loss 1 : 11.209537506103516

[1;37mBatch loss 2 : 7.477960109710693

[1;37mBatch loss 3 : 5.865095615386963

[1;37mBatch loss 4 : 5.1089959144592285

[1;37mBatch loss 5 : 4.225607395172119

[1;37mBatch loss 6 : 3.7812817096710205

[1;37mBatch loss 7 : 3.0029916763305664

[1;37mBatch loss 8 : 2.6055808067321777

[1;37mBatch loss 9 : 2.048874855041504

[1;37mBatch loss 10 : 1.6938650608062744

[1;37mBatch loss 11 : 1.3665978908538818

[1;37mBatch loss 12 : 1.118104100227356

[1;37mBatch loss 13 : 0.947327196598053

[1;37mBatch loss 14 : 0.797278881072998

[1;37mBatch loss 15 : 0.6209564805030823

[1;37mBatch loss 16 : 0.47918128967285156

[1;37mBatch loss 17 : 0.4530010223388672

[1;37mBatch loss 18 : 0.38324710726737976

[1;37mBatch loss 19 : 0.3482261598110199

[1;37mBatch loss 20 : 0.39186739921569824

[1;37mBatch loss 21 : 0.3628880977630615

[1;37mBatch loss 22 : 0.16519016027450562

[1;37mBatch loss 23 : 0.2534430921077728

[1;37mBatch loss 24 :


[1;37mBatch loss 187 : 0.12530499696731567

[1;37mBatch loss 188 : 0.15366123616695404

[1;37mBatch loss 189 : 0.0313909687101841

[1;37mBatch loss 190 : 0.07675524801015854

[1;37mBatch loss 191 : 0.050781283527612686

[1;37mBatch loss 192 : 0.03863641992211342

[1;37mBatch loss 193 : 0.014467110857367516

[1;37mBatch loss 194 : 0.10684517025947571

[1;37mBatch loss 195 : 0.022875027731060982

[1;37mBatch loss 196 : 0.06312664598226547

[1;37mBatch loss 197 : 0.07000399380922318

[1;37mBatch loss 198 : 0.050838932394981384

[1;37mBatch loss 199 : 0.03009534440934658

[1;37mBatch loss 200 : 0.06679247319698334

[1;37mBatch loss 201 : 0.09468068927526474

[1;37mBatch loss 202 : 0.08365478366613388

[1;37mBatch loss 203 : 0.082374706864357

[1;37mBatch loss 204 : 0.10345688462257385

[1;37mBatch loss 205 : 0.0798397958278656

[1;37mBatch loss 206 : 0.06753335148096085

[1;37mBatch loss 207 : 0.08342394977807999

[1;37mBatch loss 208 : 0.04932736977934837

[1;37mBa


[1;37mBatch loss 369 : 0.05187772214412689

[1;37mBatch loss 370 : 0.05368535965681076

[1;37mBatch loss 371 : 0.06365494430065155

[1;37mBatch loss 372 : 0.07361477613449097

[1;37mBatch loss 373 : 0.07380446791648865

[1;37mBatch loss 374 : 0.03696180880069733

[1;37mBatch loss 375 : 0.06576213985681534

[1;37mBatch loss 376 : 0.08993874490261078

[1;37mBatch loss 377 : 0.03941306471824646

[1;37mBatch loss 378 : 0.0406225211918354

[1;37mBatch loss 379 : 0.03402959927916527

[1;37mBatch loss 380 : 0.03676323592662811

[1;37mBatch loss 381 : 0.04405699670314789

[1;37mBatch loss 382 : 0.03346351161599159

[1;37mBatch loss 383 : 0.029890011996030807

[1;37mBatch loss 384 : 0.04372113570570946

[1;37mBatch loss 385 : 0.030780578032135963

[1;37mBatch loss 386 : 0.07434301823377609

[1;37mBatch loss 387 : 0.034444235265254974

[1;37mBatch loss 388 : 0.05857133865356445

[1;37mBatch loss 389 : 0.07760950922966003

[1;37mBatch loss 390 : 0.039407290518283844

[1;37


[1;37mBatch loss 550 : 0.021675825119018555

[1;37mBatch loss 551 : 0.066909559071064

[1;37mBatch loss 552 : 0.062127407640218735

[1;37mBatch loss 553 : 0.028630787506699562

[1;37mBatch loss 554 : 0.01588056981563568

[1;37mBatch loss 555 : 0.030163416638970375

[1;37mBatch loss 556 : 0.02417643368244171

[1;37mBatch loss 557 : 0.01817144826054573

[1;37mBatch loss 558 : 0.08265826851129532

[1;37mBatch loss 559 : 0.04748556762933731

[1;37mBatch loss 560 : 0.05598694831132889

[1;37mBatch loss 561 : 0.025976872071623802

[1;37mBatch loss 562 : 0.01728919893503189

[1;37mBatch loss 563 : 0.05843299254775047

[1;37mBatch loss 564 : 0.02100163884460926

[1;37mBatch loss 565 : 0.024281740188598633

[1;37mBatch loss 566 : 0.03241710364818573

[1;37mBatch loss 567 : 0.03737618774175644

[1;37mBatch loss 568 : 0.04886927083134651

[1;37mBatch loss 569 : 0.03541106358170509

[1;37mBatch loss 570 : 0.04875022545456886

[1;37mBatch loss 571 : 0.025541499257087708

[1;


[1;37mBatch loss 731 : 0.055251315236091614

[1;37mBatch loss 732 : 0.018841160461306572

[1;37mBatch loss 733 : 0.02429475262761116

[1;37mBatch loss 734 : 0.07487328350543976

[1;37mBatch loss 735 : 0.008272629231214523

[1;37mBatch loss 736 : 0.01778574287891388

[1;37mBatch loss 737 : 0.07151919603347778

[1;37mBatch loss 738 : 0.03248186782002449

[1;37mBatch loss 739 : 0.019238600507378578

[1;37mBatch loss 740 : 0.06709632277488708

[1;37mBatch loss 741 : 0.02545974776148796

[1;37mBatch loss 742 : 0.01780567690730095

[1;37mBatch loss 743 : 0.02981787919998169

[1;37mBatch loss 744 : 0.0291164368391037

[1;37mBatch loss 745 : 0.018368778750300407

[1;37mBatch loss 746 : 0.04206512123346329

[1;37mBatch loss 747 : 0.031603943556547165

[1;37mBatch loss 748 : 0.04350588843226433

[1;37mBatch loss 749 : 0.04231448099017143

[1;37mBatch loss 750 : 0.016338685527443886

[1;37mBatch loss 751 : 0.05896590277552605

[1;37mBatch loss 752 : 0.020958945155143738

[


[1;32mValidation Batch loss 120 : 0.06307360529899597

[1;32mValidation Batch loss 121 : 0.04680122807621956

[1;32mValidation Batch loss 122 : 0.09780380874872208

[1;32mValidation Batch loss 123 : 0.06536748260259628

[1;32mValidation Batch loss 124 : 0.024821382015943527

[1;32mValidation Batch loss 125 : 0.0932290181517601

[1;32mValidation Batch loss 126 : 0.025788627564907074

[1;32mValidation Batch loss 127 : 0.022515052929520607

[1;32mValidation Batch loss 128 : 0.047151703387498856

[1;32mValidation Batch loss 129 : 0.04077548906207085

[1;32mValidation Batch loss 130 : 0.023479558527469635

[1;32mValidation Batch loss 131 : 0.005070294253528118

[1;32mValidation Batch loss 132 : 0.06516806036233902

[1;32mValidation Batch loss 133 : 0.024918172508478165

[1;32mValidation Batch loss 134 : 0.06111948937177658

[1;32mValidation Batch loss 135 : 0.050059959292411804

[1;32mValidation Batch loss 136 : 0.053953852504491806

[1;32mValidation Batch loss 137 : 0.02


[1;37mBatch loss 125 : 0.015208082273602486

[1;37mBatch loss 126 : 0.04468028247356415

[1;37mBatch loss 127 : 0.017639774829149246

[1;37mBatch loss 128 : 0.02004462666809559

[1;37mBatch loss 129 : 0.04660146310925484

[1;37mBatch loss 130 : 0.04439306631684303

[1;37mBatch loss 131 : 0.04511604458093643

[1;37mBatch loss 132 : 0.0393248125910759

[1;37mBatch loss 133 : 0.03694678843021393

[1;37mBatch loss 134 : 0.017247909680008888

[1;37mBatch loss 135 : 0.036656517535448074

[1;37mBatch loss 136 : 0.006502432748675346

[1;37mBatch loss 137 : 0.043192170560359955

[1;37mBatch loss 138 : 0.02335420809686184

[1;37mBatch loss 139 : 0.04179644212126732

[1;37mBatch loss 140 : 0.030298003926873207

[1;37mBatch loss 141 : 0.06303588300943375

[1;37mBatch loss 142 : 0.01520103681832552

[1;37mBatch loss 143 : 0.012697921134531498

[1;37mBatch loss 144 : 0.02362086810171604

[1;37mBatch loss 145 : 0.013784656301140785

[1;37mBatch loss 146 : 0.030276959761977196




[1;37mBatch loss 306 : 0.024215171113610268

[1;37mBatch loss 307 : 0.017945365980267525

[1;37mBatch loss 308 : 0.02387925423681736

[1;37mBatch loss 309 : 0.0337100587785244

[1;37mBatch loss 310 : 0.03290671855211258

[1;37mBatch loss 311 : 0.03529256209731102

[1;37mBatch loss 312 : 0.02170977182686329

[1;37mBatch loss 313 : 0.002133465139195323

[1;37mBatch loss 314 : 0.024184202775359154

[1;37mBatch loss 315 : 0.022522881627082825

[1;37mBatch loss 316 : 0.011701328679919243

[1;37mBatch loss 317 : 0.02926434576511383

[1;37mBatch loss 318 : 0.05143081769347191

[1;37mBatch loss 319 : 0.01325247809290886

[1;37mBatch loss 320 : 0.013901717960834503

[1;37mBatch loss 321 : 0.01847996935248375

[1;37mBatch loss 322 : 0.01997978240251541

[1;37mBatch loss 323 : 0.014888517558574677

[1;37mBatch loss 324 : 0.017300819978117943

[1;37mBatch loss 325 : 0.06813787668943405

[1;37mBatch loss 326 : 0.03943759575486183

[1;37mBatch loss 327 : 0.01522519439458847




[1;37mBatch loss 487 : 0.013446156866848469

[1;37mBatch loss 488 : 0.017804134637117386

[1;37mBatch loss 489 : 0.026222912594676018

[1;37mBatch loss 490 : 0.003959294408559799

[1;37mBatch loss 491 : 0.006850842386484146

[1;37mBatch loss 492 : 0.01488605048507452

[1;37mBatch loss 493 : 0.03676365688443184

[1;37mBatch loss 494 : 0.0116676464676857

[1;37mBatch loss 495 : 0.022349173203110695

[1;37mBatch loss 496 : 0.004219611641019583

[1;37mBatch loss 497 : 0.03632960468530655

[1;37mBatch loss 498 : 0.02508026920258999

[1;37mBatch loss 499 : 0.043021880090236664

[1;37mBatch loss 500 : 0.005732561461627483

[1;37mBatch loss 501 : 0.043354764580726624

[1;37mBatch loss 502 : 0.00412213196977973

[1;37mBatch loss 503 : 0.00475327717140317

[1;37mBatch loss 504 : 0.0047193728387355804

[1;37mBatch loss 505 : 0.021712999790906906

[1;37mBatch loss 506 : 0.02310110256075859

[1;37mBatch loss 507 : 0.002991974353790283

[1;37mBatch loss 508 : 0.010022648610174


[1;37mBatch loss 667 : 0.012398394756019115

[1;37mBatch loss 668 : 0.010435440577566624

[1;37mBatch loss 669 : 0.02401369996368885

[1;37mBatch loss 670 : 0.0383731834590435

[1;37mBatch loss 671 : 0.04592292383313179

[1;37mBatch loss 672 : 0.01859920471906662

[1;37mBatch loss 673 : 0.0030653325375169516

[1;37mBatch loss 674 : 0.025405151769518852

[1;37mBatch loss 675 : 0.02512255124747753

[1;37mBatch loss 676 : 0.030344026163220406

[1;37mBatch loss 677 : 0.03798877075314522

[1;37mBatch loss 678 : 0.0030785587150603533

[1;37mBatch loss 679 : 0.011666737496852875

[1;37mBatch loss 680 : 0.0186435729265213

[1;37mBatch loss 681 : 0.02460787259042263

[1;37mBatch loss 682 : 0.03706829622387886

[1;37mBatch loss 683 : 0.030917629599571228

[1;37mBatch loss 684 : 0.04480523243546486

[1;37mBatch loss 685 : 0.015188842080533504

[1;37mBatch loss 686 : 0.01159312017261982

[1;37mBatch loss 687 : 0.0016787860076874495

[1;37mBatch loss 688 : 0.01489298976957798


[1;32mValidation Batch loss 67 : 0.03046923689544201

[1;32mValidation Batch loss 68 : 0.028673682361841202

[1;32mValidation Batch loss 69 : 0.019565805792808533

[1;32mValidation Batch loss 70 : 0.03403336554765701

[1;32mValidation Batch loss 71 : 0.035757094621658325

[1;32mValidation Batch loss 72 : 0.015925507992506027

[1;32mValidation Batch loss 73 : 0.06021106615662575

[1;32mValidation Batch loss 74 : 0.04014259949326515

[1;32mValidation Batch loss 75 : 0.05877732113003731

[1;32mValidation Batch loss 76 : 0.029679028317332268

[1;32mValidation Batch loss 77 : 0.018179161474108696

[1;32mValidation Batch loss 78 : 0.0364040769636631

[1;32mValidation Batch loss 79 : 0.06498680263757706

[1;32mValidation Batch loss 80 : 0.06645414978265762

[1;32mValidation Batch loss 81 : 0.06210671737790108

[1;32mValidation Batch loss 82 : 0.03697787597775459

[1;32mValidation Batch loss 83 : 0.030623607337474823

[1;32mValidation Batch loss 84 : 0.0270727276802063

[1;

OutOfMemoryError: CUDA out of memory. Tried to allocate 188.00 MiB. GPU 0 has a total capacity of 4.00 GiB of which 0 bytes is free. Of the allocated memory 10.16 GiB is allocated by PyTorch, and 336.40 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [10]:
print(torch.cuda.memory_stats())

OrderedDict({'active.all.allocated': 2189543, 'active.all.current': 759, 'active.all.freed': 2188784, 'active.all.peak': 1024, 'active.large_pool.allocated': 1020003, 'active.large_pool.current': 277, 'active.large_pool.freed': 1019726, 'active.large_pool.peak': 381, 'active.small_pool.allocated': 1169540, 'active.small_pool.current': 482, 'active.small_pool.freed': 1169058, 'active.small_pool.peak': 721, 'active_bytes.all.allocated': 8415742742016, 'active_bytes.all.current': 10911057408, 'active_bytes.all.freed': 8404831684608, 'active_bytes.all.peak': 10917021184, 'active_bytes.large_pool.allocated': 8239346608128, 'active_bytes.large_pool.current': 10908502528, 'active_bytes.large_pool.freed': 8228438105600, 'active_bytes.large_pool.peak': 10913417728, 'active_bytes.small_pool.allocated': 176396133888, 'active_bytes.small_pool.current': 2554880, 'active_bytes.small_pool.freed': 176393579008, 'active_bytes.small_pool.peak': 13091328, 'allocated_bytes.all.allocated': 8415742742016, '

In [None]:
print(inputs.input_ids.max())
print(inputs.input_ids.min())