In [2]:
pip install torchvision

Defaulting to user installation because normal site-packages is not writeable

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m24.0[0m[39;49m -> [0m[32;49m24.1.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49m/Library/Developer/CommandLineTools/usr/bin/python3 -m pip install --upgrade pip[0m
Note: you may need to restart the kernel to use updated packages.


In [3]:
import os
import torch
from torch.utils.data import DataLoader, Dataset
import torchvision.transforms as transforms
from PIL import Image
import numpy as np

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
class KneeDataset(Dataset):
    def __init__(self, data_path, categories, img_size=224):
        self.data_path = data_path
        self.categories = categories
        self.img_size = img_size        
        self.data = []
        self.labels = []
        self.label_dict = {category: i for i, category in enumerate(categories)}
        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.RandomResizedCrop(img_size, scale=(0.8, 1.0)),  # Zoom augmentation
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
        self._load_data()

    def _load_data(self):
        for category in self.categories:
            folder_path = os.path.join(self.data_path, category)
            img_names = os.listdir(folder_path)
            for img_name in img_names:
                img_path = os.path.join(folder_path, img_name)
                img = Image.open(img_path).convert('RGB')
                img = self.transform(img)
                self.data.append(img)
                self.labels.append(self.label_dict[category])

        self.data = torch.stack(self.data)
        self.labels = torch.tensor(self.labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


In [5]:
data_path = '/Users/apple/Desktop/PG/Summer-24/image-DL/knee-arthritis-detection-algo/Training'
categories = ['1Doubtful', '4Severe', '2Mild', '0Normal', '3Moderate']
img_size = 224

dataset = KneeDataset(data_path, categories, img_size)
train_size = int(0.9 * len(dataset))
test_size = len(dataset) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(dataset, [train_size, test_size])

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [6]:
import torch.nn as nn
import torchvision.models as models

# Load the ResNet-50 model
model = models.resnet50(pretrained=True)

# Modify the final layer to match the number of classes
num_classes = len(categories)
model.fc = nn.Linear(model.fc.in_features, num_classes)

# Send the model to the appropriate device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /Users/apple/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [03:11<00:00, 537kB/s] 


In [7]:
import torch.optim as optim

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.9)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=10, gamma=0.1)

num_epochs = 50
best_accuracy = 0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * inputs.size(0)
    epoch_loss = running_loss / len(train_loader.dataset)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), 'best_model.pth')
    print(f'Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}, Accuracy: {accuracy:.4f}, Best Accuracy: {best_accuracy:.4f}')

    scheduler.step()


Epoch 1/50, Loss: 1.3793, Accuracy: 0.3879, Best Accuracy: 0.3879
Epoch 2/50, Loss: 0.9123, Accuracy: 0.6000, Best Accuracy: 0.6000
Epoch 3/50, Loss: 0.6525, Accuracy: 0.5455, Best Accuracy: 0.6000
Epoch 4/50, Loss: 0.4270, Accuracy: 0.7030, Best Accuracy: 0.7030
Epoch 5/50, Loss: 0.2732, Accuracy: 0.6727, Best Accuracy: 0.7030
Epoch 6/50, Loss: 0.1863, Accuracy: 0.7091, Best Accuracy: 0.7091
Epoch 7/50, Loss: 0.1259, Accuracy: 0.7091, Best Accuracy: 0.7091
Epoch 8/50, Loss: 0.0929, Accuracy: 0.6848, Best Accuracy: 0.7091
Epoch 9/50, Loss: 0.0654, Accuracy: 0.7030, Best Accuracy: 0.7091
Epoch 10/50, Loss: 0.0542, Accuracy: 0.6970, Best Accuracy: 0.7091
Epoch 11/50, Loss: 0.0412, Accuracy: 0.7030, Best Accuracy: 0.7091
Epoch 12/50, Loss: 0.0387, Accuracy: 0.7030, Best Accuracy: 0.7091
Epoch 13/50, Loss: 0.0318, Accuracy: 0.7091, Best Accuracy: 0.7091
Epoch 14/50, Loss: 0.0296, Accuracy: 0.6970, Best Accuracy: 0.7091
Epoch 15/50, Loss: 0.0294, Accuracy: 0.7030, Best Accuracy: 0.7091
Epoc

KeyboardInterrupt: 

In [None]:
# Best accuracy was 80.61 when using lr=1e-2, step size = 30 and no. of epochs = 25
# Best accuracy was 78.18 when using lr=1e-3, step size = 25 and no. of epochs = 25
# Best accuracy was 81.21 when using lr=1e-2, step size = 50 and no. of epochs = 20
# Best accuracy was 80.61 when using lr=1e-2, step size = 75 and no. of epochs = 20
# Best accuracy was 70.91 when using lr=1e-3, step size = 10 and no. of epochs = 20 (was supposed to be 50)