## Importing libraries

In [None]:
import os
import torchvision
import torch
import numpy as np

from PIL import Image
from matplotlib import pyplot as plt

# torch.manual_seed(0)

print('Torch version: ', torch.__version__)

In [None]:
def load_data(image_dir):
            images = []
            labels = []

            for filename in os.listdir(os.path.join(image_dir, "images")):
                if filename.endswith('.jpg'):
                    img_path = os.path.join(image_dir, "images", filename)
                    label_path = os.path.join(image_dir, "labels", filename.replace('.jpg', '.txt'))

                    with open(label_path, 'r') as label_file:
                        line = label_file.readline().strip()
                        if not line:
                            continue

                        label = int(line.split()[0])
                        labels.append(label)

                images.append(filename)

            return images, labels
        
dataset_dir = '/kaggle/input/medical-image-dataset-brain-tumor-detection/Brain Tumor Detection/train'

# Load the data
images, labels = load_data(dataset_dir)
num_classes = len(np.unique(labels))
num_classes

## Creating custom dataset

In [None]:
class BrainTumorDataset(torch.utils.data.Dataset):
    def __init__(self, image_dir, transform=None):
        def load_data(image_dir):
            images = []
            labels = []

            for filename in os.listdir(os.path.join(image_dir, "images")):
                if filename.endswith('.jpg'):
                    img_path = os.path.join(image_dir, "images", filename)
                    label_path = os.path.join(image_dir, "labels", filename.replace('.jpg', '.txt'))

                    with open(label_path, 'r') as label_file:
                        line = label_file.readline().strip()
                        if not line:
                            continue

                        label = int(line.split()[0])
                        labels.append(label)

                images.append(filename)

            return images, labels
        
        self.image_dir = image_dir
        self.images, self.labels = load_data(self.image_dir)
        self.transform = transform
        
    def __len__(self):
        return len(self.images)
        
    def __getitem__(self, index):
        image_path = os.path.join(self.image_dir, "images", self.images[index])
        image = Image.open(image_path).convert('RGB')
        return self.transform(image), self.labels[index]

## Image transformation

In [None]:
train_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=(224, 224)),
    torchvision.transforms.RandomHorizontalFlip(),
    torchvision.transforms.ToTensor()
])

In [None]:
test_transform = torchvision.transforms.Compose([
    torchvision.transforms.Resize(size=(224, 224)),
    torchvision.transforms.ToTensor()
])

## Prepare dataloader

In [None]:
train_dir = '/kaggle/input/medical-image-dataset-brain-tumor-detection/Brain Tumor Detection/train'
train_dataset = BrainTumorDataset(train_dir, transform=train_transform)

In [None]:
valid_dir = '/kaggle/input/medical-image-dataset-brain-tumor-detection/Brain Tumor Detection/valid'
valid_dataset = BrainTumorDataset(valid_dir, transform=test_transform)

test_dir = '/kaggle/input/medical-image-dataset-brain-tumor-detection/Brain Tumor Detection/test'
test_dataset = BrainTumorDataset(test_dir, transform=test_transform)

In [None]:
print('Number of training examples: ', len(train_dataset))
print('Number of validation examples: ', len(valid_dataset))
print('Number of testing examples: ', len(test_dataset))

In [None]:
batch_size = 128

dl_train = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
dl_valid = torch.utils.data.DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
dl_test = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

print('Number of training batches: ', len(dl_train))
print('Number of validation batches: ', len(dl_valid))
print('Number of testing batches: ', len(dl_test))

## Data visualization

In [None]:
def show_images(images, labels, preds):
    plt.figure(figsize=(10, 7))
    for i, image in enumerate(images):
        plt.subplot(1, 5, i+1, xticks=[], yticks=[])
        image = image.numpy().transpose((1, 2, 0))
        image = np.clip(image, 0., 1.)
        plt.imshow(image)
        
        col = 'green' if preds[i] == labels[i] else 'red'
        plt.xlabel(f'Tumor' if labels[i] else 'Non-Tumor')
        plt.ylabel(f'Tumor' if preds[i] else 'Non-Tumor', color=col)
    plt.tight_layout()
    plt.show()

In [None]:
images, labels = next(iter(dl_train))
show_images(images[:5], labels[:5], labels[:5])

## Creating model

In [None]:
model = torchvision.models.resnet18(weights=True)
for param in model.parameters():
    param.requires_grad = False
print(model)

In [None]:
model.fc = torch.nn.Linear(in_features=512, out_features=num_classes)
loss_fn = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

In [None]:
def show_preds():
    model.eval()
    images, labels = next(iter(dl_test))
    outputs = model(images)
    _, preds = torch.max(outputs, 1)
    show_images(images[:5], labels[:5], preds[:5])

In [None]:
show_preds()

## Traing the model

In [None]:
def train(epochs):
    print('Training start...')
    for e in range(0, epochs):
        print('='*40)
        print(f'Starting epoch {e + 1}/{epochs}')
        print('='*40)
        
        train_loss = 0
        val_loss = 0
        train_loss_list = []
        val_loss_list = []
        acc_list = []
        
        model.train()
        
        for train_step, (images, labels) in enumerate(dl_train):
#             images, labels = images.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(images)
            loss = loss_fn(outputs, labels)
            loss.backward()
            optimizer.step()
            train_loss_list.append(loss.item())
            train_loss += loss.item()
            
            if train_step % 20 == 0:
                print('Evaluating at step ', train_step)
                acc = 0
                
                model.eval()
                
                for val_step, (images, labels) in enumerate(dl_valid):
#                     images, labels = images.to(device), labels.to(device)
                    outputs = model(images)
                    loss = loss_fn(outputs, labels)
                    val_loss_list.append(loss.item())
                    val_loss += loss.item()
                    
                    _, preds = torch.max(outputs, 1)
                    acc += sum((preds == labels).numpy())
                
                val_loss /= (val_step + 1)
                acc = acc / len(valid_dataset)
                acc_list.append(acc)
                print(f'Val loss: {val_loss:.4f}, Acc: {acc:.4f}')
                show_preds()
                
                model.train()
                
                if acc >= 0.95:
                    print('Performance achieved')
                    return
        train_loss /= (train_step + 1)
        print(f'Training loss: {train_loss:.4f}')
    print('Training complete')
    return train_loss_list, val_loss_list, acc_list

In [None]:
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# print("the device type is", device)
# resnet18 = resnet18.to(device)

In [None]:
train_loss, val_loss, acc = train(epochs=6)