In [1]:
import pytorch_lightning as pl
from torchvision import transforms
from torch.utils.data import Dataset
from torch.utils.data import DataLoader, random_split
import torch
import kagglehub
from PIL import Image
import os
from torch import nn
from torchmetrics import Accuracy
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import EarlyStopping
import torchvision

data_path = kagglehub.dataset_download("navoneel/brain-mri-images-for-brain-tumor-detection")

In [2]:
class TumorDataset(Dataset):
    def __init__(self, data_path, transforms = None):

        self.transforms = transforms
        imgs = []
        labels = []
        for label in ['no', 'yes']:
            for img in os.listdir(data_path + '/' + label ):
                img = Image.open(data_path + '/' +label +'/'+ img)
                imgs.append(img)
                labels.append({'no':0,'yes':1}[label])
        self.X = imgs
        self.y = labels

    def __len__(self):
        return len(self.X)

    def __getitem__(self, idx):

        image, label = self.X[idx], self.y[idx]

        if self.transforms is not None:
            image = self.transforms(image)

        return image, label


In [3]:
class Dataset(pl.LightningDataModule):

    def __init__(self):

        super().__init__()
        self.transforms = transforms.Compose([transforms.Grayscale(num_output_channels=1),
                                              transforms.Resize((256, 256)),
                                              transforms.ToTensor(),
                                              transforms.Lambda(lambda x: x.expand(3, -1, -1))])
    def prepare_data(self):

        self.data_path = kagglehub.dataset_download("navoneel/brain-mri-images-for-brain-tumor-detection")

    def setup(self, stage):

        tumor_dataset = TumorDataset(self.data_path, self.transforms)

        self.train, self.val, self.test = random_split(
            tumor_dataset, [190, 31, 32], generator=torch.Generator().manual_seed(42)
        )

    def train_dataloader(self):
        return DataLoader(self.train, batch_size=8, num_workers=8)

    def val_dataloader(self):
        return DataLoader(self.val, batch_size=8, num_workers=8)

    def test_dataloader(self):
        return DataLoader(self.test, batch_size=8, num_workers=8)

In [5]:
import torchvision.models as models

# Load pretrained ResNet (e.g., ResNet18)
model = models.resnet18(pretrained=True)

# Modify the final fully connected layer for brain tumor classification
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)  # Set num_classes (e.g., 2 for tumor/no-tumor)

Downloading: "https://download.pytorch.org/models/resnet18-f37072fd.pth" to /root/.cache/torch/hub/checkpoints/resnet18-f37072fd.pth
100%|██████████| 44.7M/44.7M [00:00<00:00, 184MB/s] 


In [6]:
class Resnet(pl.LightningModule):

    def __init__(self, model, *args, **kwargs):
        super().__init__()

        self.train_acc = Accuracy(task="binary")
        self.valid_acc = Accuracy(task="binary")
        self.test_acc = Accuracy(task="binary")

        self.model = model

    def forward(self, x):

        x = self.model(x)

        return x


    def _common_step(self, x, y):

            logits = self(x)
            loss = nn.functional.cross_entropy(logits, y)
            preds = torch.argmax(logits, dim=1)
        
            return loss, preds

    def training_step(self, batch, batch_idx):
    
        x, y = batch
        loss, preds = self._common_step(x, y)
        self.train_acc.update(preds, y)
        
        self.log("train_loss", loss, prog_bar=True)

        return loss

    def on_training_epoch_end(self, outs):

        self.log("train_acc", self.train_acc.compute(), prog_bar=True)
        self.train_acc.reset()

    def validation_step(self, batch, batch_idx):

        x, y = batch
        loss, preds = self._common_step(x, y)
        self.valid_acc.update(preds, y)

        logits = self(x)
        probs = torch.softmax(logits, dim=1)  # Convert logits to probabilities
        
        # Log images and probabilities
        grid = torchvision.utils.make_grid(x[:4])  # Visualize first 4 images
        self.logger.experiment.add_image("val_images", grid, self.current_epoch)
        
        # Log predicted probabilities as text
        preds = torch.argmax(probs, dim=1)
        self.logger.experiment.add_text(
            "val_predictions", 
            f"Predicted: {preds.tolist()}, Probabilities: {probs.tolist()}", 
            self.current_epoch
        )
        
        self.log("val_loss", loss)

        return loss

    def on_validation_epoch_end(self):

        self.log("valid_acc", self.valid_acc.compute(), prog_bar=True)
        self.valid_acc.reset()

    def test_step(self, batch, batch_idx):

        x, y = batch
        loss, preds = self._common_step(x, y)
        self.test_acc.update(preds, y)
        
        self.log("test_loss", loss)
        self.log("test_acc", self.test_acc.compute(), prog_bar=True)

        return loss

    def configure_optimizers(self):

        optimizer = torch.optim.Adam(self.parameters(), lr=0.001)

        return optimizer

In [7]:
dataset = Dataset()
net = Resnet(model)
callbacks = [ModelCheckpoint(save_top_k=1, mode='max', monitor="valid_acc")]  # save top 1 model
trainer = pl.Trainer(max_epochs=50, callbacks=callbacks, accelerator='gpu', devices=1)

In [8]:
trainer.fit(model=net, datamodule=dataset)

2025-06-22 11:29:05.175189: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1750591745.624320      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1750591745.747070      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/usr/local/lib/python3.11/dist-packages/pytorch_lightning/loops/fit_loop.py:310: The number of training batches (24) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

In [9]:
trainer.test(model=net, datamodule=dataset, ckpt_path='best')

Testing: |          | 0/? [00:00<?, ?it/s]

[{'test_loss': 0.13797251880168915, 'test_acc': 0.8828125}]