### Генерация аккордов в последовательности

Задание заключается в написании модели BERT для генерации гитарных аккордов.

В качестве данных взят датасет https://huggingface.co/datasets/lluccardoner/melodyGPT-song-chords-text-1.

В этом ноутбуке присутствует:

1) Очистка данных 

2) Аугментация

3) Токенизатор

4) Обучение модели

In [None]:
# !pip install transformers datasets torch scikit-learn matplotlib seaborn streamlit plotly

In [None]:
import numpy as np
import polars as pl
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from tqdm import tqdm
import re
from typing import List, Tuple, Optional, Mapping,  Self, NamedTuple
import pandas as pd
from datasets import load_dataset

import warnings
warnings.filterwarnings('ignore')


In [None]:
def load_and_clean_dataset():

    dataset = load_dataset("lluccardoner/melodyGPT-song-chords-text-1")
    df = dataset['train'].to_pandas()

    def clean_chords(text):
        if not isinstance(text, str):
            return ""

        text = re.sub(r'\t', ' ', text)
        text = re.sub(r'\b(INTRO|VERSE|CHORUS|BRIDGE|OUTRO|SOLO|PRE-CHORUS)\b', '', text, flags=re.IGNORECASE)

        chords = re.findall(r'[A-G][#b]?(?:m|maj|min|aug|dim|sus2|sus4)?(?:add)?[0-9]*(?:\([0-9]+\))?(?:\/[A-G][#b]?)?', text.upper())

        return ' '.join(chords)

    df['cleaned_chords'] = df['chords_str'].apply(clean_chords)

    df = df[df['cleaned_chords'].str.len() > 0]

    return df


class ChordTransposer:

    def __init__(self):
        self.chord_map = {
            'C': 0, 'C#': 1, 'Db': 1, 'D': 2, 'D#': 3, 'Eb': 3,
            'E': 4, 'F': 5, 'F#': 6, 'Gb': 6, 'G': 7, 'G#': 8,
            'Ab': 8, 'A': 9, 'A#': 10, 'Bb': 10, 'B': 11
        }
        self.reverse_map = {v: k for k, v in self.chord_map.items()}

        self.chord_patterns = [
            # Аккорды с добавлениями в скобках: A5(9), C7(11)
            r'([A-G][#b]?)(.*)(\(.*\))',
            # Аккорды с add: Cadd9, Dmadd11
            r'([A-G][#b]?)(.*)(add.*)',
            # Аккорды с sus: Dsus4, Asus2
            r'([A-G][#b]?)(.*)(sus.*)',
            # Септаккорды и расширения: Cmaj7, G9, Am11
            r'([A-G][#b]?)(.*)(\d+)',
            # Базовые аккорды: C, Dm, E7
            r'([A-G][#b]?)(.*)'
        ]

    def transpose_chord(self, chord, steps):

        for pattern in self.chord_patterns:
            match = re.match(pattern, chord)
            if match:
                groups = match.groups()
                tone = groups[0]

                if tone in self.chord_map:

                    new_tone_num = (self.chord_map[tone] + steps) % 12
                    new_tone = self.reverse_map[new_tone_num]

                    if len(groups) == 3:
                        quality = groups[1] or ''
                        extension = groups[2] or ''
                        return new_tone + quality + extension
                    else:
                        quality = groups[1] or ''
                        return new_tone + quality

                break

        return chord

    def transpose_sequence(self, chord_sequence, steps):
        chords = chord_sequence.split()
        transposed = [self.transpose_chord(chord, steps) for chord in chords]
        return ' '.join(transposed)

def augment_dataset(df, num_transpositions=5):
    transposer = ChordTransposer()
    augmented_data = []

    for _, row in df.iterrows():
        chords = row['cleaned_chords']
        if len(chords.split()) < 3:
            continue

        augmented_data.append({
            'cleaned_chords': chords,
            'genres': row.get('genres'),
            'artist_name': row.get('artist_name', 'unknown'),
            'song_name': row.get('song_name', 'unknown')
        })

        for steps in range(1, num_transpositions + 1):
            transposed = transposer.transpose_sequence(chords, steps)
            augmented_data.append({
                'cleaned_chords': transposed,
                'genres': row.get('genres'),
                'artist_name': row.get('artist_name', 'unknown'),
                'song_name': row.get('song_name', 'unknown')
            })

    return pd.DataFrame(augmented_data)

In [3]:
dataset = load_and_clean_dataset()

Аугментированный датасет, для обучения берется обычный датасет

In [4]:
augmented_dataset = augment_dataset(dataset)

### Токенайзер

In [None]:
class ChordTokenizer:
    def __init__(self):
        self._padding_token = "[PAD]"
        self._unknown_token = "[UNK]"
        self._cls_token = "[CLS]"
        self._sep_token = "[SEP]"
        self._mask_token = "[MASK]"
        
        # Special tokens IDs
        self._padding_id = 0
        self._cls_id = 1
        self._sep_id = 2
        self._mask_token_id = 3
        self._unknown_token_id = 4
        
        # Музыкальные элементы
        self.notes = ['C', 'C#', 'D', 'D#', 'E', 'F', 'F#', 'G', 'G#', 'A', 'A#', 'B']
        self.moods = ['m', 'maj', 'min', 'aug', 'dim', 'sus2', 'sus4', 'sus']
        self.extensions = [
            '5', '6', '7', '9', '11', '13', 
            'add9', 'add11', 'add13'
        ]
        self.symbols = ['/', 'b', '#', '(', ')', ' ']
        
        # Сложные аккорды для добавления в словарь
        self.complex_chords = [
            'A5(9)', 'Cadd9', 'Dsus4', 'Emadd9', 'G5(11)',
            'Fmaj7', 'G9', 'Am11', 'C7(9)', 'Dsus2',
            'Cmaj9', 'F#m7', 'Bbmaj7', 'E7sus4', 'Aadd9'
        ]
        
        self._init_vocab()

    @property
    def vocab(self) -> Mapping[int, str]:
        return self._vocab
    
    @property
    def reverse_vocab(self) -> Mapping[str, int]:
        return {token: idx for idx, token in self._vocab.items()}
    
    @property
    def cls_id(self) -> int:
        return self._cls_id
    
    @property
    def mask_token_id(self) -> int:
        return self._mask_token_id
    
    @property
    def padding_id(self) -> int:
        return self._padding_id
    
    @property
    def sep_id(self) -> int:
        return self._sep_id
    
    @property
    def unknown_token_id(self) -> int:
        return self._unknown_token_id

    def _init_vocab(self) -> None:
        """Инициализация словаря с специальными токенами"""
        self._vocab = {
            self._padding_id: self._padding_token,
            self._cls_id: self._cls_token,
            self._sep_id: self._sep_token,
            self._mask_token_id: self._mask_token,
            self._unknown_token_id: self._unknown_token,
        }
    
    def fit(self, corpus: List[str]) -> Self:
        """Создание словаря на основе корпуса"""
        self._init_vocab()
        
        # Добавляем базовые музыкальные элементы
        all_elements = (self.notes + self.moods + self.extensions + 
                       self.symbols + self.complex_chords)
        
        for element in all_elements:
            if element not in self._vocab.values():
                self._vocab[len(self._vocab)] = element
        
        # Обрабатываем корпус для извлечения дополнительных аккордов
        for text in corpus:
            chords = text.split()
            for chord in chords:
                if chord not in self.reverse_vocab and chord not in self._vocab.values():
                    self._vocab[len(self._vocab)] = chord
        
        return self
    
    def tokenize_text(self, text: str | List[str]) -> List[str] | List[List[str]]:
        """Токенизация текста в строковые токены"""
        if isinstance(text, str):
            return self._tokenize_text(text)
        assert isinstance(text, list), "`text` should be str or List[str]"
        return [self._tokenize_text(chunk) for chunk in text]
 
    def tokenize_ids(self, text: str | List[str]) -> List[int] | List[List[int]]:
        """Токенизация текста в ID токенов"""
        if isinstance(text, str):
            return self._tokenize_ids(text)
        assert isinstance(text, list), "`text` should be str or List[str]"
        return [self._tokenize_ids(chunk) for chunk in text]
    
    def decode(self, tokens: List[int]) -> str:
        """Декодирование ID токенов обратно в строку"""
        content = []
        reverse_vocab = self.reverse_vocab
        
        for token_id in tokens:
            if token_id in [self._padding_id, self._cls_id, self._sep_id, self._mask_token_id]:
                continue
            
            token = self._vocab.get(token_id, self._unknown_token)
            if token == self._unknown_token:
                continue
                
            content.append(token)
        
        # Собираем аккорды из токенов
        result = []
        current_chord = []
        
        for token in content:
            if token == ' ':
                if current_chord:
                    result.append(''.join(current_chord))
                    current_chord = []
            else:
                current_chord.append(token)
        
        if current_chord:
            result.append(''.join(current_chord))
            
        return ' '.join(result)

    def _tokenize_text(self, text: str) -> List[str]:
        """Внутренний метод для токенизации строки в текстовые токены"""
        tokens = [self._cls_token]
        reverse_vocab = self.reverse_vocab
        
        chords = text.split()
        
        for i, chord in enumerate(chords):
            # Пытаемся найти целый аккорд в словаре
            if chord in reverse_vocab:
                tokens.append(chord)
            else:
                # Разбиваем аккорд на составляющие
                chord_parts = self._split_chord(chord)
                for part in chord_parts:
                    if part in reverse_vocab:
                        tokens.append(part)
                    else:
                        tokens.append(self._unknown_token)
            
            # Добавляем пробел между аккордами (кроме последнего)
            if i < len(chords) - 1:
                tokens.append(' ')
        
        tokens.append(self._sep_token)
        return tokens
    
    def _tokenize_ids(self, text: str) -> List[int]:
        """Внутренний метод для токенизации строки в ID токенов"""
        text_tokens = self._tokenize_text(text)
        reverse_vocab = self.reverse_vocab
        return [reverse_vocab.get(token, self._unknown_token_id) for token in text_tokens]
    
    def _split_chord(self, chord: str) -> List[str]:
        """Разбивает аккорд на составляющие элементы"""
        # Регулярное выражение для разбора аккордов
        pattern = r'[A-G][#b]?|[a-z]+|\d+|[\/\(\)#b]'
        parts = re.findall(pattern, chord)
        return parts
    
    def __len__(self) -> int:
        return len(self._vocab)


