# 04 Transformer Training

В этом ноутбуке будет осуществлена тренировка нейросетевой архитектуры на основе трансформера (esm_classifier) для задачи предсказания вторичной структуры белка по аминокислотным последовательностям.


In [1]:
import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim
from sklearn.metrics import classification_report
import pickle
torch.autograd.set_detect_anomaly(True)

<torch.autograd.anomaly_mode.set_detect_anomaly at 0x26bdd1e3230>

In [2]:
## Загрузка предобработанных данных


In [None]:
base_path = r"/trinity/home/e.bulavko/a.khokhlov/data/processed"

Xtrain = np.load(f'{base_path}/X_train.npy')
ytrain = np.load(f'{base_path}/y_train.npy')

Xval = np.load(f'{base_path}/X_val.npy')
yval = np.load(f'{base_path}/y_val.npy')

Xtest = np.load(f'{base_path}/X_test.npy')
ytest = np.load(f'{base_path}/y_test.npy')

mask_train = np.load(f'{base_path}/mask_train.npy')
mask_val = np.load(f'{base_path}/mask_val.npy')
mask_test = np.load(f'{base_path}/mask_test.npy')

class_weights = np.load(f'{base_path}/class_weights.npy')


In [None]:
## Кастомный Dataset для PyTorch


In [7]:
class ProteinDataset(Dataset):
    def __init__(self, X, y, mask):
        self.X = X
        self.y = y
        self.mask = mask

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):
        return {
            'input': torch.LongTensor(self.X[idx]),
            'target': torch.LongTensor(self.y[idx]),
            'mask': torch.BoolTensor(self.mask[idx]),
        }


In [4]:
def diagnose_data(X, y, mask, vocab_size=21, num_labels=4):
    """Проверяет корректность данных"""
    
    print("=== DATA DIAGNOSTICS ===")
    
    # 1. Формы
    print(f"1. Shapes: X={X.shape}, y={y.shape}, mask={mask.shape}")
    assert X.shape == y.shape == mask.shape, "Shape mismatch!"
    
    # 2. Диапазоны
    print(f"2. Value ranges:")
    print(f"   X: min={X.min()}, max={X.max()}, expected=[0, {vocab_size})")
    print(f"   y: min={y.min()}, max={y.max()}, expected=[0, {num_labels})")
    print(f"   mask: min={mask.min()}, max={mask.max()}, expected=[0, 1]")
    
    assert X.min() >= 0 and X.max() < vocab_size, "X out of range!"
    assert y.min() >= 0 and y.max() < num_labels, "y out of range!"
    assert set(np.unique(mask)) <= {0, 1}, "mask not binary!"
    
    # 3. Типы данных
    print(f"3. Data types: X={X.dtype}, y={y.dtype}, mask={mask.dtype}")
    
    # 4. Маски не все нули
    mask_per_seq = mask.sum(axis=1)
    print(f"4. Valid positions per sequence: min={mask_per_seq.min()}, max={mask_per_seq.max()}")
    
    if (mask_per_seq == 0).any():
        print("   WARNING: Some sequences are ALL padding!")
    all_padding = (mask.sum(axis=1) == 0).sum()
    print(f"Sequences with ALL padding: {all_padding} / {len(mask)} ({100*all_padding/len(mask):.2f}%)")

    if all_padding > 0:
        print("WARNING: These sequences should be REMOVED or handled specially!")
    
    # 5. Распределение классов
    y_valid = y[mask == 1]
    if len(y_valid) > 0:
        unique, counts = np.unique(y_valid, return_counts=True)
        print(f"5. Class distribution (only valid positions):")
        for cls, cnt in zip(unique, counts):
            print(f"   Class {cls}: {cnt} ({100*cnt/len(y_valid):.1f}%)")
    
    print("=== ALL CHECKS PASSED ===")

# Использование:
diagnose_data(Xtrain, ytrain, mask_train)
diagnose_data(Xval, yval, mask_val)
diagnose_data(Xtest, ytest, mask_test)

