# **[심층신경망개론] Group 1 DeiT 구현**




In [7]:
import torch
import torch.nn as nn
from functools import partial

from timm.models.vision_transformer import VisionTransformer, _cfg
from timm.models.registry import register_model
from timm.models.layers import trunc_normal_

__all__ = [
    'deit_tiny_patch16_224', 'deit_small_patch16_224', 'deit_base_patch16_224',
    'deit_tiny_distilled_patch16_224', 'deit_small_distilled_patch16_224',
    'deit_base_distilled_patch16_224', 'deit_base_patch16_384',
    'deit_base_distilled_patch16_384',
]


class DistilledVisionTransformer(VisionTransformer):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
        num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 2, self.embed_dim))
        self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()

        trunc_normal_(self.dist_token, std=.02)
        trunc_normal_(self.pos_embed, std=.02)
        self.head_dist.apply(self._init_weights)

    def forward_features(self, x):
        # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py
        # with slight modifications to add the dist_token
        B = x.shape[0]
        x = self.patch_embed(x)

        cls_tokens = self.cls_token.expand(B, -1, -1)  # stole cls_tokens impl from Phil Wang, thanks
        dist_token = self.dist_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, dist_token, x), dim=1)

        x = x + self.pos_embed
        x = self.pos_drop(x)

        for blk in self.blocks:
            x = blk(x)

        x = self.norm(x)
        return x[:, 0], x[:, 1]

    def forward(self, x):
        x, x_dist = self.forward_features(x)
        x = self.head(x)
        x_dist = self.head_dist(x_dist)
        if self.training:
            return x, x_dist
        else:
            # during inference, return the average of both classifier predictions
            return (x + x_dist) / 2


