In [1]:
import torch
from tqdm import tqdm

from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score
from utils.dataprovider import DataProvider
from torch.utils.data import DataLoader
from config import config

import logging
import time
import os

In [2]:
def train_model(model, train_loader, val_loader, criterion, optimizer, device, log_file, num_epochs, patience, model_path, regularize=False):
    model.train()
    best_loss = float('inf')
    epochs_no_improve = 0
    early_stop = False
    logging.basicConfig(filename=log_file, level=logging.INFO, format='%(asctime)s - [Training process]: [base_model0]-[model0]-[Epoch %(epoch)04d] - time: %(time)3d s, train_acc: %(train_acc).5f, val_acc: %(val_acc).5f, train_loss: %(train_loss).5f, val_loss: %(val_loss).5f')

    for epoch in range(num_epochs):
        if early_stop:
            print(f'Early stopping at epoch {epoch}')
            break

        start_time = time.time()
        epoch_loss = 0
        val_loss = 0
        correct_train = 0
        total_train = 0
        correct_val = 0
        total_val = 0

        model.train()
        with tqdm(train_loader, unit="batch") as tepoch:
            for batch in tepoch:
                tepoch.set_description(f"Epoch {epoch+1}/{num_epochs}")

                batch_matrices, affinities = batch
                batch_matrices = batch_matrices.to(device)
                affinities = affinities.to(device)

                optimizer.zero_grad()
                outputs = model(batch_matrices)
                loss = criterion(outputs, affinities)
                
                # Regularization
                if regularize:
                    loss = model.regularize(loss, device)
                loss.backward()
                optimizer.step()

                tepoch.set_postfix(loss=loss.item())
                epoch_loss += loss.item()

                # Accuracy 계산
                predicted = (outputs > 0.5).float()
                correct_train += (predicted == affinities).sum().item()
                total_train += affinities.size(0)

        avg_train_loss = epoch_loss / len(train_loader)
        train_acc = correct_train / total_train

        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                batch_matrices, affinities = batch
                batch_matrices = batch_matrices.to(device)
                affinities = affinities.to(device)
                
                outputs = model(batch_matrices)
                loss = criterion(outputs, affinities)
                val_loss += loss.item()

                # Accuracy 계산
                predicted = (outputs > 0.5).float()
                correct_val += (predicted == affinities).sum().item()
                total_val += affinities.size(0)

        avg_val_loss = val_loss / len(val_loader)
        val_acc = correct_val / total_val

        epoch_time = int(time.time() - start_time)
        logging.info('', extra={'epoch': epoch, 'time': epoch_time, 'train_acc': train_acc, 'val_acc': val_acc, 'train_loss': avg_train_loss, 'val_loss': avg_val_loss})
        torch.save(model.state_dict(),f"{model_path}-epoch_{epoch}.pt")
        # Early stopping and best model saving
        if avg_val_loss < best_loss:
            best_loss = avg_val_loss
            if not os.path.exists('models'):
                os.makedirs('models')
            torch.save(model.state_dict(),f"{model_path}-best.pt")
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                early_stop = True

In [3]:
# DataProvider 객체 생성
data_provider = DataProvider(
    epi_path=config["Data"]["epi_path"],
    epi_args=config["Data"]["epi_args"],
    hla_path=config["Data"]["hla_path"],
    hla_args=config["Data"]["hla_args"],
)
encoder = config["encoder"]
x, y = encoder(data_provider)

# 데이터를 학습용과 검증용으로 나누기 (클래스 비율 유지)
x_train, x_val, y_train, y_val = train_test_split(
    x, y, stratify=y,
    test_size=config["Data"]["val_size"],
    random_state=42
    )

# numpy 배열을 텐서로 변환
x_train_tensor = torch.tensor(x_train, dtype=torch.float32).unsqueeze(1)
x_val_tensor = torch.tensor(x_val, dtype=torch.float32).unsqueeze(1)
y_train_tensor = torch.tensor(y_train, dtype=torch.float32).unsqueeze(1)
y_val_tensor = torch.tensor(y_val, dtype=torch.float32).unsqueeze(1)

# TensorDataset으로 변환
train_dataset = torch.utils.data.TensorDataset(x_train_tensor, y_train_tensor)
val_dataset = torch.utils.data.TensorDataset(x_val_tensor, y_val_tensor)
print(f"Samples in train set: {len(train_dataset)}")
print(f"Samples in validation set: {len(val_dataset)}")

# DataLoader를 사용하여 데이터를 배치로 로드
batch_size = config["Train"]["batch_size"]
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

Number of HLA alleles: 7044
Number of samples: 238429
Number of samples: 238429
Samples in train set: 190743
Samples in validation set: 47686


In [4]:
# 모델 초기화
model = config["model"]()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f'Model loaded on {device}')

# 손실 함수 및 옵티마이저 정의
criterion = config["Train"]["criterion"]()
optimizer = config["Train"]["optimizer"](
    model.parameters(),
    **config["Train"]["optimizer_args"]
    )

# 모델 학습
train_model(
    model           = model,
    train_loader    = train_loader,
    val_loader      = val_loader,
    criterion       = criterion,
    optimizer       = optimizer,
    device          = device,
    log_file        = config["Train"]["log_file"],
    num_epochs      = config["Train"]["num_epochs"],
    patience        = config["Train"]["patience"],
    regularize      = config["Train"]["regularize"],
    model_path      = os.path.join(config["chkp_path"], config["chkp_name"])
    )

Model loaded on cuda


Epoch 1/100:  38%|███▊      | 71/187 [00:16<00:26,  4.38batch/s, loss=0.534]