In [None]:
import torch
import torch.nn as nn
from torchvision import datasets, transforms, models
from torch.utils.data import DataLoader

In [None]:
from google.colab import drive
drive.mount('/content/drive')
train_path = '/content/drive/MyDrive/archive/Training'
test_path = '/content/drive/MyDrive/archive/Testing'


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [None]:
# 1. Transform for grayscale MRI scans
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.Grayscale(num_output_channels=3),  # Convert grayscale to 3-channel for pretrained models
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)  # Optional: normalize to match pretrained model expectations
])

In [None]:
# 2. Simulated MRI dataset (use real data in production)
train_data = datasets.ImageFolder(root=train_path, transform=transform)
test_data = datasets.ImageFolder(root=test_path, transform=transform)

train_loader = DataLoader(train_data, batch_size=8, shuffle=True)
test_loader = DataLoader(test_data, batch_size=8)


In [None]:
# Optional: Check class labels

print("Training classes:", train_data.classes)
print("Testing classes:", test_data.classes)


Training classes: ['glioma', 'meningioma', 'notumor', 'pituitary']
Testing classes: ['glioma', 'meningioma', 'notumor', 'pituitary']


In [None]:
# 3. Define MRI classifier (ResNet18 fine-tuned)
class MRINet(nn.Module):
    def __init__(self):
        super().__init__()
        self.base = models.resnet18(pretrained=True)
        self.base.fc = nn.Linear(self.base.fc.in_features, 4) # Changed to 4 output features

    def forward(self, x):
        return self.base(x)

In [None]:
# 4. Setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = MRINet().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()



In [None]:
# 5. Training function
def train():
    model.train()
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = loss_fn(outputs, labels)
        loss.backward()
        optimizer.step()

In [None]:
# 6. Evaluation function
def evaluate():
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            preds = torch.argmax(model(images), dim=1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)
    print(f"Test Accuracy: {correct / total:.2f}")

In [None]:
# 7. Run training
for epoch in range(1, 6):
    train()
    print(f"Epoch {epoch}")
    evaluate()

Epoch 1
Test Accuracy: 0.91
Epoch 2
Test Accuracy: 0.93
Epoch 3
Test Accuracy: 0.93
Epoch 4
Test Accuracy: 0.94
Epoch 5
Test Accuracy: 0.95