@register_model
def deit_tiny_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes=num_classes,
        drop_rate=drop_rate, drop_path_rate=drop_path_rate, ** 0.5, **kwargs
    )
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_small_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes = num_classes, drop_rate=drop_rate, drop_path_rate=drop_path_rate,  **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_base_patch16_224(pretrained=False, **kwargs):
    model = VisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes = num_classes, drop_rate=drop_rate, drop_path_rate=drop_path_rate, **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_tiny_distilled_patch16_224(pretrained=False, num_classes = 5, drop_rate= 0.2, drop_path_rate=0.1, **kwargs):
    model = DistilledVisionTransformer(
        patch_size=16, embed_dim=192, depth=12, num_heads=3, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes = num_classes, drop_rate=drop_rate, drop_path_rate=drop_path_rate, **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_small_distilled_patch16_224(pretrained=False, num_classes = 5, drop_rate= 0.2, drop_path_rate=0.1, **kwargs):
    model = DistilledVisionTransformer(
        patch_size=16, embed_dim=384, depth=12, num_heads=6, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes = num_classes, drop_rate=drop_rate, drop_path_rate=drop_path_rate, **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_base_distilled_patch16_224(pretrained=False, num_classes = 5, drop_rate= 0.2, drop_path_rate=0.1, **kwargs):
    model = DistilledVisionTransformer(
        patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), num_classes = num_classes, drop_rate=drop_rate, drop_path_rate=drop_path_rate, **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_base_patch16_384(pretrained=False, **kwargs):
    model = VisionTransformer(
        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model


@register_model
def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
    model = DistilledVisionTransformer(
        img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
    model.default_cfg = _cfg()
    if pretrained:
        checkpoint = torch.hub.load_state_dict_from_url(
            url="https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth",
            map_location="cpu", check_hash=True
        )
        model.load_state_dict(checkpoint["model"])
    return model

  @register_model
  @register_model
  @register_model
  @register_model
  @register_model
  @register_model
  @register_model
  @register_model


In [8]:
!pip install wfdb



In [9]:
import pandas as pd
import numpy as np
import wfdb
import ast
import os
from sklearn.preprocessing import MultiLabelBinarizer

def load_raw_data(df, sampling_rate, path):
    """df.index를 기준으로 데이터를 로드"""
    if sampling_rate == 100:
        data = [wfdb.rdsamp(os.path.join(path, f)) for f in df['filename_lr']]
    else:
        data = [wfdb.rdsamp(os.path.join(path, f)) for f in df['filename_hr']]
    # df.index에 있는 데이터만 로드
    data = np.array([signal for signal, meta in data])
    return data

# 데이터 경로 설정
path = "./"
sampling_rate = 100

# PTB-XL 데이터베이스 로드
df = pd.read_csv(os.path.join(path, 'ptbxl_database.csv'), index_col='ecg_id')
df.scp_codes = df.scp_codes.apply(lambda x: ast.literal_eval(x))

# 진단 정보 로드
agg_df = pd.read_csv(os.path.join(path, 'scp_statements.csv'), index_col=0)
agg_df = agg_df[agg_df.diagnostic == 1]

def aggregate_diagnostic(y_dic):
    """진단 클래스를 매핑하는 함수"""
    tmp = []
    for key in y_dic.keys():
        if key in agg_df.index:
            tmp.append(agg_df.loc[key].diagnostic_class)
    return list(set(tmp))

# 진단 클래스 매핑
df['diagnostic_superclass'] = df.scp_codes.apply(aggregate_diagnostic)

# 빈 클래스 제거
df = df[df['diagnostic_superclass'].apply(lambda x: len(x) > 0)]

# Raw data 로드
X = load_raw_data(df, sampling_rate, path)

# 크기 확인
assert len(X) == len(df), "X와 df의 크기가 일치하지 않습니다."

# 데이터셋 분리
test_fold = 10
val_fold = 9

train_filter = (df.strat_fold != test_fold) & (df.strat_fold != val_fold)
val_filter = df.strat_fold == val_fold
test_filter = df.strat_fold == test_fold

X_train = X[train_filter]
y_train = list(df[train_filter]['diagnostic_superclass'])

X_val = X[val_filter]
y_val = list(df[val_filter]['diagnostic_superclass'])

X_test = X[test_filter]
y_test = list(df[test_filter]['diagnostic_superclass'])

# 다중 라벨 이진화
mlb = MultiLabelBinarizer()
y_train_bin = mlb.fit_transform(y_train)
y_val_bin = mlb.transform(y_val)
y_test_bin = mlb.transform(y_test)

print(f"Train Data Shape: {X_train.shape}, Labels: {y_train_bin.shape}")
print(f"Validation Data Shape: {X_val.shape}, Labels: {y_val_bin.shape}")
print(f"Test Data Shape: {X_test.shape}, Labels: {y_test_bin.shape}")



Train Data Shape: (17084, 1000, 12), Labels: (17084, 5)
Validation Data Shape: (2146, 1000, 12), Labels: (2146, 5)
Test Data Shape: (2158, 1000, 12), Labels: (2158, 5)


In [10]:
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import matplotlib.pyplot as plt

# Custom Dataset for ECG Data
class ECGDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]
        label = self.labels[idx]
        if self.transform:
            # ECG 데이터를 3채널로 확장
            sample = self.transform(sample)
        return sample.float(), torch.tensor(label, dtype=torch.float32)

transform = transforms.Compose([
    transforms.ToTensor(),  # Numpy 배열 -> Tensor
    transforms.Resize((224, 224)),  # 이미지 크기 조정
    transforms.Lambda(lambda x: x.expand(3, -1, -1)),  # 1채널 데이터를 3채널로 확장
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])  # 정규화
])


# 데이터셋 정의
train_dataset = ECGDataset(X_train, y_train_bin, transform=transform)
val_dataset = ECGDataset(X_val, y_val_bin, transform=transform)
test_dataset = ECGDataset(X_test, y_test_bin, transform=transform)

# DataLoader 정의
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# 클래스 확인
class_names = mlb.classes_
print(f"Classes: {class_names}")


Classes: ['CD' 'HYP' 'MI' 'NORM' 'STTC']


In [11]:
def train_model(
    model,
    train_loader,
    val_loader=None,
    num_epochs=100,
    patience=3,
    learning_rate=0.001,
    checkpoint_path='deit_checkpoint.pth'
):

    criterion = nn.BCEWithLogitsLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

    # 학습 상태 초기화
    best_loss = float('inf')
    epochs_no_improve = 0
    logs = {'train_loss': [], 'train_accuracy': [], 'train_f1': []}

    if val_loader:
        logs['val_loss'] = []
        logs['val_accuracy'] = []
        logs['val_f1'] = []

    for epoch in range(num_epochs):
        print(f"Starting epoch {epoch+1}/{num_epochs}...")
        model.train()
        running_loss = 0.0
        all_labels = []
        all_preds = []

        # Training loop
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(device), labels.float().to(device)
            optimizer.zero_grad()

            outputs = model(inputs)  # 모델 forward
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item() * inputs.size(0)

            preds = (torch.sigmoid(outputs) > 0.5).int().cpu().numpy()
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds)

        # Train metrics
        epoch_loss = running_loss / len(train_loader.dataset)
        accuracy = accuracy_score(np.vstack(all_labels), np.vstack(all_preds))
        f1 = f1_score(np.vstack(all_labels), np.vstack(all_preds), average='macro')
        print(f'Epoch {epoch+1}, Train Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}, F1 Score: {f1:.4f}')

        logs['train_loss'].append(epoch_loss)
        logs['train_accuracy'].append(accuracy)
        logs['train_f1'].append(f1)

        # Validation loop (if provided)
        if val_loader:
            val_loss, val_accuracy, val_f1 = evaluate_model(model, val_loader, criterion)
            logs['val_loss'].append(val_loss)
            logs['val_accuracy'].append(val_accuracy)
            logs['val_f1'].append(val_f1)

            print(f'Epoch {epoch+1}, Val Loss: {val_loss:.4f}, Val Accuracy: {val_accuracy:.4f}, Val F1: {val_f1:.4f}')

        # Checkpoint and early stopping
        if epoch_loss < best_loss:
            best_loss = epoch_loss
            epochs_no_improve = 0
            torch.save(model.state_dict(), checkpoint_path)
            print(f"Checkpoint saved at epoch {epoch+1}.")
        else:
            epochs_no_improve += 1

        if epochs_no_improve >= patience:
            print(f"Early stopping triggered after {epoch+1} epochs.")
            break

    # Best model 로드
    model.load_state_dict(torch.load(checkpoint_path))
    print("Best model loaded.")
    return logs