# Адаптер для Hugging Face (обновленная версия)
class ChordTokenizerHF:
    def __init__(self, chord_tokenizer: ChordTokenizer):
        self.chord_tokenizer = chord_tokenizer

    def __call__(self, texts, padding=True, truncation=True, max_length=128, return_tensors=None):
        if isinstance(texts, str):
            texts = [texts]

        input_ids = []
        attention_masks = []

        for text in texts:
            token_ids = self.chord_tokenizer.tokenize_ids(text)
            
            # Обрезаем если нужно
            if truncation and len(token_ids) > max_length:
                token_ids = token_ids[:max_length]

            attention_mask = [1] * len(token_ids)

            # Добавляем паддинг если нужно
            if padding:
                padding_length = max_length - len(token_ids)
                token_ids = token_ids + [self.chord_tokenizer.padding_id] * padding_length
                attention_mask = attention_mask + [0] * padding_length

            input_ids.append(token_ids)
            attention_masks.append(attention_mask)

        output = {
            'input_ids': input_ids,
            'attention_mask': attention_masks
        }

        if return_tensors == 'pt':
            import torch
            output['input_ids'] = torch.tensor(output['input_ids'])
            output['attention_mask'] = torch.tensor(output['attention_mask'])

        return output

    def decode(self, token_ids: List[int]) -> str:
        """Декодирование ID токенов обратно в строку"""
        return self.chord_tokenizer.decode(token_ids)

Обучаем токенайзер

In [16]:
tokenizer = ChordTokenizer()
corpus = dataset['cleaned_chords'].to_list()
tokenizer.fit(corpus)

<__main__.ChordTokenizer at 0x7f53fb817890>

### Подготовка данных к обучению

In [None]:
class TrainingBatch(NamedTuple):
    input: torch.LongTensor
    label: torch.LongTensor
    segment_label: torch.LongTensor
    is_next: torch.LongTensor

