<a href="https://colab.research.google.com/github/mobarakol/AI_Medical_Imaging/blob/main/Active_Learning_Training_ReseNet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import gdown

url = 'https://drive.google.com/uc?id=1Oms9X0Vpid_kN8jiSgz-3MhRA5BcmivE'
gdown.download(url,'braintumor.zip',quiet=True)
!unzip -q braintumor.zip -d braintumor

In [None]:
import os
import sys
import argparse
from torchvision import datasets
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import models
import torchvision.transforms as transforms
from collections import Counter

# Set a fixed seed for reproducibility
seed = 42
torch.manual_seed(seed)

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                std=[0.229, 0.224, 0.225])

dataset_ = datasets.ImageFolder(root='/content/braintumor/Training', transform=transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    normalize,
]))

# Define split ratios
train_ratio = 0.8
val_ratio = 0.2

# Calculate lengths for each split
total_size = len(dataset_)
train_size = int(total_size * train_ratio)
val_size = total_size - train_size

# Perform the split
train_split, val_split = torch.utils.data.random_split(dataset_, [train_size, val_size])

# Print split information
print(f"Total samples: {len(dataset_)}")
print(f"Training samples: {len(train_split)}")
print(f"Validation samples: {len(val_split)}")

# Mapping of class indices to class names
class_to_idx = dataset_.class_to_idx
idx_to_class = {v: k for k, v in class_to_idx.items()}

# Extract labels from train_split
train_labels = [dataset_.samples[idx][1] for idx in train_split.indices]

# Count class frequencies
class_counts = Counter(train_labels)

# Map frequencies back to class names
class_frequencies = {idx_to_class[idx]: count for idx, count in class_counts.items()}

# Print class frequencies
print("\nClass frequencies in train_split:")
for class_name, count in class_frequencies.items():
    print(f"Class '{class_name}': {count} images")


# Extract labels from val_split
valid_labels = [dataset_.samples[idx][1] for idx in val_split.indices]

# Count class frequencies
class_counts = Counter(valid_labels)

# Map frequencies back to class names
class_frequencies = {idx_to_class[idx]: count for idx, count in class_counts.items()}

# Print class frequencies
print("\nClass frequencies in val_split:")
for class_name, count in class_frequencies.items():
    print(f"Class '{class_name}': {count} images")

Total samples: 5712
Training samples: 4569
Validation samples: 1143

Class frequencies in train_split:
Class 'glioma': 1051 images
Class 'notumor': 1276 images
Class 'pituitary': 1166 images
Class 'meningioma': 1076 images

Class frequencies in val_split:
Class 'meningioma': 263 images
Class 'glioma': 270 images
Class 'pituitary': 291 images
Class 'notumor': 319 images


In [None]:
import os
import sys
import argparse
from torchvision import datasets
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import models
import torchvision.transforms as transforms


def get_arguments():
    parser = argparse.ArgumentParser(description='CIFAR-10H Training')
    parser.add_argument('--lr', default=1e-4, type=float, help='learning rate')
    parser.add_argument('--lr_schedule', default=0, type=int, help='lr scheduler')
    parser.add_argument('--batch_size', default=128, type=int, help='batch size')
    parser.add_argument('--test_batch_size', default=256, type=int, help='batch size')
    parser.add_argument('--num_epoch', default=10, type=int, help='epoch number')
    parser.add_argument('--num_classes', type=int, default=4, help='number classes')
    if 'ipykernel' in sys.modules:
        args = parser.parse_args([])
    else:
        args = parser.parse_args()
    return args

def train(model, trainloader, criterion, optimizer):
    model.train()
    for batch_idx, (inputs, targets) in enumerate(trainloader):
        inputs, targets = inputs.to(device), targets.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

def test(model, testloader):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    return correct / total

if __name__ == '__main__':
    args = get_arguments()
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    trainloader = DataLoader(train_split, batch_size=args.batch_size, shuffle=True, num_workers=2)
    testloader = DataLoader(val_split, batch_size=args.test_batch_size, shuffle=False, num_workers=2)
    print('Training on:', device, 'train sample size:', len(train_split), 'test sample size:', len(val_split))

    model = models.resnet34(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, args.num_classes)
    model.to(device)
    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, betas=(0.9, 0.999), weight_decay=0.1)
    criterion = nn.CrossEntropyLoss()

    best_epoch, best_acc = 0.0, 0
    for epoch in range(args.num_epoch):
        train(model, trainloader, criterion, optimizer)
        accuracy = test(model, testloader)
        if accuracy > best_acc:
            patience = 0
            best_acc = accuracy
            best_epoch = epoch
            torch.save(model.state_dict(), 'best_model_{}.pth.tar'.format(epoch))
        print('epoch: {}  acc: {:.4f}  best epoch: {}  best acc: {:.4f}'.format(
                epoch, accuracy, best_epoch, best_acc, optimizer.param_groups[0]['lr']))


Training on: cuda train sample size: 4569 test sample size: 1143