def evaluate_model(model, dataloader, criterion):
    """모델 검증 및 평가."""
    model.eval()
    running_loss = 0.0
    all_labels = []
    all_preds = []

    with torch.no_grad():
        for inputs, labels in dataloader:
            inputs, labels = inputs.to(device), labels.float().to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)

            preds = (torch.sigmoid(outputs) > 0.5).int().cpu().numpy()
            all_labels.extend(labels.cpu().numpy())
            all_preds.extend(preds)

    epoch_loss = running_loss / len(dataloader.dataset)
    accuracy = accuracy_score(np.vstack(all_labels), np.vstack(all_preds))
    f1 = f1_score(np.vstack(all_labels), np.vstack(all_preds), average='macro')

    return epoch_loss, accuracy, f1


In [12]:
from timm.models import deit_tiny_distilled_patch16_224
from torch.optim import Adam
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CustomDeiT_tiny(nn.Module):
    def __init__(self, pretrained=False, num_classes=5, drop_rate=0.2):
        super(CustomDeiT_tiny, self).__init__()
        self.model = deit_tiny_distilled_patch16_224(pretrained=pretrained, num_classes=num_classes)
        self.dropout = nn.Dropout(drop_rate)

    def forward(self, x):
        outputs = self.model(x)
        if isinstance(outputs, tuple):
            x1, x2 = outputs
            avg_output = (x1 + x2) / 2
        else:
            avg_output = outputs

        return self.dropout(avg_output)

