In [1]:
import os
import shutil
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, models, datasets
from sklearn.model_selection import train_test_split
from collections import Counter
from PIL import Image
import gc

In [2]:
torch.cuda.empty_cache()
gc.collect()

5

In [3]:
Dataset= r"C:\Users\krish\Downloads\bone break dataset\Bone Break Classification\Bone Break Classification Processed"
processed_data = Dataset

In [4]:
def augment_image(image):
    augmentations = [
        transforms.RandomRotation(30),
        transforms.RandomAffine(0, shear=15, scale=(0.7, 1.3)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomVerticalFlip(),
        transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.2),
        transforms.RandomPerspective(distortion_scale=0.5, p=0.5),
        transforms.RandomGrayscale(p=0.2)
    ]
    transform = transforms.Compose([transforms.RandomChoice(augmentations), transforms.ToTensor()])
    return transform(image)


In [5]:
def balance_classes():
    train_path = os.path.join(processed_data, 'train')
    if not os.path.exists(train_path):
        raise FileNotFoundError(f"Train directory not found: {train_path}")
    
    class_counts = {}
    for cls in os.listdir(train_path):
        cls_path = os.path.join(train_path, cls)
        if os.path.isdir(cls_path):
            images = [img for img in os.listdir(cls_path) if img.endswith(('png', 'jpg', 'jpeg'))]
            class_counts[cls] = len(images)
    
    if not class_counts:
        raise ValueError("No images found in training dataset.")
    
    max_count = max(class_counts.values())
    print("Class distribution before balancing:", class_counts)

In [6]:
balance_classes()

Class distribution before balancing: {'Avulsion fracture': 109, 'Comminuted fracture': 134, 'Fracture Dislocation': 137, 'Greenstick fracture': 106, 'Hairline Fracture': 101, 'Impacted fracture': 75, 'Longitudinal fracture': 68, 'Oblique fracture': 69, 'Pathological fracture': 116, 'Spiral Fracture': 74}


In [7]:
transform = {
    'train': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(20),
        transforms.ColorJitter(brightness=0.2, contrast=0.2),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ]),
    'test': transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])
}

In [8]:
dataset = {split: datasets.ImageFolder(root=os.path.join(processed_data, split), transform=transform[split]) 
           for split in ['train', 'test']}

dataloader = {split: DataLoader(dataset[split], batch_size=16, shuffle=True, num_workers=2)
              for split in ['train', 'test']}

In [9]:
class BoneFractureModel(nn.Module):
    def __init__(self, num_classes=10):
        super(BoneFractureModel, self).__init__()
        self.model = models.resnet18(pretrained=True)
        self.model.fc = nn.Sequential(
            nn.Dropout(0.3),  # Added dropout for regularization
            nn.Linear(self.model.fc.in_features, num_classes)
        )
        #self.model.fc = nn.Linear(self.model.fc.in_features, num_classes)
    
    def forward(self, x):
        return self.model(x)

In [10]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BoneFractureModel(num_classes=10).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
#optimizer = optim.Adam(model.parameters(), lr=0.01)  # Start with high LR
#scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)  # Reduce LR every 5 epochs



In [11]:
def train_model(epochs=80):
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0
        for inputs, labels in dataloader['train']:
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
        train_accuracy = 100 * correct / total
        print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader['train'])}, Accuracy: {train_accuracy:.2f}%")
        

In [12]:
def test_model():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in dataloader['test']:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
    test_accuracy = 100 * correct / total
    print(f"Final Test Accuracy: {test_accuracy:.2f}%")

In [13]:
train_model()
test_model()

Epoch 1/80, Loss: 2.3799915640584883, Accuracy: 15.77%
Epoch 2/80, Loss: 2.2158697882006244, Accuracy: 19.51%
Epoch 3/80, Loss: 2.1579455887117693, Accuracy: 24.87%
Epoch 4/80, Loss: 2.1125929663258214, Accuracy: 26.90%
Epoch 5/80, Loss: 2.093701602951173, Accuracy: 26.59%
Epoch 6/80, Loss: 2.026901823859061, Accuracy: 29.93%
Epoch 7/80, Loss: 1.9643663552499586, Accuracy: 32.25%
Epoch 8/80, Loss: 1.9057581059394344, Accuracy: 35.19%
Epoch 9/80, Loss: 1.8839016383694065, Accuracy: 35.69%
Epoch 10/80, Loss: 1.8312791316739974, Accuracy: 34.48%
Epoch 11/80, Loss: 1.7782299230175633, Accuracy: 39.43%
Epoch 12/80, Loss: 1.753892396726916, Accuracy: 39.64%
Epoch 13/80, Loss: 1.700439541570602, Accuracy: 41.05%
Epoch 14/80, Loss: 1.6251674063744084, Accuracy: 44.59%
Epoch 15/80, Loss: 1.5743282081619385, Accuracy: 47.02%
Epoch 16/80, Loss: 1.5494909622976858, Accuracy: 47.02%
Epoch 17/80, Loss: 1.4986407766419072, Accuracy: 49.04%
Epoch 18/80, Loss: 1.3909481927271812, Accuracy: 53.49%
Epoch