# 04 Transformer Training

В этом ноутбуке будет осуществлена тренировка нейросетевой архитектуры на основе трансформера (esm_classifier) для задачи предсказания вторичной структуры белка по аминокислотным последовательностям.


In [None]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from sklearn.metrics import classification_report
import pickle
import os


In [3]:
## Загрузка предобработанных данных


In [None]:
base_path = r"/trinity/home/e.bulavko/a.khokhlov/data/processed"

Xtrain = np.load(f'{base_path}/X_train.npy')
ytrain = np.load(f'{base_path}/y_train.npy')

Xval = np.load(f'{base_path}/X_val.npy')
yval = np.load(f'{base_path}/y_val.npy')

Xtest = np.load(f'{base_path}/X_test.npy')
ytest = np.load(f'{base_path}/y_test.npy')

mask_train = np.load(f'{base_path}/mask_train.npy')
mask_val = np.load(f'{base_path}/mask_val.npy')
mask_test = np.load(f'{base_path}/mask_test.npy')

class_weights = np.load(f'{base_path}/class_weights.npy')


In [None]:
def subsample_data(X, y, mask, factor=4):
    # new length будет в factor раз меньше
    new_length = X.shape[0] // factor

    X_subsampled = X[:new_length]
    y_subsampled = y[:new_length]
    mask_subsampled = mask[:new_length]

    return X_subsampled, y_subsampled, mask_subsampled

Xtrain, ytrain, mask_train = subsample_data(Xtrain, ytrain, mask_train, factor=1)
Xval, yval, mask_val = subsample_data(Xval, yval, mask_val, factor=1)
Xtest, ytest, mask_test = subsample_data(Xtest, ytest, mask_test, factor=1)


In [28]:
## Кастомный Dataset для PyTorch


In [29]:
class ProteinDataset(Dataset):
    def __init__(self, X, y, mask):
        self.X = X
        self.y = y
        self.mask = mask

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return {
            'input': torch.LongTensor(self.X[idx]),
            'target': torch.LongTensor(self.y[idx]),
            'mask': torch.BoolTensor(self.mask[idx]),
        }


## DataLoader'ы


In [None]:
batch_size = 64

train_dataset = ProteinDataset(Xtrain, ytrain, mask_train)
val_dataset = ProteinDataset(Xval, yval, mask_val)
test_dataset = ProteinDataset(Xtest, ytest, mask_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)

print('Данные загружены')


## Определение архитектуры трансформера (esm_classifier)


In [31]:
class ESMClassifier(nn.Module):
    def __init__(self, vocab_size, num_labels, d_model=128, nhead=8, num_layers=2, dim_feedforward=256, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(d_model, num_labels)

    def forward(self, x, mask):
        x = self.embedding(x)
        x = x.transpose(0, 1)
        key_padding_mask = ~mask
        x = self.transformer(x, src_key_padding_mask=key_padding_mask)
        x = x.transpose(0, 1)
        logits = self.classifier(x)
        return logits


Инициализатор


In [None]:
def init_weights_normal(model, std=0.02):
    for name, param in model.named_parameters():
        if 'weight' in name:
            nn.init.normal_(param, mean=0.0, std=std)
        elif 'bias' in name:
            nn.init.constant_(param, 0.0)

In [32]:
## Подготовка к обучению: функция потерь, модель, оптимизатор


In [None]:
num_labels = 4
vocab_size = 21
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ESMClassifier(vocab_size, num_labels).to(device)
init_weights_normal(model) #добавили инициализацию
weights = np.concatenate(([0.0], class_weights))
#loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(weights, dtype=torch.float32, device=device), ignore_index=0)
loss_fn = nn.CrossEntropyLoss(ignore_index=0)
#optimizer = optim.Adam(model.parameters(), lr=1e-5)
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-5)

print('Модель подготовлена')




In [34]:
## Цикл обучения с валидацией


In [None]:
import time

model.train()
start = time.time()
for i, batch in enumerate(train_loader):
    inputs = batch['input'].to(device)
    targets = batch['target'].to(device)
    mask = batch['mask'].to(device)

    optimizer.zero_grad()
    outputs = model(inputs, mask)
    outputs = outputs.view(-1, num_labels)
    targets_flat = targets.view(-1).long()
    loss = loss_fn(outputs, targets_flat)
    loss.backward()
    torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) 
    optimizer.step()

    if i == 0:  # только одна итерация для оценки времени
        break
end = time.time()
print(f"Время обработки одного батча: {end - start:.2f} секунд")


Время обработки одного батча: 0.85 секунд


In [None]:
n_epochs = 20
best_val_f1 = 0

patience = 5
patience_counter = 0

for epoch in range(n_epochs):
    model.train()
    train_loss = 0
    for batch in train_loader:
        inputs = batch['input'].to(device)
        targets = batch['target'].to(device)
        mask = batch['mask'].to(device)

        optimizer.zero_grad()
        outputs = model(inputs, mask)
        outputs = outputs.view(-1, num_labels)
        targets_flat = targets.view(-1)
        loss = loss_fn(outputs, targets_flat)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) #добавил от безысходности
        optimizer.step()
        train_loss += loss.item()

    train_loss /= len(train_loader)

    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['input'].to(device)
            targets = batch['target'].to(device)
            mask = batch['mask'].to(device)
            outputs = model(inputs, mask)
            preds = outputs.argmax(dim=-1)
            for i in range(inputs.size(0)):
                seq_mask = mask[i].cpu().numpy()
                all_preds.extend(preds[i][seq_mask].cpu().numpy())
                all_targets.extend(targets[i][seq_mask].cpu().numpy())

    print(f'Epoch {epoch+1}, Train loss: {train_loss:.4f}')
    print(classification_report(all_targets, all_preds, digits=3))

    f1 = classification_report(all_targets, all_preds, output_dict=True, zero_division=0)['weighted avg']['f1-score']
    if f1 > best_val_f1:  # Изменил >= на > для более строгого критерия
        best_val_f1 = f1
        patience_counter = 0  # ← сбросить счетчик
        torch.save(model.state_dict(), 'best_esm_classifier.pth')
        print('Model saved!')
    else:
        patience_counter += 1  # ← увеличить счетчик
        if patience_counter >= patience:  # ← проверить лимит
            print("Early stopping!")
            break

print('Best validation F1:', best_val_f1)


In [None]:
## Тестирование лучшей модели


In [None]:

if os.path.exists('best_esm_classifier.pth'):
    model.load_state_dict(torch.load('best_esm_classifier.pth'))
    print("Загружена лучшая модель.")
else:
    print("Внимание: Лучшая модель не найдена (возможно, обучение не сошлось). Тестирование будет на текущих весах.")

model.eval()
all_test_preds = []
all_test_targets = []

with torch.no_grad():
    for batch in test_loader:
        inputs = batch['input'].to(device)
        targets = batch['target'].to(device)
        mask = batch['mask'].to(device)
        outputs = model(inputs, mask)
        preds = outputs.argmax(dim=-1)
        for i in range(inputs.size(0)):
            seq_mask = mask[i].cpu().numpy()
            all_test_preds.extend(preds[i][seq_mask].cpu().numpy())
            all_test_targets.extend(targets[i][seq_mask].cpu().numpy())

print(classification_report(all_test_targets, all_test_preds, digits=3))


FileNotFoundError: [Errno 2] No such file or directory: 'best_esm_classifier.pth'