In [4]:
# Teacher Model (Beit)
from torchvision import transforms, models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import os

# Define the dataset class
class FoldDataset(Dataset):
    def __init__(self, root_dir, section, transform=None):
        """
        Initialize dataset.
        root_dir: Directory with all the folds.
        section: One of 'train', 'test', 'val'.
        transform: PyTorch transforms for preprocessing.
        """
        self.root_dir = root_dir
        self.section = section
        self.transform = transform
        self.samples = []

        # Load images from each fold
        for fold in os.listdir(self.root_dir):
            section_path = os.path.join(self.root_dir, fold, self.section)
            for label in ['normal', 'abnormal']:
                label_path = os.path.join(section_path, label)
                if os.path.exists(label_path):
                    for img in os.listdir(label_path):
                        self.samples.append((os.path.join(label_path, img), 0 if label == 'normal' else 1))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Setup model, device, optimizer (already defined in your context)
base_dir = '../data'  # Make sure to define this properly

# Create DataLoader for test dataset
test_dataset = FoldDataset(root_dir=base_dir, section='test', transform=transform)
test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)

# Output to confirm setup
print(f"Test DataLoader setup complete with {len(test_dataset)} images.")


Test DataLoader setup complete with 6204 images.


In [6]:
import torch
from transformers import BeitForImageClassification

model = BeitForImageClassification.from_pretrained('microsoft/beit-base-patch16-224')
model_path = '../trained_beit_model.pth'
device = torch.device("cuda")
model.load_state_dict(torch.load(model_path))
model.to(device)

def evaluate_model(model, data_loader, device):
    model.eval()  # Set the model to evaluation mode
    total = 0
    correct = 0

    with torch.no_grad():  # Turn off gradients for validation, saves memory and computations
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            
            # Get the maximum probabilities and the corresponding class indices from logits
            _, predicted = torch.max(outputs.logits, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f"Accuracy: {accuracy * 100:.2f}%")

# Use the function to evaluate the model
evaluate_model(model, test_loader, device)


Accuracy: 99.44%


In [10]:
# Student Model (Mobilenet CNN) from KD

import torch
from torch import nn
from transformers import BeitForImageClassification

model = models.mobilenet_v3_small(pretrained=True)
model.classifier[3] = nn.Linear(model.classifier[3].in_features, 1000)
model_path = '../student_model_weights.pth'
device = torch.device("cuda")
model.load_state_dict(torch.load(model_path))
model.to(device)

def evaluate_model(model, data_loader, device):
    model.eval()  # Set the model to evaluation mode
    total = 0
    correct = 0

    with torch.no_grad():  # Turn off gradients for validation, saves memory and computations
        for images, labels in data_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            
            # Get the maximum probabilities and the corresponding class indices from logits
            _, predicted = torch.max(outputs, 1)
            
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f"Accuracy: {accuracy * 100:.2f}%")

# Use the function to evaluate the model
evaluate_model(model, test_loader, device)


Accuracy: 95.41%