class ChordDataset(Dataset):
    def __init__(
        self,
        corpus: List[str],
        fitted_tokenizer: ChordTokenizer,
        max_seq_len: int = 128,
        mask_prob: float = 0.15
    ):
        self.corpus = corpus
        self.tokenizer = fitted_tokenizer
        self.max_seq_len = max_seq_len
        self.mask_prob = mask_prob
        
        # ID специальных токенов
        self._pad_id = fitted_tokenizer.padding_id
        self._cls_id = fitted_tokenizer.cls_id
        self._sep_id = fitted_tokenizer.sep_id
        self._mask_id = fitted_tokenizer.mask_token_id
        self._unk_id = fitted_tokenizer.unknown_token_id
        self._vocab_size = len(fitted_tokenizer.vocab)
        
        self._sequences = self._prepare_sequences()
        
    def _prepare_sequences(self) -> List[List[str]]:
        sequences = []
        for chord_progression in self.corpus:
            chords = chord_progression.split()
            for i in range(0, len(chords), self.max_seq_len // 2):
                sequence = chords[i:i + self.max_seq_len // 2]
                if len(sequence) >= 2:
                    sequences.append(sequence)
        return sequences
    
    def __len__(self) -> int:
        return len(self._sequences)
    
    def __getitem__(self, idx: int) -> TrainingBatch:
        seq1, seq2, is_next_label = self._get_next_sentence_pair(idx)
        
        seq1_masked, seq1_labels = self._mask_chords(seq1)
        seq2_masked, seq2_labels = self._mask_chords(seq2)
        
        bert_input, bert_labels, segment_labels = self._prepare_bert_input(
            seq1_masked, seq2_masked, seq1_labels, seq2_labels
        )
        
        return TrainingBatch(
            input=torch.LongTensor(bert_input),
            label=torch.LongTensor(bert_labels),
            segment_label=torch.LongTensor(segment_labels),
            is_next=torch.LongTensor([is_next_label]),
        )
    
    def _get_next_sentence_pair(self, idx: int) -> Tuple[List[str], List[str], int]:
        prob = random.random()
        
        if prob < 0.5:
            sequence = self._sequences[idx]
            if len(sequence) >= 4:
                split_point = len(sequence) // 2
                seq1 = sequence[:split_point]
                seq2 = sequence[split_point:split_point + len(seq1)]
                return seq1, seq2, 1
            else:
                return sequence, sequence, 1
        
        seq1 = self._sequences[idx]
        second_idx = random.randrange(len(self._sequences))
        while second_idx == idx:
            second_idx = random.randrange(len(self._sequences))
        seq2 = self._sequences[second_idx]
        
        min_len = min(len(seq1), len(seq2))
        seq1 = seq1[:min_len]
        seq2 = seq2[:min_len]
        
        return seq1, seq2, 0
    
    def _mask_chords(self, chords: List[str]) -> Tuple[List[str], List[str]]:
        masked_chords = []
        labels = []
        
        for chord in chords:
            prob = random.random()
            
            if prob < self.mask_prob:
                if random.random() < 0.8:
                    masked_chords.append(self.tokenizer._mask_token)
                elif random.random() < 0.9:
                    random_chord = random.choice(list(self.tokenizer.reverse_vocab.keys()))
                    while random_chord in [self.tokenizer._padding_token, self.tokenizer._cls_token, 
                                         self.tokenizer._sep_token, self.tokenizer._mask_token]:
                        random_chord = random.choice(list(self.tokenizer.reverse_vocab.keys()))
                    masked_chords.append(random_chord)
                else:
                    masked_chords.append(chord)
                labels.append(chord)
            else:
                masked_chords.append(chord)
                labels.append(self.tokenizer._padding_token)
        
        return masked_chords, labels
    
    def _prepare_bert_input(self, seq1: List[str], seq2: List[str], 
                          labels1: List[str], labels2: List[str]) -> Tuple[List[int], List[int], List[int], List[int]]:
        seq1_ids = self._chords_to_ids(seq1)
        seq2_ids = self._chords_to_ids(seq2)
        labels1_ids = self._chords_to_ids(labels1)
        labels2_ids = self._chords_to_ids(labels2)
        
        bert_input = [self._cls_id] + seq1_ids + [self._sep_id] + seq2_ids + [self._sep_id]
        bert_labels = [self._pad_id] + labels1_ids + [self._pad_id] + labels2_ids + [self._pad_id]
        
        segment_labels = [0] * (1 + len(seq1_ids) + 1)
        segment_labels += [1] * (len(seq2_ids) + 1)
        
        
        if len(bert_input) > self.max_seq_len:
            bert_input = bert_input[:self.max_seq_len]
            bert_labels = bert_labels[:self.max_seq_len]
            segment_labels = segment_labels[:self.max_seq_len]
        
        padding_len = self.max_seq_len - len(bert_input)
        if padding_len > 0:
            bert_input += [self._pad_id] * padding_len
            bert_labels += [self._pad_id] * padding_len
            last_segment = segment_labels[-1] if segment_labels else 0
            segment_labels += [last_segment] * padding_len
        
        return bert_input, bert_labels, segment_labels
    
    def _chords_to_ids(self, chords: List[str]) -> List[int]:
        ids = []
        for chord in chords:
            if chord in self.tokenizer.reverse_vocab:
                ids.append(self.tokenizer.reverse_vocab[chord])
            else:
                ids.append(self._unk_id)
        return ids

Далее классы из учебного ноутбука

In [147]:
class ScheduledOptim():
    '''A simple wrapper class for learning rate scheduling'''

    def __init__(
        self,
        optimizer: torch.optim.Optimizer,
        embedding_size: int,
        n_warmup_steps,
    ):
        self._optimizer = optimizer
        self.n_warmup_steps = n_warmup_steps
        self.n_current_steps = 0
        self.init_lr = np.power(embedding_size, -0.5)

    def step_and_update_lr(self):
        "Step with the inner optimizer"
        self._update_learning_rate()
        self._optimizer.step()

    def zero_grad(self):
        "Zero out the gradients by the inner optimizer"
        self._optimizer.zero_grad()

    def _get_lr_scale(self):
        return np.min([
            np.power(self.n_current_steps, -0.5),
            np.power(self.n_warmup_steps, -1.5) * self.n_current_steps])

    def _update_learning_rate(self):
        ''' Learning rate scheduling per step '''

        self.n_current_steps += 1
        lr = self.init_lr * self._get_lr_scale()

        for param_group in self._optimizer.param_groups:
            param_group['lr'] = lr

In [148]:
class RotaryPositionEmbedding(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        base: int = 10_000,
    ) -> None:
        super().__init__()
        self._theta = 1 / (torch.pow(torch.tensor(base), (torch.arange(0, embedding_size, 2).float() / embedding_size)))
        self._theta = self._theta.repeat_interleave(2)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        position_ids = torch.arange(0, x.size(-2), device=x.device)
        position_matrix = torch.outer(position_ids, self._theta.to(x.device))
        cos = torch.cos(position_matrix)
        sin = torch.sin(position_matrix)
        x_odd = x[..., ::2]
        x_even = x[..., 1::2]

        _x = torch.empty_like(x, device=x.device)
        _x[..., 0::2] = -x_even
        _x[..., 1::2] = x_odd

        # x_stacked = torch.stack([-x_even, x_odd], dim=-1)
        # _x = x_stacked.flatten(start_dim=-2)
        _x = _x * sin[:x.size(-2), :]
        x = x * cos[:x.size(-2), :]
        return x + _x

class BERTEmbedding(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embedding_size: int,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self._embeddings = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=embedding_size,
            padding_idx=0,
        )
        self._segment_embeddings = nn.Embedding(
            num_embeddings=3,
            embedding_dim=embedding_size,
            padding_idx=0,
        )
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.LongTensor, segmet_label: torch.LongTensor) -> torch.Tensor:
        x = self._embeddings(x) + self._segment_embeddings(segmet_label)
        return self.dropout(x)

In [149]:
class RoPEMultiHeadedAttention(nn.Module):
    def __init__(
        self,
        num_heads: int,
        embedding_size: int,
        head_embedding_size: int,
        positional_embedding: RotaryPositionEmbedding,
        dropout: float = 0.1,
    ):
        super().__init__()
        self._num_heads = num_heads
        self._embedding_size = embedding_size
        self._head_embedding_size = head_embedding_size
        self._positional_embedding = positional_embedding
        self._Q = nn.Linear(self._embedding_size, self._num_heads * self._head_embedding_size)
        self._K = nn.Linear(self._embedding_size, self._num_heads * self._head_embedding_size)
        self._V = nn.Linear(self._embedding_size, self._num_heads * self._head_embedding_size)
        self._W_proj = nn.Linear(self._num_heads * self._head_embedding_size, self._embedding_size)
        self._dropout = nn.Dropout(p=dropout)
        self._layernorm = nn.LayerNorm(self._embedding_size)

    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        mask: Optional[torch.Tensor] = None,
    ) -> torch.Tensor:
        batch_size = query.size(0)

        q = self._Q.forward(query).view(batch_size, -1, self._num_heads, self._head_embedding_size).transpose(1, 2)
        k = self._K.forward(key).view(batch_size, -1, self._num_heads, self._head_embedding_size).transpose(1, 2)
        v = self._V.forward(value).view(batch_size, -1, self._num_heads, self._head_embedding_size).transpose(1, 2)

        q_rope = self._positional_embedding.forward(q)
        k_rope = self._positional_embedding.forward(k)

        attention_numerator = torch.exp(
            torch.matmul(q_rope, k_rope.transpose(-1, -2)) / torch.sqrt(torch.tensor(self._head_embedding_size))
        )
        attention_denominator = torch.exp(
            torch.matmul(q, k.transpose(-1, -2)) / torch.sqrt(torch.tensor(self._head_embedding_size))
        )
        attention_denominator = torch.sum(attention_denominator, dim=-1, keepdim=True)
        a = attention_numerator / attention_denominator
        # a = torch.matmul(q_rope, k_rope.transpose(-1, -2)) / torch.sqrt(torch.tensor(self._head_embedding_size))
        if mask is not None:
            # mask = mask.unsqueeze(1)
            a = a.masked_fill(mask == 0, -torch.inf)
        
        alpha = F.softmax(a, -1)

        z = torch.matmul(alpha, v).transpose(1, 2).contiguous().view(batch_size, -1, self._num_heads * self._head_embedding_size)
        z = self._W_proj(z)
        return self._layernorm(query + self._dropout(z))



