In [16]:
import numpy as np
import polars as pl
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from tqdm import tqdm
from transformers import AutoModel

import re
from typing import List, Dict, Any, Tuple, Optional, Mapping, Set, Self, NamedTuple, TypedDict
from utils import *

In [17]:
def load_tokenizer(filepath):
    """Загружает токенайзер из файла"""
    tokenizer_data = torch.load(filepath)

    # Создаем новый токенайзер
    tokenizer = ChordTokenizer()

    # Восстанавливаем состояние
    tokenizer._vocab = tokenizer_data['_vocab']
    tokenizer.notes = tokenizer_data['notes']
    tokenizer.moods = tokenizer_data['moods']
    tokenizer.extensions = tokenizer_data['extensions']
    tokenizer.symbols = tokenizer_data['symbols']
    tokenizer.complex_chords = tokenizer_data['complex_chords']

    print(f"Токенайзер загружен из {filepath}")
    return tokenizer
loaded_tokenizer = load_tokenizer('chord_tokenizer.pth')

Токенайзер загружен из chord_tokenizer.pth


In [18]:
bert_model = AutoModel.from_pretrained("prajjwal1/bert-mini")

hidden_dim = 128
num_classes = 15

model = ChordBERTMiniLSTMClassifier(
    bert_model=bert_model,
    hidden_dim=hidden_dim,
    num_classes=num_classes,
    dropout=0.3
)

MODEL_PATH = 'bert_model_genre.pt'
state_dict = torch.load(MODEL_PATH, map_location="cpu")
model.load_state_dict(state_dict)

model.eval()

ChordBERTMiniLSTMClassifier(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(30522, 256, padding_idx=0)
      (position_embeddings): Embedding(512, 256)
      (token_type_embeddings): Embedding(2, 256)
      (LayerNorm): LayerNorm((256,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-3): 4 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=256, out_features=256, bias=True)
              (key): Linear(in_features=256, out_features=256, bias=True)
              (value): Linear(in_features=256, out_features=256, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=256, out_features=256, bias=True)
              (LayerNorm): LayerNorm((256,), eps=1e-12,

In [19]:
from transformers import AutoModel
bert_model = AutoModel.from_pretrained("prajjwal1/bert-mini")

hidden_dim = 128
num_classes = 15

model = ChordBERTMiniLSTMClassifier(
    bert_model=bert_model,
    hidden_dim=hidden_dim,
    num_classes=num_classes,
    dropout=0.3
)

MODEL_PATH = 'bert_model_genre.pt'
state_dict = torch.load(MODEL_PATH, map_location="cpu")
model.load_state_dict(state_dict)

model.eval()

def predict_genres(model, tokenizer, chord_sequence, threshold=0.5):
    """
    Делает предсказание жанров на основе последовательности аккордов.
    Модель — multilabel classifier.
    """
    model.eval()

    inputs = tokenizer(
        chord_sequence,
        return_tensors='pt',
        padding=True,
        truncation=True,
        max_length=128
    )

    with torch.no_grad():
        logits = model(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"]
        )

    logits = logits[0]             
    probs = torch.sigmoid(logits) 

    predicted = (probs > threshold).int().tolist()
    all_genres = ['children / family', 'classical', 'electronic / edm', 'folk / country', 'hip hop / rap',
                    'jazz / blues', 'latin / world', 'metal', 'other / misc', 'pop', 'punk / hardcore', 'r&b / soul',
            'religious / worship', 'rock', 'soundtrack / score / instrumental']
    predicted_genres = [all_genres[i] for i in range(len(all_genres)) if predicted[i] == 1]
    return predicted_genres

In [20]:
genres = predict_genres(
    model=model,
    tokenizer=ChordTokenizerHF(loaded_tokenizer),
    chord_sequence="D C D F B B F B B F B ? F D F C D D C D D F"
)

print("Predicted genres:", genres)

Predicted genres: ['pop', 'rock']
