In [1]:
import os
from transformers import BertTokenizer
from torch.utils.data import Dataset, DataLoader
import random
import torch

import sentencepiece as spm
import csv
import os
import re
import json
import glob
import numpy as np
import matplotlib.pyplot as plt
import unicodedata

import torch
import random
from sklearn.model_selection import train_test_split

from IPython.display import display

from tqdm.notebook import tqdm

from datetime import datetime

In [2]:
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

#device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device = torch.device("cpu")

## Prepare dataset

In [3]:
data_dir = "../data"

all_files_path = [os.path.join(data_dir, file_name) for file_name in os.listdir(data_dir)]

print(all_files_path)

['../data/text_52.txt', '../data/text_5.txt', '../data/text_14.txt', '../data/text_33.txt', '../data/text_47.txt', '../data/text_10.txt', '../data/text_42.txt', '../data/text_24.txt', '../data/text_25.txt', '../data/text_0.txt', '../data/text_67.txt', '../data/text_30.txt', '../data/text_29.txt', '../data/text_74.txt', '../data/text_55.txt', '../data/text_9.txt', '../data/text_70.txt', '../data/text_20.txt', '../data/text_59.txt', '../data/text_17.txt', '../data/text_50.txt', '../data/text_54.txt', '../data/text_2.txt', '../data/text_61.txt', '../data/text_3.txt', '../data/text_77.txt', '../data/text_75.txt', '../data/text_1.txt', '../data/text_72.txt', '../data/text_64.txt', '../data/text_53.txt', '../data/text_73.txt', '../data/text_46.txt', '../data/text_32.txt', '../data/text_58.txt', '../data/text_15.txt', '../data/text_49.txt', '../data/text_28.txt', '../data/text_39.txt', '../data/text_68.txt', '../data/text_78.txt', '../data/text_45.txt', '../data/text_62.txt', '../data/text_27

In [4]:
from nltk.tokenize import sent_tokenize

lines = []
with open(all_files_path[0], "r", encoding="utf-8") as f:
    for line in f:
        lines.append(line)

sent_tokenize(lines[0])

['only a smattering of bodies filled the rows of chairs set up facing the dais , and most of the folks sitting in them only rested there to fan their feet before launching themselves into the fair again .']

In [5]:
models_dir = '../models/tokenizer'
tokenizer = BertTokenizer.from_pretrained(f'{models_dir}/bert_tokenizer-vocab.txt', local_files_only=True)



In [None]:
class BERT_dataset(Dataset):
    def __init__(self, files_path, tokenizer, seq_len):
        self.files_path = files_path
        self.tokenizer = tokenizer
        self.seq_len = seq_len

        self.pairs = []

        for i in range(len(files_path)):
            lines = []
            with open(all_files_path[i], "r", encoding="utf-8") as f:
                count = 0
                for line in f:
                    lines.append(line)
                    count += 1
                    if count == 10000:
                        break
            
            for j in range(len(lines)-1):
                self.pairs.append((lines[j], lines[j+1]))

    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, index):
        # Step 1: get a random sentence pair, either negative or positive (saved as is_next_label)
        #         is_next=1 means the second sentence comes after the first one in the conversation.
        s1, s2, is_next = self.get_pair(index)

        # Step 2: replace random words in sentence with mask / random words
        masked_numericalized_s1, s1_mask = self.mask_sentence(s1)
        masked_numericalized_s2, s2_mask = self.mask_sentence(s2)

        # Step 3: Adding CLS and SEP tokens to the start and end of sentences
        # Adding PAD token for labels
        t1 = [self.tokenizer.vocab['[CLS]']] + masked_numericalized_s1 + [self.tokenizer.vocab['[SEP]']]
        t2 = masked_numericalized_s2 + [self.tokenizer.vocab['[SEP]']]
        t1_mask = [self.tokenizer.vocab['[PAD]']] + s1_mask + [self.tokenizer.vocab['[PAD]']]
        t2_mask = s2_mask + [self.tokenizer.vocab['[PAD]']]

        # Step 4: combine sentence 1 and 2 as one input
        # adding PAD tokens to make the sentence same length as seq_len
        segment_ids = ([0 for _ in range(len(t1))] + [1 for _ in range(len(t2))])[:self.seq_len]
        bert_input = (t1 + t2)[:self.seq_len]
        bert_label = (t1_mask + t2_mask)[:self.seq_len]
        padding = [self.tokenizer.vocab['[PAD]'] for _ in range(self.seq_len - len(bert_input))]
        bert_input.extend(padding), bert_label.extend(padding), segment_ids.extend(padding)

        # output = {
        #     "bert_input": bert_input,
        #     "bert_label": bert_label,
        #     "segment_ids": segment_ids,
        #     "is_next": is_next
        # }
        # return {key: torch.tensor(value) for key, value in output.items()}

        return (torch.tensor(bert_input),
                torch.tensor(bert_label),
                torch.tensor(segment_ids),
                torch.tensor(is_next))
        
    
    def get_pair(self, index):
        s1, s2 = self.pairs[index]
        is_next = 1
        if random.random() > 0.5:
            random_index = random.randrange(len(self.pairs))
            s2 = self.pairs[random_index][1]
            is_next = 0
        return s1, s2, is_next
    
    def mask_sentence(self, s):
        words = s.split()
        masked_numericalized_s = []
        mask = []
        for word in words:
            prob = random.random()
            token_ids = self.tokenizer(word)['input_ids'][1:-1]     # remove cls and sep token
            if prob < 0.15:                              # Mask out 15% of the words in the input
                prob /= 0.15
                for token_id in token_ids:                         # Iterate through token ids regardless of masking decision
                    if prob < 0.8:                          # Among 15 %, 80% will be replaced with the token 'Mask'
                        masked_numericalized_s.append(self.tokenizer.vocab['[MASK]'])
                    elif prob < 0.9:                        # Among 15%, 10% will be replaced with a random token
                        masked_numericalized_s.append(random.randrange(len(self.tokenizer.vocab)))
                    else:                                   # Among 15%, 10% will be left unchanged
                        masked_numericalized_s.append(token_id)   # Adding unchanged tokens
                    mask.append(token_id)                          # Mask label added for each token
            else:
                masked_numericalized_s.extend(token_ids)    # Adding tokens directly if not masked
                mask.extend([0] * len(token_ids))           # Corresponding unmasked labels

        assert len(masked_numericalized_s) == len(mask)
        return masked_numericalized_s, mask

