In [1]:
#!pip install torch torchvision timm


In [13]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import timm
from torchvision.datasets import ImageFolder

In [2]:
num_classes = 10  
batch_size = 16
learning_rate = 1e-4
epochs = 20
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [15]:
class ImageFolderWithPaths(datasets.ImageFolder):
    def __getitem__(self, index):
        original_tuple = super(ImageFolderWithPaths, self).__getitem__(index)
        path = self.imgs[index][0]
        return original_tuple + (path,)


In [16]:
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

In [17]:
train_dataset = datasets.ImageFolder(root='Dataset/training', transform=train_transforms)
val_dataset = ImageFolderWithPaths(root='Dataset/validation', transform=val_transforms)

In [18]:
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

In [6]:
model = timm.create_model('vit_large_patch16_224', pretrained=True, num_classes=num_classes)
model = model.to(device)

To support symlinks on Windows, you either need to activate Developer Mode or to run Python as an administrator. In order to see activate developer mode, see this article: https://docs.microsoft.com/en-us/windows/apps/get-started/enable-your-device-for-development


In [19]:
from torch.optim.lr_scheduler import CosineAnnealingLR, StepLR
criterion = torch.nn.CrossEntropyLoss(label_smoothing=0.1)
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=10, eta_min=1e-6)
scheduler = StepLR(optimizer, step_size=7, gamma=0.1)

In [20]:
class_names = train_dataset.classes

def train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, total_epochs=20, initial_epochs=5):
    best_acc = 0.0
    best_model = None

    for param in model.parameters():
        param.requires_grad = False

    for param in model.head.parameters():
        param.requires_grad = True

    for epoch in range(initial_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        print(f"Epoch [{epoch+1}/{initial_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")
        misclassified_images = validate_model(model, val_loader, criterion, best_acc, best_model, class_names)

        scheduler.step()

    for param in model.parameters():
        param.requires_grad = True

    for epoch in range(initial_epochs, total_epochs):
        model.train()
        running_loss = 0.0
        correct = 0
        total = 0

        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            optimizer.zero_grad()
            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()

        print(f"Epoch [{epoch+1}/{total_epochs}], Loss: {running_loss/len(train_loader):.4f}, Accuracy: {100 * correct / total:.2f}%")
        validate_model(model, val_loader, criterion, best_acc, best_model)
        scheduler.step()

def validate_model(model, val_loader, criterion, best_acc, best_model):
    model.eval()
    val_loss = 0.0
    correct = 0
    total = 0
    misclassified_images = []

    with torch.no_grad():
        for images, labels, paths in val_loader:  # Adjust for paths
            images, labels = images.to(device), labels.to(device)

            outputs = model(images)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            for idx in range(len(labels)):
                if predicted[idx] != labels[idx]:
                    misclassified_images.append(paths[idx])  

    accuracy = 100 * correct / total
    print(f'Validation Loss: {val_loss/len(val_loader):.4f}, Accuracy: {accuracy:.2f}%')

    if accuracy > best_acc:
        best_acc = accuracy
        best_model = model.state_dict()
        torch.save(best_model, 'best_model.pth')
    if misclassified_images:
        print("\nMisclassified Images:")
        for img_path in misclassified_images:
            print(f"Misclassified Image: {img_path}")
    return misclassified_images

In [21]:
train_model(model, train_loader, val_loader, criterion, optimizer, scheduler, total_epochs=20, initial_epochs=5)


KeyboardInterrupt: 

In [11]:
torch.save(model.state_dict(), 'vit_medical_image_classification_large.pth')