deit_tiny_model = CustomDeiT_tiny(pretrained=False, num_classes=5, drop_rate=0.2).to(device)

logs_tiny = train_model(
    model=deit_tiny_model,
    train_loader=train_loader,  # 제공된 학습 데이터 로더
    val_loader=val_loader,      # 제공된 검증 데이터 로더
    num_epochs=100,              # 총 학습 epoch 수
    patience=5,                 # 조기 종료를 위한 patience
    learning_rate=0.0001,       # 최적 학습률
    checkpoint_path='best_deit_tiny_model.pth'  # 체크포인트 저장 경로
)

Starting epoch 1/100...
Epoch 1, Train Loss: 0.5578, Accuracy: 0.1330, F1 Score: 0.1335
Epoch 1, Val Loss: 0.4708, Val Accuracy: 0.3756, Val F1: 0.2422
Checkpoint saved at epoch 1.
Starting epoch 2/100...
Epoch 2, Train Loss: 0.4959, Accuracy: 0.3119, F1 Score: 0.3354
Epoch 2, Val Loss: 0.4345, Val Accuracy: 0.4334, Val F1: 0.3853
Checkpoint saved at epoch 2.
Starting epoch 3/100...
Epoch 3, Train Loss: 0.4703, Accuracy: 0.3433, F1 Score: 0.4287
Epoch 3, Val Loss: 0.4108, Val Accuracy: 0.4362, Val F1: 0.5540
Checkpoint saved at epoch 3.
Starting epoch 4/100...
Epoch 4, Train Loss: 0.4573, Accuracy: 0.3620, F1 Score: 0.4658
Epoch 4, Val Loss: 0.4054, Val Accuracy: 0.4366, Val F1: 0.4980
Checkpoint saved at epoch 4.
Starting epoch 5/100...
Epoch 5, Train Loss: 0.4488, Accuracy: 0.3745, F1 Score: 0.4869
Epoch 5, Val Loss: 0.3955, Val Accuracy: 0.4627, Val F1: 0.5414
Checkpoint saved at epoch 5.
Starting epoch 6/100...
Epoch 6, Train Loss: 0.4406, Accuracy: 0.3845, F1 Score: 0.4989
Epoch 6

  model.load_state_dict(torch.load(checkpoint_path))


In [13]:
validation_loader = DataLoader(
    ECGDataset(X_val, y_val_bin, transform=transform),
    batch_size=16,
    shuffle=False
)
print("Evaluating on validation set...")
val_loss, val_accuracy, val_f1 = evaluate_model(deit_tiny_model, validation_loader, nn.BCEWithLogitsLoss())
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.4f}")
print(f"Validation F1 Score: {val_f1:.4f}")

# 테스트 데이터 평가
test_loader = DataLoader(
    ECGDataset(X_test, y_test_bin, transform=transform),
    batch_size=16,
    shuffle=False
)
print("Evaluating on test set...")
test_loss, test_accuracy, test_f1 = evaluate_model(deit_tiny_model, test_loader, nn.BCEWithLogitsLoss())
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")

Evaluating on validation set...
Validation Loss: 0.8121
Validation Accuracy: 0.4995
Validation F1 Score: 0.6028
Evaluating on test set...
Test Loss: 0.8141
Test Accuracy: 0.4884
Test F1 Score: 0.5955


In [14]:
from timm.models import deit_base_distilled_patch16_224
from torch.optim import Adam
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