In [151]:
test_dataset = BERT_dataset(all_files_path[:1], tokenizer, 100)

result = test_dataset.__getitem__(1)

print("bert_input :", result[0].size(), result[0])
print("bert_label :", result[1])
print("segment_ids :", result[2])
print("is_next :", result[-1])

print(tokenizer.convert_ids_to_tokens(result[0]))
print(tokenizer.convert_ids_to_tokens(result[1]))

NameError: name 'BERT_dataset' is not defined

In [8]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data.dataloader
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP

import math

import os

In [9]:
for i in range(torch.cuda.device_count()):
    print(f"GPU {i}: {torch.cuda.get_device_name(i)}")

GPU 0: NVIDIA GeForce GTX 970
GPU 1: NVIDIA GeForce GTX 970


In [9]:
seq_len = 175#100

n_enc_vocab = tokenizer.vocab_size
n_dec_vocab = n_enc_vocab
n_output = n_enc_vocab

n_enc_seq = seq_len           # json_encode_length
n_seg_type = 2
n_layers  = 6
hid_dim   = 256 # Taille des embeddings
pf_dim    = 1024
i_pad     = 0
n_heads   = 4#8
d_head    = 64
dropout   = 0.3
layer_norm_epsilon = 1e-12

batch_size = 20
learning_rate = 5e-5
num_epochs = 10

In [11]:
data_dir = "../data"