In [150]:
class FCNNBlock(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        hidden_size: int,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self._linear1 = nn.Linear(embedding_size, hidden_size, bias=False)
        self._linear2 = nn.Linear(hidden_size, embedding_size, bias=False)
        self._activation = nn.GELU()
        self._layernorm = nn.LayerNorm(embedding_size)
        self._dropout = nn.Dropout(dropout)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        z = self._linear2(self._activation(self._linear1(x)))
        return self._layernorm(x + self._dropout(z))

class EncoderLayer(nn.Module):
    def __init__(
        self,
        embedding_size: int,
        num_heads: int,
        head_embedding_size: int,
        fcnn_hidden_size: int,
        positional_embedding: RotaryPositionEmbedding,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self._mha = RoPEMultiHeadedAttention(
            embedding_size=embedding_size,
            num_heads=num_heads,
            head_embedding_size=head_embedding_size,
            positional_embedding=positional_embedding,
            dropout=dropout,
        )
        self._fcnn = FCNNBlock(
            embedding_size=embedding_size,
            hidden_size=fcnn_hidden_size,
            dropout=dropout,
        )
        
    def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
        return self._fcnn(self._mha(x, x, x, mask))

In [151]:
class Encoder(nn.Module):
    def __init__(
        self,
        vocab_size: int,
        n_layers: int,
        embedding_size: int,
        num_heads: int,
        head_embedding_size: int,
        fcnn_hidden_size: int,
        dropout: float = 0.1,
    ) -> None:
        super().__init__()
        self._vocab_size = vocab_size
        self._embedding_size = embedding_size
        self._embeddings = BERTEmbedding(
            vocab_size=vocab_size,
            embedding_size=embedding_size,
            dropout=dropout,
        )
        self._positional_embeddings = RotaryPositionEmbedding(
            embedding_size=head_embedding_size,
            base=1000,
        )
        self._layers = nn.ModuleList(
            EncoderLayer(
                embedding_size=embedding_size,
                num_heads=num_heads,
                head_embedding_size=head_embedding_size,
                fcnn_hidden_size=fcnn_hidden_size,
                positional_embedding=self._positional_embeddings,
                dropout=dropout,
            )
            for _ in range(n_layers)
        )
    
    @property
    def vocab_size(self) -> int:
        return self._vocab_size
    
    @property
    def embedding_size(self) -> int:
        return self._embedding_size

    def forward(
        self,
        x: torch.LongTensor,
        segment: torch.LongTensor,
    ) -> torch.Tensor:
        mask = (x > 0).unsqueeze(1).repeat(1, x.size(1), 1).unsqueeze(1)
        z = self._embeddings(x, segment)
        for layer in self._layers:
            z = layer(z, mask)
        return z

In [None]:
class NextSentencePrediction(nn.Module):
    """
    2-class classification model : is_next, is_not_next
    """

    def __init__(self, embedding_size: int):
        """
        :param hidden: BERT model output size
        """
        super().__init__()
        self._linear = nn.Linear(embedding_size, 2)
        self._softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        # use only the first token which is the [CLS]
        return self._softmax(self._linear(x[:, 0]))

class MaskedLanguageModel(nn.Module):
    """
    predicting origin token from masked input sequence
    n-class classification problem, n-class = vocab_size
    """

    def __init__(self, embedding_size: int, vocab_size: int):
        """
        :param hidden: output size of BERT model
        :param vocab_size: total vocab size
        """
        super().__init__()
        self._linear = nn.Linear(embedding_size, vocab_size)
        self._softmax = nn.LogSoftmax(dim=-1)

    def forward(self, x):
        return self._softmax(self._linear(x))

class BERTLM(nn.Module):
    """
    BERT Language Model
    Next Sentence Prediction Model + Masked Language Model
    """

    def __init__(self, encoder: Encoder):
        """
        :param bert: BERT model which should be trained
        :param vocab_size: total vocab size for masked_lm
        """

        super().__init__()
        self._encoder = encoder
        self._next_sentence = NextSentencePrediction(self._encoder.embedding_size)
        self._mask_lm = MaskedLanguageModel(self._encoder.embedding_size, self._encoder.vocab_size)

    def forward(
        self,
        x: torch.LongTensor,
        segment_label: torch.LongTensor,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        x = self._encoder(x, segment_label)
        return self._next_sentence(x), self._mask_lm(x)

In [None]:
class BERTTrainer:
    def __init__(
        self, 
        model: BERTLM, 
        train_dataloader: DataLoader, 
        test_dataloader: Optional[DataLoader] = None, 
        lr: float = 1e-4,
        weight_decay: float = 0.01,
        betas: Tuple[float, float] = (0.9, 0.999),
        warmup_steps: int = 10000,
        log_freq: int = 10,
        device: str = 'cuda'
    ) -> None:
        self.device = device
        self.model = model
        self.train_data = train_dataloader
        self.test_data = test_dataloader

        self.optim = Adam(
            self.model.parameters(),
            lr=lr,
            betas=betas,
            weight_decay=weight_decay,
        )
        self.optim_schedule = ScheduledOptim(
            optimizer=self.optim,
            embedding_size=self.model._encoder.embedding_size,
            n_warmup_steps=warmup_steps,
        )

        self.criterion = nn.NLLLoss(ignore_index=0)
        self.log_freq = log_freq
        print("Total Parameters:", sum([p.nelement() for p in self.model.parameters()]))
    
    def train(self, epoch: int) -> None:
        self.iteration(epoch, self.train_data)

    def test(self, epoch: int) -> None:
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch: int, data_loader: DataLoader, train: bool = True) -> None:
        avg_loss = 0.0
        total_correct = 0
        total_element = 0
        
        mode = "train" if train else "test"

        # progress bar
        data_iter = tqdm(
            enumerate(data_loader),
            desc="EP_%s:%d" % (mode, epoch),
            total=len(data_loader),
            bar_format="{l_bar}{r_bar}"
        )

        for i, data in data_iter:

            data: TrainingBatch
            bert_input = data.input.to(self.device)
            label = data.label.to(self.device)
            segment_label = data.segment_label.to(self.device)
            is_next = data.is_next.to(self.device).view(-1)

            next_sent_output, mask_lm_output = self.model.forward(bert_input, segment_label)

            next_loss = self.criterion(next_sent_output, is_next)
            mask_loss = self.criterion(mask_lm_output.transpose(1, 2), label)

            loss = next_loss + mask_loss

            if train:
                self.optim_schedule.zero_grad()
                loss.backward()
                self.optim_schedule.step_and_update_lr()

            correct = next_sent_output.argmax(dim=-1).eq(is_next).sum().item()
            avg_loss += loss.item()
            total_correct += correct
            total_element += is_next.nelement()

            post_fix = {
                "epoch": epoch,
                "iter": i,
                "avg_loss": avg_loss / (i + 1),
                "avg_acc": total_correct / total_element * 100,
                "loss": loss.item()
            }

            if i != 0 and i % self.log_freq == 0:
                data_iter.write(str(post_fix))
        print(
            f"EP{epoch}, {mode}: \
            avg_loss={avg_loss / len(data_iter)}, \
            total_acc={total_correct * 100.0 / total_element}"
        )

In [None]:
VOCAB_SIZE = len(tokenizer.vocab)
BATCH_SIZE = 128
MAX_SEQ_LEN = 50
N_LAYERS = 6
EMBEDDING_SIZE = 64
NUM_HEADS = 8
HEAD_EMBEDDING_SIZE = EMBEDDING_SIZE // NUM_HEADS
FCCN_HIDDEN_SIZE = EMBEDDING_SIZE * 4
n_epoch = 10

In [155]:
dataset = ChordDataset(
    corpus=corpus,
    fitted_tokenizer=tokenizer,
)

In [160]:
def collate(data: List[TrainingBatch]):   
    inputs = torch.stack([batch.input for batch in data])
    labels = torch.stack([batch.label for batch in data])
    segment_labels = torch.stack([batch.segment_label for batch in data])
    is_nexts = torch.stack([batch.is_next for batch in data])
    return TrainingBatch(
        input=torch.LongTensor(inputs),
        label=torch.LongTensor(labels),
        segment_label=torch.LongTensor(segment_labels),
        is_next=torch.LongTensor(is_nexts),
    )

In [161]:
dataloader = DataLoader(
    dataset=dataset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    collate_fn=collate
)

In [162]:
batch = next(iter(dataloader))

In [163]:
encoder = Encoder(
    vocab_size=VOCAB_SIZE,
    n_layers=N_LAYERS,
    embedding_size=EMBEDDING_SIZE,
    num_heads=NUM_HEADS,
    head_embedding_size=HEAD_EMBEDDING_SIZE,
    fcnn_hidden_size=FCCN_HIDDEN_SIZE,
)

In [164]:
bert_model = BERTLM(
    encoder=encoder
)
device = "cuda:0"
bert_trainer = BERTTrainer(
    model=bert_model,
    train_dataloader=dataloader,
    device=device,
    log_freq=200,
    warmup_steps=1000,
    lr=0.005,
)

Total Parameters: 454783


In [165]:
torch.cuda.is_available() 

True

Обучение

In [166]:
bert_model.to(device)
bert_model.train()
for epoch in range(n_epoch):
    bert_trainer.train(epoch)

EP_train:0:   0%|| 0/1757 [00:00<?, ?it/s]

EP_train:0:  11%|| 201/1757 [05:23<37:07,  1.43s/it]  

{'epoch': 0, 'iter': 200, 'avg_loss': 4.583915023661372, 'avg_acc': 49.906716417910445, 'loss': 2.8390960693359375}


EP_train:0:  23%|| 401/1757 [10:47<35:12,  1.56s/it]  

{'epoch': 0, 'iter': 400, 'avg_loss': 3.635679988195177, 'avg_acc': 49.73503740648379, 'loss': 2.6608200073242188}


EP_train:0:  34%|| 601/1757 [16:05<27:04,  1.41s/it]

{'epoch': 0, 'iter': 600, 'avg_loss': 3.2840149747750127, 'avg_acc': 49.762115224625624, 'loss': 2.6582040786743164}


EP_train:0:  46%|| 801/1757 [21:15<22:49,  1.43s/it]

{'epoch': 0, 'iter': 800, 'avg_loss': 3.0971017681555204, 'avg_acc': 49.92489856429464, 'loss': 2.5018794536590576}


EP_train:0:  57%|| 1001/1757 [26:29<18:15,  1.45s/it]

{'epoch': 0, 'iter': 1000, 'avg_loss': 2.980887124350259, 'avg_acc': 49.862637362637365, 'loss': 2.5605549812316895}


EP_train:0:  68%|| 1201/1757 [31:39<14:00,  1.51s/it]

{'epoch': 0, 'iter': 1200, 'avg_loss': 2.900857978419003, 'avg_acc': 49.834773105745214, 'loss': 2.3839354515075684}


EP_train:0:  80%|| 1401/1757 [36:49<09:34,  1.61s/it]

{'epoch': 0, 'iter': 1400, 'avg_loss': 2.8396837966600375, 'avg_acc': 49.878992683797286, 'loss': 2.479987382888794}


EP_train:0:  91%|| 1601/1757 [41:53<03:24,  1.31s/it]

{'epoch': 0, 'iter': 1600, 'avg_loss': 2.7946713488970154, 'avg_acc': 49.887765459088065, 'loss': 2.398185968399048}


EP_train:0: 100%|| 1757/1757 [45:56<00:00,  1.57s/it]


EP0, train:             avg_loss=2.7648969575906928,             total_acc=49.88990360448926


EP_train:1:  11%|| 201/1757 [05:18<36:41,  1.41s/it]  

{'epoch': 1, 'iter': 200, 'avg_loss': 2.4513491635298847, 'avg_acc': 50.310945273631845, 'loss': 2.5045316219329834}


EP_train:1:  23%|| 401/1757 [10:33<33:07,  1.47s/it]

{'epoch': 1, 'iter': 400, 'avg_loss': 2.4540527913338526, 'avg_acc': 50.1909289276808, 'loss': 2.261896848678589}


EP_train:1:  34%|| 601/1757 [15:48<27:59,  1.45s/it]

{'epoch': 1, 'iter': 600, 'avg_loss': 2.449858318747776, 'avg_acc': 50.16508943427621, 'loss': 2.4617042541503906}


EP_train:1:  46%|| 801/1757 [21:11<22:54,  1.44s/it]

{'epoch': 1, 'iter': 800, 'avg_loss': 2.4470131031136386, 'avg_acc': 50.0897315855181, 'loss': 2.557003974914551}


EP_train:1:  57%|| 1001/1757 [26:33<16:56,  1.34s/it]

{'epoch': 1, 'iter': 1000, 'avg_loss': 2.4460702440240882, 'avg_acc': 49.960196053946056, 'loss': 2.511751651763916}


EP_train:1:  68%|| 1201/1757 [31:50<14:02,  1.52s/it]

{'epoch': 1, 'iter': 1200, 'avg_loss': 2.444172134804388, 'avg_acc': 49.90242506244796, 'loss': 2.534467935562134}


EP_train:1:  80%|| 1401/1757 [37:31<09:26,  1.59s/it]

{'epoch': 1, 'iter': 1400, 'avg_loss': 2.4430144889962238, 'avg_acc': 49.92583422555318, 'loss': 2.453115224838257}


EP_train:1:  91%|| 1601/1757 [43:00<03:50,  1.48s/it]

{'epoch': 1, 'iter': 1600, 'avg_loss': 2.4407094749043243, 'avg_acc': 49.96486570893192, 'loss': 2.328341484069824}


EP_train:1: 100%|| 1757/1757 [47:19<00:00,  1.62s/it]


EP1, train:             avg_loss=2.4392070979236133,             total_acc=49.97041854423651


EP_train:2:  11%|| 201/1757 [05:38<41:07,  1.59s/it]  

{'epoch': 2, 'iter': 200, 'avg_loss': 2.4317682085938714, 'avg_acc': 50.3847947761194, 'loss': 2.4985995292663574}


EP_train:2:  23%|| 401/1757 [11:15<36:38,  1.62s/it]  

{'epoch': 2, 'iter': 400, 'avg_loss': 2.4274992235283603, 'avg_acc': 50.30782418952619, 'loss': 2.4928088188171387}


EP_train:2:  34%|| 601/1757 [17:04<38:23,  1.99s/it]

{'epoch': 2, 'iter': 600, 'avg_loss': 2.4271739762952045, 'avg_acc': 50.1169925124792, 'loss': 2.346231698989868}


EP_train:2:  46%|| 801/1757 [22:42<22:56,  1.44s/it]

{'epoch': 2, 'iter': 800, 'avg_loss': 2.428359199254849, 'avg_acc': 50.10631242197253, 'loss': 2.3193583488464355}


EP_train:2:  57%|| 1001/1757 [28:19<19:49,  1.57s/it]

{'epoch': 2, 'iter': 1000, 'avg_loss': 2.4268581040732036, 'avg_acc': 50.0975586913087, 'loss': 2.447850465774536}


EP_train:2:  68%|| 1201/1757 [33:45<14:15,  1.54s/it]

{'epoch': 2, 'iter': 1200, 'avg_loss': 2.426380247994327, 'avg_acc': 50.134653413821816, 'loss': 2.4470109939575195}


EP_train:2:  80%|| 1401/1757 [39:06<09:17,  1.56s/it]

{'epoch': 2, 'iter': 1400, 'avg_loss': 2.4260179839927245, 'avg_acc': 50.105951106352606, 'loss': 2.3856284618377686}


EP_train:2:  91%|| 1601/1757 [44:25<04:29,  1.73s/it]

{'epoch': 2, 'iter': 1600, 'avg_loss': 2.4267344503087003, 'avg_acc': 50.03903810118676, 'loss': 2.5323798656463623}


EP_train:2: 100%|| 1757/1757 [48:33<00:00,  1.66s/it]


EP2, train:             avg_loss=2.4265949390934622,             total_acc=50.02958145576349


EP_train:3:  11%|| 201/1757 [05:29<39:35,  1.53s/it]  

{'epoch': 3, 'iter': 200, 'avg_loss': 2.4150026947704712, 'avg_acc': 49.76679104477612, 'loss': 2.628065347671509}


EP_train:3:  23%|| 401/1757 [10:48<35:39,  1.58s/it]  

{'epoch': 3, 'iter': 400, 'avg_loss': 2.4174418033209824, 'avg_acc': 49.563591022443894, 'loss': 2.429598331451416}


EP_train:3:  34%|| 601/1757 [15:58<27:00,  1.40s/it]

{'epoch': 3, 'iter': 600, 'avg_loss': 2.418319446671624, 'avg_acc': 49.57752703826955, 'loss': 2.4006240367889404}


EP_train:3:  46%|| 801/1757 [21:08<23:23,  1.47s/it]

{'epoch': 3, 'iter': 800, 'avg_loss': 2.4197262747308588, 'avg_acc': 49.56109550561798, 'loss': 2.449817657470703}


EP_train:3:  57%|| 1001/1757 [26:26<18:22,  1.46s/it]

{'epoch': 3, 'iter': 1000, 'avg_loss': 2.418215110466316, 'avg_acc': 49.71903096903097, 'loss': 2.4587457180023193}


EP_train:3:  68%|| 1201/1757 [31:37<20:47,  2.24s/it]

{'epoch': 3, 'iter': 1200, 'avg_loss': 2.4184060595017685, 'avg_acc': 49.76451915070774, 'loss': 2.367908000946045}


EP_train:3:  80%|| 1401/1757 [37:02<11:15,  1.90s/it]

{'epoch': 3, 'iter': 1400, 'avg_loss': 2.418870943261418, 'avg_acc': 49.70556745182012, 'loss': 2.498136281967163}


EP_train:3:  91%|| 1601/1757 [42:20<03:49,  1.47s/it]

{'epoch': 3, 'iter': 1600, 'avg_loss': 2.4194897180196273, 'avg_acc': 49.78431449094316, 'loss': 2.3225836753845215}


EP_train:3: 100%|| 1757/1757 [46:22<00:00,  1.58s/it]


EP3, train:             avg_loss=2.4196073837964183,             total_acc=49.7915953078918


EP_train:4:  11%|| 201/1757 [05:20<37:23,  1.44s/it]  

{'epoch': 4, 'iter': 200, 'avg_loss': 2.422714774288348, 'avg_acc': 50.108830845771145, 'loss': 2.376873016357422}


EP_train:4:  23%|| 401/1757 [10:49<45:13,  2.00s/it]  

{'epoch': 4, 'iter': 400, 'avg_loss': 2.4157274470959518, 'avg_acc': 50.15001558603491, 'loss': 2.4925031661987305}


EP_train:4:  34%|| 601/1757 [16:16<28:16,  1.47s/it]

{'epoch': 4, 'iter': 600, 'avg_loss': 2.4135962369636372, 'avg_acc': 50.02079866888519, 'loss': 2.510922908782959}


EP_train:4:  46%|| 801/1757 [21:39<30:37,  1.92s/it]

{'epoch': 4, 'iter': 800, 'avg_loss': 2.4140640742769848, 'avg_acc': 50.12386860174781, 'loss': 2.4050257205963135}


EP_train:4:  57%|| 1001/1757 [27:05<20:18,  1.61s/it]

{'epoch': 4, 'iter': 1000, 'avg_loss': 2.4154752865656985, 'avg_acc': 50.24038461538461, 'loss': 2.346954345703125}


EP_train:4:  68%|| 1201/1757 [32:29<14:03,  1.52s/it]

{'epoch': 4, 'iter': 1200, 'avg_loss': 2.4151299496078176, 'avg_acc': 50.26410283097419, 'loss': 2.4232616424560547}


EP_train:4:  80%|| 1401/1757 [37:54<08:51,  1.49s/it]

{'epoch': 4, 'iter': 1400, 'avg_loss': 2.4159295722299774, 'avg_acc': 50.2018647394718, 'loss': 2.435267925262451}


EP_train:4:  91%|| 1601/1757 [43:28<04:02,  1.56s/it]

{'epoch': 4, 'iter': 1600, 'avg_loss': 2.4158846599321526, 'avg_acc': 50.2108057464085, 'loss': 2.353776454925537}


EP_train:4: 100%|| 1757/1757 [47:47<00:00,  1.63s/it]


EP4, train:             avg_loss=2.415271248407641,             total_acc=50.23509472738353


EP_train:5:  11%|| 201/1757 [05:27<38:28,  1.48s/it]  

{'epoch': 5, 'iter': 200, 'avg_loss': 2.4205935475838123, 'avg_acc': 49.261504975124375, 'loss': 2.391791820526123}


EP_train:5:  23%|| 401/1757 [10:43<34:19,  1.52s/it]  

{'epoch': 5, 'iter': 400, 'avg_loss': 2.421555832437149, 'avg_acc': 49.69996882793018, 'loss': 2.3811299800872803}


EP_train:5:  34%|| 601/1757 [16:04<30:58,  1.61s/it]

{'epoch': 5, 'iter': 600, 'avg_loss': 2.4164815266398145, 'avg_acc': 49.98700083194675, 'loss': 2.4824087619781494}


EP_train:5:  46%|| 801/1757 [21:25<24:19,  1.53s/it]

{'epoch': 5, 'iter': 800, 'avg_loss': 2.4162919467754578, 'avg_acc': 49.872230024968786, 'loss': 2.469680070877075}


EP_train:5:  57%|| 1001/1757 [26:53<22:48,  1.81s/it]

{'epoch': 5, 'iter': 1000, 'avg_loss': 2.4162168276536238, 'avg_acc': 49.72527472527473, 'loss': 2.3064308166503906}


EP_train:5:  68%|| 1201/1757 [32:16<14:12,  1.53s/it]

{'epoch': 5, 'iter': 1200, 'avg_loss': 2.416566780862165, 'avg_acc': 49.71378018318068, 'loss': 2.3835482597351074}


EP_train:5:  80%|| 1401/1757 [37:37<08:42,  1.47s/it]

{'epoch': 5, 'iter': 1400, 'avg_loss': 2.4157460430194275, 'avg_acc': 49.79869289793005, 'loss': 2.4613330364227295}


EP_train:5:  91%|| 1601/1757 [42:58<03:53,  1.50s/it]

{'epoch': 5, 'iter': 1600, 'avg_loss': 2.415535185353448, 'avg_acc': 49.788706277326675, 'loss': 2.510542154312134}


EP_train:5: 100%|| 1757/1757 [47:09<00:00,  1.61s/it]


EP5, train:             avg_loss=2.4149234992370867,             total_acc=49.827182021592236


EP_train:6:  11%|| 201/1757 [05:27<1:01:53,  2.39s/it]

{'epoch': 6, 'iter': 200, 'avg_loss': 2.4164548003258397, 'avg_acc': 49.665733830845774, 'loss': 2.4156079292297363}


EP_train:6:  23%|| 401/1757 [10:56<34:54,  1.54s/it]  

{'epoch': 6, 'iter': 400, 'avg_loss': 2.4162255736657805, 'avg_acc': 49.76426122194514, 'loss': 2.46608829498291}


EP_train:6:  34%|| 601/1757 [16:23<35:40,  1.85s/it]

{'epoch': 6, 'iter': 600, 'avg_loss': 2.414082328015675, 'avg_acc': 49.86870840266223, 'loss': 2.404184103012085}


EP_train:6:  46%|| 801/1757 [21:47<23:36,  1.48s/it]

{'epoch': 6, 'iter': 800, 'avg_loss': 2.4112733898686707, 'avg_acc': 50.10631242197253, 'loss': 2.325376510620117}


EP_train:6:  57%|| 1001/1757 [27:16<19:23,  1.54s/it]

{'epoch': 6, 'iter': 1000, 'avg_loss': 2.410376271763286, 'avg_acc': 50.01170704295704, 'loss': 2.4317941665649414}


EP_train:6:  68%|| 1201/1757 [32:45<14:26,  1.56s/it]

{'epoch': 6, 'iter': 1200, 'avg_loss': 2.41138362249268, 'avg_acc': 49.95966902581183, 'loss': 2.383100986480713}


EP_train:6:  80%|| 1401/1757 [38:07<08:48,  1.48s/it]

{'epoch': 6, 'iter': 1400, 'avg_loss': 2.41051588817463, 'avg_acc': 49.95538900785154, 'loss': 2.441837787628174}


EP_train:6:  91%|| 1601/1757 [43:39<04:15,  1.64s/it]

{'epoch': 6, 'iter': 1600, 'avg_loss': 2.4113989537541083, 'avg_acc': 49.932171299188006, 'loss': 2.332197427749634}


EP_train:6: 100%|| 1757/1757 [47:56<00:00,  1.64s/it]


EP6, train:             avg_loss=2.411682005151131,             total_acc=49.96863920855149


EP_train:7:  11%|| 201/1757 [05:30<38:11,  1.47s/it]  

{'epoch': 7, 'iter': 200, 'avg_loss': 2.3979690691724938, 'avg_acc': 49.59577114427861, 'loss': 2.3971776962280273}


EP_train:7:  23%|| 401/1757 [10:56<33:38,  1.49s/it]  

{'epoch': 7, 'iter': 400, 'avg_loss': 2.4045810354618062, 'avg_acc': 49.78374376558604, 'loss': 2.4799129962921143}


EP_train:7:  34%|| 601/1757 [16:24<34:01,  1.77s/it]

{'epoch': 7, 'iter': 600, 'avg_loss': 2.4059662315095722, 'avg_acc': 49.91030574043261, 'loss': 2.300373077392578}


EP_train:7:  46%|| 801/1757 [21:51<25:44,  1.62s/it]

{'epoch': 7, 'iter': 800, 'avg_loss': 2.4079008176829784, 'avg_acc': 49.92977528089887, 'loss': 2.3746490478515625}


EP_train:7:  57%|| 1001/1757 [27:19<20:47,  1.65s/it]

{'epoch': 7, 'iter': 1000, 'avg_loss': 2.4069193655198866, 'avg_acc': 50.050730519480524, 'loss': 2.492142677307129}


EP_train:7:  68%|| 1201/1757 [32:48<14:31,  1.57s/it]

{'epoch': 7, 'iter': 1200, 'avg_loss': 2.406442662857653, 'avg_acc': 50.08131244796004, 'loss': 2.28385853767395}


EP_train:7:  80%|| 1401/1757 [38:17<09:29,  1.60s/it]

{'epoch': 7, 'iter': 1400, 'avg_loss': 2.407244948129157, 'avg_acc': 50.012268022840836, 'loss': 2.5108680725097656}


EP_train:7:  91%|| 1601/1757 [44:01<04:35,  1.76s/it]

{'epoch': 7, 'iter': 1600, 'avg_loss': 2.4072468356740693, 'avg_acc': 50.04586976889445, 'loss': 2.3663885593414307}


EP_train:7: 100%|| 1757/1757 [48:19<00:00,  1.65s/it]


EP7, train:             avg_loss=2.407120781862132,             total_acc=50.046040310849946


EP_train:8:  11%|| 201/1757 [05:30<40:56,  1.58s/it]

{'epoch': 8, 'iter': 200, 'avg_loss': 2.3997126254276258, 'avg_acc': 50.221548507462686, 'loss': 2.3286755084991455}


EP_train:8:  23%|| 401/1757 [11:27<36:53,  1.63s/it]

{'epoch': 8, 'iter': 400, 'avg_loss': 2.402316837834004, 'avg_acc': 50.12663653366584, 'loss': 2.472522258758545}


EP_train:8:  34%|| 601/1757 [17:03<30:24,  1.58s/it]

{'epoch': 8, 'iter': 600, 'avg_loss': 2.404788430240904, 'avg_acc': 50.00779950083195, 'loss': 2.44401216506958}


EP_train:8:  46%|| 801/1757 [22:36<30:04,  1.89s/it]

{'epoch': 8, 'iter': 800, 'avg_loss': 2.406702218728417, 'avg_acc': 50.06242197253433, 'loss': 2.523512125015259}


EP_train:8:  57%|| 1001/1757 [28:13<20:16,  1.61s/it]

{'epoch': 8, 'iter': 1000, 'avg_loss': 2.406048832358895, 'avg_acc': 50.07414460539461, 'loss': 2.323251962661743}


EP_train:8:  68%|| 1201/1757 [33:40<15:09,  1.64s/it]

{'epoch': 8, 'iter': 1200, 'avg_loss': 2.407300380743314, 'avg_acc': 50.12164342214821, 'loss': 2.3453943729400635}


EP_train:8:  80%|| 1401/1757 [39:16<10:55,  1.84s/it]

{'epoch': 8, 'iter': 1400, 'avg_loss': 2.4068508280931753, 'avg_acc': 50.09647127052106, 'loss': 2.458913803100586}


EP_train:8:  91%|| 1601/1757 [44:46<03:57,  1.52s/it]

{'epoch': 8, 'iter': 1600, 'avg_loss': 2.4068189946507603, 'avg_acc': 50.07807620237351, 'loss': 2.4273414611816406}


EP_train:8: 100%|| 1757/1757 [49:00<00:00,  1.67s/it]


EP8, train:             avg_loss=2.4068693675028716,             total_acc=50.122106911384634


EP_train:9:  11%|| 201/1757 [05:33<50:47,  1.96s/it]

{'epoch': 9, 'iter': 200, 'avg_loss': 2.402234783220054, 'avg_acc': 50.36536069651741, 'loss': 2.369131565093994}


EP_train:9:  23%|| 401/1757 [11:12<37:34,  1.66s/it]

{'epoch': 9, 'iter': 400, 'avg_loss': 2.406831335248495, 'avg_acc': 50.20067019950125, 'loss': 2.4412736892700195}


EP_train:9:  34%|| 601/1757 [16:40<31:59,  1.66s/it]

{'epoch': 9, 'iter': 600, 'avg_loss': 2.4073949959036116, 'avg_acc': 50.12349209650583, 'loss': 2.363548755645752}


EP_train:9:  46%|| 801/1757 [22:19<27:33,  1.73s/it]

{'epoch': 9, 'iter': 800, 'avg_loss': 2.4076826384897982, 'avg_acc': 50.11606585518103, 'loss': 2.4103188514709473}


EP_train:9:  57%|| 1001/1757 [27:49<19:15,  1.53s/it]

{'epoch': 9, 'iter': 1000, 'avg_loss': 2.4092423318030236, 'avg_acc': 50.164679070929076, 'loss': 2.357564926147461}


EP_train:9:  68%|| 1201/1757 [33:15<15:10,  1.64s/it]

{'epoch': 9, 'iter': 1200, 'avg_loss': 2.4082632017175323, 'avg_acc': 50.13140091590341, 'loss': 2.438218116760254}


EP_train:9:  80%|| 1401/1757 [38:48<10:17,  1.73s/it]

{'epoch': 9, 'iter': 1400, 'avg_loss': 2.407226068645099, 'avg_acc': 50.09814418272662, 'loss': 2.3929646015167236}


EP_train:9:  91%|| 1601/1757 [44:11<03:57,  1.52s/it]

{'epoch': 9, 'iter': 1600, 'avg_loss': 2.4079777207097584, 'avg_acc': 50.04977357901311, 'loss': 2.4873321056365967}


EP_train:9: 100%|| 1757/1757 [48:29<00:00,  1.66s/it]

EP9, train:             avg_loss=2.407972543566622,             total_acc=50.046929978692454





Сохраняем модель

In [None]:
def save_model(model, tokenizer, filepath):
    """Сохраняет модель и токенайзер"""
    checkpoint = {
        'model_state_dict': model.model.state_dict(),
        'tokenizer_vocab': tokenizer.vocab,
        'tokenizer_reverse_vocab': tokenizer.reverse_vocab,
        'model_config': {
            'vocab_size': model.model._encoder.embedding_size,
            'embedding_size': model.model._encoder.embedding_size,
        }
    }
    
    torch.save(checkpoint, filepath)
    print(f"Модель сохранена в {filepath}")

save_model(bert_trainer, tokenizer, 'chord_bert_model.pth')

Модель сохранена в chord_bert_model.pth


Отдельно еще сохраняем токенайзер

In [None]:
def save_tokenizer(tokenizer, filepath):
    """Сохраняет словарь токенайзера"""
    tokenizer_data = {
        '_vocab': tokenizer.vocab,
        'notes': tokenizer.notes,
        'moods': tokenizer.moods,
        'extensions': tokenizer.extensions,
        'symbols': tokenizer.symbols,
        'complex_chords': tokenizer.complex_chords
    }
    
    torch.save(tokenizer_data, filepath)
    print(f"Токенайзер сохранен в {filepath}")

def load_tokenizer(filepath):
    """Загружает токенайзер из файла"""
    tokenizer_data = torch.load(filepath)
    

    tokenizer = ChordTokenizer()

    tokenizer._vocab = tokenizer_data['_vocab']
    tokenizer.notes = tokenizer_data['notes']
    tokenizer.moods = tokenizer_data['moods']
    tokenizer.extensions = tokenizer_data['extensions']
    tokenizer.symbols = tokenizer_data['symbols']
    tokenizer.complex_chords = tokenizer_data['complex_chords']
    
    print(f"Токенайзер загружен из {filepath}")
    return tokenizer

save_tokenizer(tokenizer, 'chord_tokenizer.pth')
loaded_tokenizer = load_tokenizer('chord_tokenizer.pth')

Токенайзер сохранен в chord_tokenizer.pth
Токенайзер загружен из chord_tokenizer.pth


Переключаем модельку в режим эвала

In [177]:
bert_model.eval()

BERTLM(
  (_encoder): Encoder(
    (_embeddings): BERTEmbedding(
      (_embeddings): Embedding(1213, 64, padding_idx=0)
      (_segment_embeddings): Embedding(3, 64, padding_idx=0)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (_positional_embeddings): RotaryPositionEmbedding()
    (_layers): ModuleList(
      (0-5): 6 x EncoderLayer(
        (_mha): RoPEMultiHeadedAttention(
          (_positional_embedding): RotaryPositionEmbedding()
          (_Q): Linear(in_features=64, out_features=64, bias=True)
          (_K): Linear(in_features=64, out_features=64, bias=True)
          (_V): Linear(in_features=64, out_features=64, bias=True)
          (_W_proj): Linear(in_features=64, out_features=64, bias=True)
          (_dropout): Dropout(p=0.1, inplace=False)
          (_layernorm): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        )
        (_fcnn): FCNNBlock(
          (_linear1): Linear(in_features=64, out_features=256, bias=False)
          (_linear2): Linear(in_f

Пример предикта

In [None]:
def predict_masked_chord(model, tokenizer, chord_sequence):
    """
    Предсказывает аккорд на месте знака вопроса в последовательности
    Пример: ["A", "B", "?", "B", "D"] -> предсказывает аккорд вместо '?'
    """
    print(f'Input {chord_sequence}')
    chord_sequence = chord_sequence.split()
    masked_sequence = [chord if chord != '?' else tokenizer._mask_token for chord in chord_sequence]
    if tokenizer._mask_token not in masked_sequence:
        masked_sequence.append(tokenizer._mask_token)

    input_ids = tokenizer.tokenize_ids(masked_sequence)
    input_ids = [i[1] for i in input_ids]


    inputs = torch.tensor([input_ids])


    mask_index = input_ids.index(tokenizer.mask_token_id)
    segment_label = torch.zeros_like(inputs)

    with torch.no_grad():
        outputs = model(inputs, segment_label)
        predictions = outputs.logits if hasattr(outputs, 'logits') else outputs


    predicted_index = torch.argmax(predictions[1][0][mask_index]).item()
    predicted_chord = tokenizer.vocab.get(predicted_index, tokenizer._unknown_token)
    masked_sequence[mask_index] = predicted_chord
    separator = " "
    result = separator.join(masked_sequence)
    print(f'Result {result}')
    return predicted_chord


seqs = [
'? C D F B B F B B F B B F D F C D D C D F',
'D ? D F B B F B B F B B F D F C D D C D F',
'D C ? F B B F B B F B B F D F C D D C D F',
'D C D ? B B F B B F B B F D F C D D C D F',
'D C D F ? B F B B F B B F D F C D D C D F',
'D C D F B ? F B B F B B F D F C D D C D F',
'D C D F B B ? B B F B B F D F C D D C D F',
'D C D F B B F ? B F B B F D F C D D C D F',
'D C D F B B F B ? F B B F D F C D D C D F',
'D C D F B B F B B ? B B F D F C D D C D F',
'D C D F B B F B B F ? B F D F C D D C D F',
'D C D F B B F B B F B ? F D F C D D C D F',
'D C D F B B F B B F B B ? D F C D D C D F',
'D C D F B B F B B F B B F ? F C D D C D F',
'D C D F B B F B B F B B F D ? C D D C D F',
'D C D F B B F B B F B B F D F ? D D C D F',
'D C D F B B F B B F B B F D F C ? D C D F',
'D C D F B B F B B F B B F D F C D ? C D F',
'D C D F B B F B B F B B F D F C D D ? D F',
'D C D F B B F B B F B B F D F C D D C ? F',
'D C D F B B F B B F B B F D F C D D C D ?'
]
for sequence in seqs:
    predicted = predict_masked_chord(bert_model, tokenizer, sequence)
    print(f"Предсказанный аккорд: {predicted}")

Предсказанный аккорд: C
Предсказанный аккорд: G
Предсказанный аккорд: C
Предсказанный аккорд: G
Предсказанный аккорд: G
Предсказанный аккорд: G
Предсказанный аккорд: G
Предсказанный аккорд: G
Предсказанный аккорд: G
Предсказанный аккорд: G
Предсказанный аккорд: G
Предсказанный аккорд: G
Предсказанный аккорд: G
Предсказанный аккорд: C
Предсказанный аккорд: G
Предсказанный аккорд: G
Предсказанный аккорд: C
Предсказанный аккорд: C
Предсказанный аккорд: G
Предсказанный аккорд: C
Предсказанный аккорд: G