class CustomDeiT_base(nn.Module):
    def __init__(self, pretrained=False, num_classes=5, drop_rate=0.2):
        super(CustomDeiT_base, self).__init__()
        self.model = deit_base_distilled_patch16_224(pretrained=pretrained, num_classes=num_classes)
        self.dropout = nn.Dropout(drop_rate)

    def forward(self, x):
        outputs = self.model(x)
        if isinstance(outputs, tuple):
            x1, x2 = outputs
            avg_output = (x1 + x2) / 2
        else:
            avg_output = outputs

        return self.dropout(avg_output)

deit_base_model = CustomDeiT_base(pretrained=False, num_classes=5, drop_rate=0.2).to(device)

logs_base = train_model(
    model=deit_base_model,
    train_loader=train_loader,  # 제공된 학습 데이터 로더
    val_loader=val_loader,      # 제공된 검증 데이터 로더
    num_epochs=100,              # 총 학습 epoch 수
    patience=5,                 # 조기 종료를 위한 patience
    learning_rate=0.0001,       # 최적 학습률
    checkpoint_path='best_deit_base_model.pth'  # 체크포인트 저장 경로
)

Starting epoch 1/100...
Epoch 1, Train Loss: 0.5447, Accuracy: 0.2016, F1 Score: 0.2109
Epoch 1, Val Loss: 0.4553, Val Accuracy: 0.3388, Val F1: 0.3649
Checkpoint saved at epoch 1.
Starting epoch 2/100...
Epoch 2, Train Loss: 0.4851, Accuracy: 0.3234, F1 Score: 0.3867
Epoch 2, Val Loss: 0.4153, Val Accuracy: 0.4380, Val F1: 0.4802
Checkpoint saved at epoch 2.
Starting epoch 3/100...
Epoch 3, Train Loss: 0.4700, Accuracy: 0.3416, F1 Score: 0.4280
Epoch 3, Val Loss: 0.4108, Val Accuracy: 0.4669, Val F1: 0.5290
Checkpoint saved at epoch 3.
Starting epoch 4/100...
Epoch 4, Train Loss: 0.4585, Accuracy: 0.3565, F1 Score: 0.4594
Epoch 4, Val Loss: 0.3980, Val Accuracy: 0.4716, Val F1: 0.5598
Checkpoint saved at epoch 4.
Starting epoch 5/100...
Epoch 5, Train Loss: 0.4519, Accuracy: 0.3702, F1 Score: 0.4776
Epoch 5, Val Loss: 0.3913, Val Accuracy: 0.4949, Val F1: 0.5261
Checkpoint saved at epoch 5.
Starting epoch 6/100...
Epoch 6, Train Loss: 0.4483, Accuracy: 0.3775, F1 Score: 0.4860
Epoch 6

  model.load_state_dict(torch.load(checkpoint_path))


In [15]:
validation_loader = DataLoader(
    ECGDataset(X_val, y_val_bin, transform=transform),
    batch_size=16,
    shuffle=False
)
print("Evaluating on validation set...")
val_loss, val_accuracy, val_f1 = evaluate_model(deit_base_model, validation_loader, nn.BCEWithLogitsLoss())
print(f"Validation Loss: {val_loss:.4f}")
print(f"Validation Accuracy: {val_accuracy:.4f}")
print(f"Validation F1 Score: {val_f1:.4f}")

# 테스트 데이터 평가
test_loader = DataLoader(
    ECGDataset(X_test, y_test_bin, transform=transform),
    batch_size=16,
    shuffle=False
)
print("Evaluating on test set...")
test_loss, test_accuracy, test_f1 = evaluate_model(deit_base_model, test_loader, nn.BCEWithLogitsLoss())
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_accuracy:.4f}")
print(f"Test F1 Score: {test_f1:.4f}")

Evaluating on validation set...
Validation Loss: 0.6458
Validation Accuracy: 0.4986
Validation F1 Score: 0.6064
Evaluating on test set...
Test Loss: 0.6835
Test Accuracy: 0.4852
Test F1 Score: 0.5932