def get_batch(split, batch_size=batch_size, seq_len=seq_len):
    all_files_path = [os.path.join(data_dir, file_name) for file_name in os.listdir(data_dir)]
    if split == "train":
        data = BERT_dataset(all_files_path[:1], tokenizer, seq_len)
    else:
        data = BERT_dataset(all_files_path[1:2], tokenizer, seq_len)

    random_index = torch.randint(0, len(data) - batch_size, (batch_size,))

    batch = [data[i] for i in random_index]

    bert_inputs, bert_labels, segment_ids, is_next = zip(*batch)

    bert_inputs = torch.stack(bert_inputs)
    bert_labels = torch.stack(bert_labels)
    segment_ids = torch.stack(segment_ids)
    is_next = torch.stack(is_next)

    # TODO : Modification à apport à BERTdataset, pour que ce soit plus performant
    # bert_inputs = torch.stack([torch.tensor(data[i:i + seq_len][0]) for i in random_index])
    # bert_labels = torch.stack([torch.tensor(data[i + 1:i + 1 + seq_len][1]) for i in random_index])

    # if "caca" == 'cuda':
    #     #x, y = x.pin_memory().to(device, non_blocking=True), y.pin_memory().to(device, non_blocking=True)
    #     bert_inputs = bert_inputs.pin_memory().to(device, non_blocking=True)
    #     bert_labels = bert_labels.pin_memory().to(device, non_blocking=True)
    #     segment_ids = segment_ids.pin_memory().to(device, non_blocking=True)
    #     is_next = is_next.pin_memory().to(device, non_blocking=True)
    # else:
    #     bert_inputs = bert_inputs.to(device)
    #     bert_labels = bert_labels.to(device)
    #     segment_ids = segment_ids.to(device)
    #     is_next = is_next.to(device)

    return [bert_inputs, bert_labels, segment_ids, is_next]

## Modelling

In [12]:
class ScaledDotProductAttention(nn.Module):
    def __init__(self):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, query, key, value, mask):

        matmul_qk = torch.matmul(query, torch.transpose(key,2,3))

        dk = key.shape[-1]
        scaled_attention_logits = matmul_qk / math.sqrt(dk)

        if mask is not None:
            scaled_attention_logits += (mask * -1e9)

        attention_weights = F.softmax(scaled_attention_logits, dim=-1)
        
        output = torch.matmul(attention_weights, value)

        return output, attention_weights

In [13]:
class MultiHeadAttentionLayer(nn.Module):
    
    def __init__(self):
        super(MultiHeadAttentionLayer, self).__init__()
        
        # Define dense layers corresponding to WQ, WK, and WV
        self.query = nn.Linear(hid_dim, n_heads * d_head)
        self.key = nn.Linear(hid_dim, n_heads * d_head)
        self.value = nn.Linear(hid_dim, n_heads * d_head)
        self.scaled_dot_attn = ScaledDotProductAttention()
        self.dense = nn.Linear(n_heads * d_head, hid_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, Q, K, V, attn_mask):
        batch_size = Q.size(0)

        # 1. Pass through the dense layer corresponding to WQ
        # q : (bs, n_heads, n_q_seq, d_head)
        query = self.query(Q).view(batch_size, -1, n_heads, d_head).transpose(1,2)

        # 2. Pass through the dense layer corresponding to WK
        # k : (bs, n_heads, n_k_seq, d_head)
        key   = self.key(K).view(batch_size, -1, n_heads, d_head).transpose(1,2)
        
        # 3. Pass through the dense layer corresponding to WV
        # v : (bs, n_heads, n_v_seq, d_head)
        value = self.value(V).view(batch_size, -1, n_heads, d_head).transpose(1,2)

        # 4. Scaled Dot Product Attention. Using the previously implemented function
        # (bs, n_heads, n_q_seq, n_k_seq)
        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)

        # (bs, n_heads, n_q_seq, d_head), (bs, n_heads, n_q_seq, n_k_seq)
        scaled_attention, attn_prob = self.scaled_dot_attn(query, key, value, attn_mask)
        
        # 5. Concatenate the heads
        # (bs, n_heads, n_q_seq, h_head * d_head)
        concat_attention = scaled_attention.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_head)
        
        # 6. Pass through the dense layer corresponding to WO
        # (bs, n_heads, n_q_seq, e_embd)
        outputs = self.dense(concat_attention)
        outputs = self.dropout(outputs)
        # (bs, n_q_seq, hid_dim), (bs, n_heads, n_q_seq, n_k_seq)
        return outputs, attn_prob