=== DATA DIAGNOSTICS ===
1. Shapes: X=(381769, 700), y=(381769, 700), mask=(381769, 700)
2. Value ranges:
   X: min=0, max=20, expected=[0, 21)
   y: min=0, max=3, expected=[0, 4)
   mask: min=0, max=1, expected=[0, 1]
3. Data types: X=int64, y=int64, mask=int64
4. Valid positions per sequence: min=0, max=700
Sequences with ALL padding: 1142 / 381769 (0.30%)
5. Class distribution (only valid positions):
   Class 1: 33488580 (34.0%)
   Class 2: 20365789 (20.7%)
   Class 3: 44653636 (45.3%)
=== ALL CHECKS PASSED ===
=== DATA DIAGNOSTICS ===
1. Shapes: X=(47668, 700), y=(47668, 700), mask=(47668, 700)
2. Value ranges:
   X: min=0, max=20, expected=[0, 21)
   y: min=0, max=3, expected=[0, 4)
   mask: min=0, max=1, expected=[0, 1]
3. Data types: X=int64, y=int64, mask=int64
4. Valid positions per sequence: min=0, max=700
Sequences with ALL padding: 153 / 47668 (0.32%)
5. Class distribution (only valid positions):
   Class 1: 4185893 (33.9%)
   Class 2: 2563599 (20.8%)
   Class 3: 5591449 (4

In [5]:
def remove_all_padding_sequences(X, y, mask):
    """Удаляет последовательности, которые целиком padding"""
    # Вычисляем количество валидных позиций в каждой последовательности
    valid_count = mask.sum(axis=1)
    
    # Оставляем только последовательности с хотя бы одной валидной позицией
    valid_idx = valid_count > 0
    
    n_removed = (~valid_idx).sum()
    print(f"Removing {n_removed} all-padding sequences")
    
    return X[valid_idx], y[valid_idx], mask[valid_idx]

# Применяем фильтр ко всем наборам
print("\n=== FILTERING ALL-PADDING SEQUENCES ===")

Xtrain, ytrain, mask_train = remove_all_padding_sequences(Xtrain, ytrain, mask_train)
Xval, yval, mask_val = remove_all_padding_sequences(Xval, yval, mask_val)
Xtest, ytest, mask_test = remove_all_padding_sequences(Xtest, ytest, mask_test)

print(f"Train: {Xtrain.shape[0]} sequences")
print(f"Val: {Xval.shape[0]} sequences")
print(f"Test: {Xtest.shape[0]} sequences")



=== FILTERING ALL-PADDING SEQUENCES ===
Removing 1142 all-padding sequences
Removing 153 all-padding sequences
Removing 141 all-padding sequences
Train: 380627 sequences
Val: 47515 sequences
Test: 47575 sequences


## DataLoader'ы


In [None]:
batch_size = 4

train_dataset = ProteinDataset(Xtrain, ytrain, mask_train)
val_dataset = ProteinDataset(Xval, yval, mask_val)
test_dataset = ProteinDataset(Xtest, ytest, mask_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size)
test_loader = DataLoader(test_dataset, batch_size=batch_size)


## Определение архитектуры трансформера (esm_classifier)


In [9]:
class ESMClassifier(nn.Module):
    def __init__(self, vocab_size, num_labels, d_model=128, nhead=8, num_layers=2, dim_feedforward=256, dropout=0.1):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_model, padding_idx=0)
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=nhead,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            activation='relu'
        )
        self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
        self.classifier = nn.Linear(d_model, num_labels)

    def forward(self, x, mask):
        x = self.embedding(x)
        x = x.transpose(0, 1)
        key_padding_mask = ~mask
        x = self.transformer(x, src_key_padding_mask=key_padding_mask)
        x = x.transpose(0, 1)
        logits = self.classifier(x)
        return logits


In [10]:
## Подготовка к обучению: функция потерь, модель, оптимизатор


