In [25]:
from sklearn.metrics import precision_score, recall_score, f1_score, roc_auc_score, average_precision_score
from sklearn.preprocessing import label_binarize

In [29]:
def validate(model, device, val_loader, criterion):
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    all_predicted = []
    all_labels = []
    all_probs = []

    with torch.no_grad():
        for features, labels in val_loader:
            features, labels = features.to(device), labels.to(device)
            outputs = model(features)
            loss = criterion(outputs, labels)
            total_loss += loss.item()
            _, predicted = torch.max(outputs.data, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            all_predicted.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            all_probs.extend(torch.nn.functional.softmax(outputs.data, 1).cpu().numpy())

    all_predicted = np.array(all_predicted)
    all_labels = np.array(all_labels)
    all_probs = np.array(all_probs)

    val_loss = total_loss / len(val_loader)
    val_accuracy = correct / len(val_loader.dataset)
    precision = precision_score(all_labels, all_predicted, average='macro',zero_division=0)
    recall = recall_score(all_labels, all_predicted, average='macro',zero_division=0)
    f1 = f1_score(all_labels, all_predicted, average='macro',zero_division=0)
    roc_auc = roc_auc_score(label_binarize(all_labels, classes=range(65))[:,1], all_probs[:,1], multi_class='ovo', average='macro')
    aupr = average_precision_score(label_binarize(all_labels, classes=range(65))[:,1], all_probs[:,1], average='macro')

    return val_loss, val_accuracy, precision, recall, f1, roc_auc, aupr


In [30]:
for epoch in tqdm(range(num_epochs)):
    train_loss = train(model, device, train_loader, optimizer, criterion)
    val_loss, val_accuracy, precision, recall, f1, roc_auc, aupr = validate(model, device, val_loader, criterion)
    print(f'Epoch {epoch+1}, Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}, Validation Accuracy: {val_accuracy:.4f}, Precision: {precision:.4f}, recall: {recall:.4f}, F1: {f1:.4f}, AUC: {roc_auc:.4f}, AUPR: {aupr:.4f}')