In [14]:
class PositionwiseFeedforwardLayer(nn.Module):
    def __init__(self):
        super(PositionwiseFeedforwardLayer, self).__init__()
        self.linear_1 = nn.Linear(hid_dim, pf_dim)
        self.linear_2 = nn.Linear(pf_dim, hid_dim)

    def forward(self, attention):
        output = self.linear_1(attention)
        output = F.relu(output)
        output = self.linear_2(output)
        return output

In [15]:
class EncoderLayer(nn.Module):
    def __init__(self):
        super(EncoderLayer, self).__init__()
        
        self.attention = MultiHeadAttentionLayer()
        self.ffn = PositionwiseFeedforwardLayer()
        
        self.layernorm1 = nn.LayerNorm(hid_dim)
        self.layernorm2 = nn.LayerNorm(hid_dim)
        
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, inputs, padding_mask):
        
        # 1. Encoder mutihead attention is defined
        attention, attn_prob = self.attention(inputs, inputs, inputs, padding_mask)
        attention   = self.dropout1(attention)
        
        # 2. 1 st residual layer
        attention   = self.layernorm1(inputs + attention)  # (batch_size, input_seq_len, hid_dim)
        
        # 3. Feed Forward Network
        ffn_outputs = self.ffn(attention)  # (batch_size, input_seq_len, hid_dim)
        
        ffn_outputs = self.dropout2(ffn_outputs)
        
        # 4. 2 nd residual layer
        ffn_outputs = self.layernorm2(attention + ffn_outputs)  # (batch_size, input_seq_len, hid_dim)

        # 5. Encoder output of each encoder layer
        return ffn_outputs, attn_prob

In [16]:
""" attention pad mask """
def create_padding_mask(seq_q, seq_k, i_pad):
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    mask = seq_k.data.eq(i_pad).unsqueeze(1).expand(batch_size, len_q, len_k)  # <pad>
    return mask

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        
        self.word_embeddings = nn.Embedding(n_enc_vocab, hid_dim)
        self.position_embeddings = nn.Embedding(n_enc_seq + 1, hid_dim)
        self.token_type_embeddings = nn.Embedding(n_seg_type, hid_dim) #<------------------

        self.layer = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])
    
    def forward(self, inputs, segments):
        print("Encoder :",inputs.size(), segments.size())
        positions = torch.arange(inputs.size(1), device=inputs.device, dtype=inputs.dtype).expand(inputs.size(0), inputs.size(1)).contiguous() + 1
        pos_mask = inputs.eq(i_pad)
        positions.masked_fill_(pos_mask, 0)

        assert torch.all(inputs < self.word_embeddings.num_embeddings), \
            f"Indices in inputs exceed embedding size: max={inputs.max()}, num_embeddings={self.word_embeddings.num_embeddings}"
        assert torch.all(positions < self.position_embeddings.num_embeddings), \
            f"Indices in positions exceed embedding size: max={positions.max()}, num_embeddings={self.position_embeddings.num_embeddings}"
        assert torch.all(segments < self.token_type_embeddings.num_embeddings), \
            f"Indices in segments exceed embedding size: max={segments.max()}, num_embeddings={self.token_type_embeddings.num_embeddings}"

        # (bs, ENCODER_LEN, hid_dim)
        outputs = self.word_embeddings(inputs) + self.position_embeddings(positions)  + self.token_type_embeddings(segments)

        # (bs, ENCODER_LEN, ENCODER_LEN)
        attn_mask = create_padding_mask(inputs, inputs, i_pad)

        attn_probs = []
        for l in self.layer:
            # (bs, ENCODER_LEN, hid_dim), (bs, n_heads, ENCODER_LEN, ENCODER_LEN)
            outputs, attn_prob = l(outputs, attn_mask)
            attn_probs.append(attn_prob)
        # (bs, ENCODER_LEN, hid_dim), [(bs, n_heads, ENCODER_LEN, ENCODER_LEN)]
        return outputs, attn_probs