In [None]:
num_labels = 4
vocab_size = 21
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = ESMClassifier(vocab_size, num_labels).to(device)
weights = np.concatenate(([0.0], class_weights))
loss_fn = nn.CrossEntropyLoss(weight=torch.tensor(weights, dtype=torch.float32, device=device), ignore_index=0)
optimizer = optim.Adam(model.parameters(), lr=1e-4)




In [33]:
## Цикл обучения с валидацией


In [34]:
import time

model.train()
start = time.time()
for i, batch in enumerate(train_loader):
    inputs = batch['input'].to(device)
    targets = batch['target'].to(device)
    mask = batch['mask'].to(device)

    optimizer.zero_grad()
    outputs = model(inputs, mask)
    outputs = outputs.view(-1, num_labels)
    targets_flat = targets.view(-1).long()
    loss = loss_fn(outputs, targets_flat)
    loss.backward()
    optimizer.step()

    if i == 0:  # только одна итерация для оценки времени
        break
end = time.time()
print(f"Время обработки одного батча: {end - start:.2f} секунд")


RuntimeError: weight tensor should be defined either for all 4 classes or no classes but got weight tensor of shape: [5]

In [None]:
print(np.any(np.isnan(Xtrain)))
print(np.any(np.isnan(ytrain)))
print(np.any(np.isnan(mask_train)))

False
False
False


In [None]:
n_epochs = 5
best_val_f1 = 0

for epoch in range(n_epochs):
    model.train()
    train_loss = 0
    for batch in train_loader:
        inputs = batch['input'].to(device)
        targets = batch['target'].to(device)
        mask = batch['mask'].to(device)

        optimizer.zero_grad()
        outputs = model(inputs, mask)
        outputs = outputs.view(-1, num_labels)
        targets_flat = targets.view(-1)
        loss = loss_fn(outputs, targets_flat)
        loss.backward()
        optimizer.step()
        train_loss += loss.item()
    train_loss /= len(train_loader)

    model.eval()
    all_preds = []
    all_targets = []

    with torch.no_grad():
        for batch in val_loader:
            inputs = batch['input'].to(device)
            targets = batch['target'].to(device)
            mask = batch['mask'].to(device)
            outputs = model(inputs, mask)
            preds = outputs.argmax(dim=-1)
            for i in range(inputs.size(0)):
                seq_mask = mask[i].cpu().numpy()
                all_preds.extend(preds[i][seq_mask].cpu().numpy())
                all_targets.extend(targets[i][seq_mask].cpu().numpy())
    if mask.sum() == 0: print('zero mask') 
    print(f'Epoch {epoch+1}, Train loss: {train_loss:.4f}')
    print(classification_report(all_targets, all_preds, digits=3))

    f1 = classification_report(all_targets, all_preds, output_dict=True, zero_division=0)['weighted avg']['f1-score']
    if f1 > best_val_f1:
        best_val_f1 = f1
        torch.save(model.state_dict(), 'best_esm_classifier.pth')
        print('Model saved!')

print('Best validation F1:', best_val_f1)


KeyboardInterrupt: 

In [None]:
## Тестирование лучшей модели


In [None]:
model.load_state_dict(torch.load('best_esm_classifier.pth'))
model.eval()
all_test_preds = []
all_test_targets = []

with torch.no_grad():
    for batch in test_loader:
        inputs = batch['input'].to(device)
        targets = batch['target'].to(device)
        mask = batch['mask'].to(device)
        outputs = model(inputs, mask)
        preds = outputs.argmax(dim=-1)
        for i in range(inputs.size(0)):
            seq_mask = mask[i].cpu().numpy()
            all_test_preds.extend(preds[i][seq_mask].cpu().numpy())
            all_test_targets.extend(targets[i][seq_mask].cpu().numpy())

print(classification_report(all_test_targets, all_test_preds, digits=3))


FileNotFoundError: [Errno 2] No such file or directory: 'best_esm_classifier.pth'