In [None]:
import sys
import importlib.util

# Add the directory containing MedViT.py and utils.py to the Python path
sys.path.append('/Users/apple/Desktop/PG/Summer-24/image-DL/knee-arthritis-detection-algo')

# Manually add the MedViT module to sys.modules
medvit_module_name = "MedViT"
medvit_file_path = "/Users/apple/Desktop/PG/Summer-24/image-DL/knee-arthritis-detection-algo/modules/MedViT.py"
utils_file_path = "/Users/apple/Desktop/PG/Summer-24/image-DL/knee-arthritis-detection-algo/modules/utils.py"

In [None]:
spec = importlib.util.spec_from_file_location(medvit_module_name, medvit_file_path)
medvit_module = importlib.util.module_from_spec(spec)
sys.modules[medvit_module_name] = medvit_module
spec.loader.exec_module(medvit_module)

# Load the utils module
utils_spec = importlib.util.spec_from_file_location("utils", utils_file_path)
utils_module = importlib.util.module_from_spec(utils_spec)
utils_spec.loader.exec_module(utils_module)

# Import any other necessary modules or components from MedViT and utils
from MedViT import MedViT_small

# Instantiate the model
model = MedViT_small(pretrained=True)

# Modify the final layer to match the number of classes
num_classes = 5
model.proj_head = nn.Sequential(
    nn.Linear(model.proj_head[0].in_features, num_classes)
)

# Send the model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)


In [None]:
# Load the utils module
utils_file_path = "/Users/apple/Desktop/PG/Summer-24/image-DL/knee-arthritis-detection-algo/modules/utils.py"
utils_spec = importlib.util.spec_from_file_location("utils", utils_file_path)
utils_module = importlib.util.module_from_spec(utils_spec)
utils_spec.loader.exec_module(utils_module)

# Pass the merge_pre_bn function to MedViT.py
medvit_file_path = "/Users/apple/Desktop/PG/Summer-24/image-DL/knee-arthritis-detection-algo/modules/MedViT.py"
spec = importlib.util.spec_from_file_location("MedViT", medvit_file_path)
medvit_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(medvit_module)

# Assuming the model class is named MedViT in the module
#model = medvit_module.MedViT(pretrained=True)


In [None]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
import cv2
import numpy as np
import torch.nn as nn
import torch.optim as optim

# Data Preparation
class KneeDataset(Dataset):
    def __init__(self, data_path, categories, img_size=224):
        self.data_path = data_path
        self.categories = categories
        self.img_size = img_size
        self.data = []
        self.labels = []
        self.label_dict = {category: i for i, category in enumerate(categories)}
        self._load_data()

    def _load_data(self):
        for category in self.categories:
            folder_path = os.path.join(self.data_path, category)
            img_names = os.listdir(folder_path)
            for img_name in img_names:
                img_path = os.path.join(folder_path, img_name)
                img = cv2.imread(img_path)
                if img is not None:
                    img = cv2.resize(img, (self.img_size, self.img_size))
                    img = (img - 128) / 128 * 1024  # Normalize to [-1024, 1024]
                    self.data.append(img)
                    self.labels.append(self.label_dict[category])

        self.data = np.array(self.data)
        self.labels = np.array(self.labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img = self.data[idx]
        img = np.transpose(img, (2, 0, 1))  # Convert to CxHxW
        img = torch.tensor(img, dtype=torch.float32)
        label = self.labels[idx]
        label = torch.tensor(label, dtype=torch.long)
        return img, label

data_path = '/Users/apple/Desktop/PG/Summer-24/image-DL/knee-arthritis-detection-algo/Training'
categories = ['1Doubtful', '4Severe', '2Mild', '0Normal', '3Moderate']
img_size = 224

dataset = KneeDataset(data_path, categories, img_size)
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

In [None]:
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

# Training Loop
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

num_epochs = 50
best_accuracy = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), 'best_model.pth')
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}, Best Accuracy: {best_accuracy:.4f}')

    scheduler.step()