In [49]:
class BERT(nn.Module):
    def __init__(self):
        super(BERT, self).__init__()
        
        self.encoder = Encoder()

        self.linear = nn.Linear(hid_dim, hid_dim)
        self.activation = torch.tanh
    
    def forward(self, inputs, segments):
        # (bs, n_seq, hid_dim), [(bs, n_heads, n_enc_seq, n_enc_seq)]
        outputs, self_attn_probs = self.encoder(inputs, segments)
        # (bs, hid_dim)
        outputs_cls = outputs[:, 0].contiguous()
        outputs_cls = self.linear(outputs_cls)
        outputs_cls = self.activation(outputs_cls)
        # (bs, n_enc_seq, n_enc_vocab), (bs, hid_dim), [(bs, n_heads, n_enc_seq, n_enc_seq)]
        return outputs, outputs_cls, self_attn_probs
    
    def save(self, epoch, loss, path):
        torch.save({
            "epoch": epoch,
            "loss": loss,
            "state_dict": self.state_dict()
        }, path)
    
    def load(self, path):
        save = torch.load(path)
        self.load_state_dict(save["state_dict"])
        return save["epoch"], save["loss"]

In [50]:
""" Define Language Model Head """
class Language_Model_Head(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.bert = BERT()
        # classfier
        self.projection_cls = nn.Linear(hid_dim, 2, bias=False)
        # lm
        self.projection_lm = nn.Linear(hid_dim, n_output, bias=False)
        self.projection_lm.weight = self.bert.encoder.word_embeddings.weight
    
    def forward(self, inputs, segments):
        # (bs, n_enc_seq, hid_dim), (bs, hid_dim), [(bs, n_heads, n_enc_seq, n_enc_seq)]
        print("LM :", inputs.size(), segments.size())
        print(inputs.device.type)
        print(segments.device.type)

        # assert inputs.device.type == 'cuda'
        # assert segments.device.type == 'cuda'

        
        assert inputs.max() < n_enc_vocab
        
        
        outputs, outputs_cls, attn_probs = self.bert(inputs, segments)
        # (bs, 2)
        logits_cls = self.projection_cls(outputs_cls)
        # (bs, n_enc_seq, n_enc_vocab)
        logits_lm = self.projection_lm(outputs)
        # (bs, n_enc_vocab), (bs, n_enc_seq, n_enc_vocab), [(bs, n_heads, n_enc_seq, n_enc_seq)]
        return logits_cls, logits_lm, attn_probs

## Training

In [51]:
def L_M_collate(inputs):
    bert_input, bert_label, segment_ids, is_next = list(zip(*inputs))

    bert_input  = torch.nn.utils.rnn.pad_sequence(bert_input, batch_first=True, padding_value=0)
    bert_label = torch.nn.utils.rnn.pad_sequence(bert_label, batch_first=True, padding_value=0)
    segment_ids = torch.nn.utils.rnn.pad_sequence(segment_ids, batch_first=True, padding_value=0)
    # src_inputs, trg_outputs  = torch.nn.utils.rnn.pad_sequence([src_inputs,trg_outputs], batch_first=True, padding_value=0)

    batch = [
        bert_input,
        bert_label,
        segment_ids,
        torch.tensor(is_next)
    ]
    return batch

In [10]:
training_dataset = BERT_dataset(all_files_path[:1], tokenizer, seq_len)
training_loader = torch.utils.data.DataLoader(training_dataset, batch_size=batch_size, shuffle=True, collate_fn=L_M_collate)

validation_dataset = BERT_dataset(all_files_path[1:2], tokenizer, seq_len)
validation_loader = torch.utils.data.DataLoader(validation_dataset, batch_size=batch_size, shuffle=True, collate_fn=L_M_collate)

NameError: name 'L_M_collate' is not defined

In [53]:
# logits_lm contient les scores bruts (non normalisés) pour chaque tokens du vocabulaire, avec l'application softmax on obtient des probabilités de chaque token

def custom_loss(logits_cls, logits_lm, bert_label, is_next):
    loss_lm = criterion(logits_lm.view(-1, logits_lm.size(-1)), bert_label.view(-1))
    loss_nsp = criterion(logits_cls, is_next)

    return loss_lm + loss_nsp

In [61]:
# # Test avec training loader

# model = Language_Model_Head()
# model.to(device)

# bert_input, bert_label, segments_ids, is_next = next(iter(training_loader))
# bert_input = bert_input.to(device)
# bert_label = bert_label.to(device)

# with torch.no_grad():
#     model(bert_input, segments_ids)

In [60]:
# # Test avec get_batch

# bert_input_2, bert_label_2, segments_ids_2, is_next_2 = get_batch("train")
# model = Language_Model_Head()
# model.to(device)

# bert_input_2 = bert_input_2.to(device)
# bert_label_2 = bert_label_2.to(device)
# segments_ids_2 = segments_ids_2.to(device)
# is_next_2 = is_next.to(device)

# print(bert_input_2.device.type, bert_label_2.device.type, next(model.parameters()).device.type)

# with torch.no_grad():
#     model(bert_input_2, segments_ids_2)

In [23]:
def train_one_epoch(bar):
    running_loss = 0
    last_loss = 0

    for i, batch in enumerate(training_loader):

        # Retrieve data
        bert_input, bert_label, segment_ids, is_next = [data.to(device) for data in batch]

        # Zero the gradients for every batchs
        optimizer.zero_grad()
        print("start :", bert_input.size(), bert_label.size())
        # Make predictions for this batch
        logits_cls, logits_lm, attn_probs = model(bert_input, segment_ids) # (bs, n_enc_vocab), (bs, n_enc_seq, n_enc_vocab), [(bs, n_heads, n_enc_seq, n_enc_seq)]
        
        # Compute losse(s)
        loss = custom_loss(logits_cls, logits_lm, bert_label, is_next)

        # Compute its gradients
        loss.backward()

        # Adjust learning weights
        optimizer.step()

        # Gather data and report
        running_loss += loss.item()

        bar.update(1)

        if i % 10 == 9:
            last_loss = running_loss / 10
            
            bar.set_postfix_str(f"loss: {last_loss:.3f}")

            running_loss = 0
        
        bar.set_postfix_str(f"loss: {last_loss:.3f}")

    return last_loss

def val_one_epoch(bar):
    last_vloss = 0
    running_vloss = 0

    with torch.no_grad():
        for i, vbatch in enumerate(validation_loader):
            vbert_input, vbert_label, vsegment_ids, vis_next = [vdata.to(device) for vdata in vbatch]

            vlogits_cls, vlogits_lm, attn_probs = model(vbert_input, vsegment_ids)

            vloss = custom_loss(vlogits_cls, vlogits_lm, vbert_label, vis_next)

            running_vloss += vloss
            bar.update(1)

            if i % 10 == 9:
                last_vloss = running_vloss / 10
                bar.set_postfix_str(f"val_loss: {last_vloss:.3f}")
                running_vloss = 0
    return last_vloss

In [None]:
SAVE_MODELS_PATH = r"../models/bert-chkpts"

model_files = [
    os.path.join(SAVE_MODELS_PATH, f) 
    for f in os.listdir(SAVE_MODELS_PATH)
    if os.path.isfile(os.path.join(SAVE_MODELS_PATH, f))
]

In [25]:
NEW_TRAINING = False
DISTRIBUTED = True

SAVE_MODELS_PATH = r"../models/bert-chkpts"

model = Language_Model_Head()
# if DISTRIBUTED:
#     model = torch.nn.DataParallel(model)
model.to(device)

criterion = torch.nn.CrossEntropyLoss(ignore_index=i_pad)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
patience = 3

start_epoch = 1
patience_counter = 0
best_vloss = np.inf

if not NEW_TRAINING:
    try:
        model_files = [
            os.path.join(SAVE_MODELS_PATH, f) 
            for f in os.listdir(SAVE_MODELS_PATH) 
            if os.path.isfile(os.path.join(SAVE_MODELS_PATH, f))
        ]

        lastest_model_path = max(model_files, key=os.path.getctime)
        checkpoint = torch.load(lastest_model_path, weights_only=True)

        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

        start_epoch = checkpoint["epoch"] + 1
        patience_counter = checkpoint["patience_counter"]
        best_vloss = checkpoint["best_vloss"]

    except ValueError:
        print("No checkpoint found, ... starting new training")
    except FileNotFoundError:
        print("No checkpoint found, ... starting new training")
    except Exception as e:
        print(f"Error loading checkpoint: {e}")

for epoch in range(start_epoch, start_epoch + num_epochs + 1):
    timestamp = datetime.now().strftime("%Y%m%d_%H%M")
    
    with tqdm(total=len(training_loader), desc=f"Epoch {epoch}/{start_epoch + num_epochs}", unit="batch") as train_bar:
        model.train()
        avg_tloss = train_one_epoch(train_bar)

        with tqdm(total=len(validation_loader), desc=f"Validation", unit="batch", leave=False) as val_bar:
            model.eval()
            avg_vloss = val_one_epoch(val_bar)
        train_bar.set_postfix_str(f"loss: {avg_tloss:.3f} - val_loss: {avg_vloss:.3f}")

    if avg_vloss < best_vloss:
        best_vloss = avg_vloss
        model_path = f"bert-small_epoch={epoch}_{timestamp}.pth"
        torch.save({
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "patience_counter": patience_counter,
            "best_vloss": best_vloss
        }, os.path.join(SAVE_MODELS_PATH, model_path))
    else:
        patience_counter += 1
        if patience_counter >= patience:
            break

No checkpoint found, ... starting new training


Epoch 1/11:   0%|          | 0/500 [00:00<?, ?batch/s]

start : torch.Size([20, 175]) torch.Size([20, 175])
LM : torch.Size([20, 175]) torch.Size([20, 175])
cuda
cuda
Encoder : torch.Size([20, 175]) torch.Size([20, 175])
start : torch.Size([20, 175]) torch.Size([20, 175])
LM : torch.Size([20, 175]) torch.Size([20, 175])
cuda
cuda
Encoder : torch.Size([20, 175]) torch.Size([20, 175])
start : torch.Size([20, 175]) torch.Size([20, 175])
LM : torch.Size([20, 175]) torch.Size([20, 175])
cuda
cuda
Encoder : torch.Size([20, 175]) torch.Size([20, 175])
start : torch.Size([20, 175]) torch.Size([20, 175])
LM : torch.Size([20, 175]) torch.Size([20, 175])
cuda
cuda
Encoder : torch.Size([20, 175]) torch.Size([20, 175])
start : torch.Size([20, 175]) torch.Size([20, 175])
LM : torch.Size([20, 175]) torch.Size([20, 175])
cuda
cuda
Encoder : torch.Size([20, 175]) torch.Size([20, 175])
start : torch.Size([20, 175]) torch.Size([20, 175])
LM : torch.Size([20, 175]) torch.Size([20, 175])
cuda
cuda
Encoder : torch.Size([20, 175]) torch.Size([20, 175])
start : to

KeyboardInterrupt: 

In [49]:
for name, param in model.named_parameters():
    print(f"{name}: {param.shape}")

module.bert.encoder.word_embeddings.weight: torch.Size([30000, 256])
module.bert.encoder.position_embeddings.weight: torch.Size([176, 256])
module.bert.encoder.token_type_embeddings.weight: torch.Size([2, 256])
module.bert.encoder.layer.0.attention.query.weight: torch.Size([256, 256])
module.bert.encoder.layer.0.attention.query.bias: torch.Size([256])
module.bert.encoder.layer.0.attention.key.weight: torch.Size([256, 256])
module.bert.encoder.layer.0.attention.key.bias: torch.Size([256])
module.bert.encoder.layer.0.attention.value.weight: torch.Size([256, 256])
module.bert.encoder.layer.0.attention.value.bias: torch.Size([256])
module.bert.encoder.layer.0.attention.dense.weight: torch.Size([256, 256])
module.bert.encoder.layer.0.attention.dense.bias: torch.Size([256])
module.bert.encoder.layer.0.ffn.linear_1.weight: torch.Size([1024, 256])
module.bert.encoder.layer.0.ffn.linear_1.bias: torch.Size([1024])
module.bert.encoder.layer.0.ffn.linear_2.weight: torch.Size([256, 1024])
module.be