In [73]:
import torch
from torch import nn
import torch.nn.functional as F
import torchvision
import torchvision.models as models
from torchvision.datasets import ImageFolder
from torchvision.transforms import transforms
from torch.utils.data import DataLoader
import numpy as np
import os
import matplotlib.pyplot as plt
from sklearn.metrics import classification_report, log_loss, accuracy_score
from sklearn.model_selection import train_test_split
from tqdm import tqdm
import transformers

In [74]:
####################
# NOTE: this notebook assumes the brain tumor data is in a directory at the root
# of this repository called 'BrainTumorData'.
#
# To make this work:
# 1 - go to https://www.kaggle.com/datasets/sartajbhuvaji/brain-tumor-classification-mri
# 2 - download the .zip file to this directory
# 3 - unzip the archive
# 4 - rename the expanded directory to BrainTumorData


data_directory = os.path.join(os.getcwd(), 'BrainTumorData')
train_dir = os.path.join(data_directory, 'Training')
test_dir = os.path.join(data_directory, 'Testing')

train_imgs = []
train_labels = []

test_imgs = []
test_labels = []


# get labels from the file structure
for label in os.listdir(train_dir):
    for img in os.listdir(os.path.join(train_dir, label)):
        train_imgs.append (img) 
        train_labels.append(label)
        
for label in os.listdir(train_dir):
    for img in os.listdir(os.path.join(train_dir, label)):
        train_imgs.append (img) 
        train_labels.append(label)


In [75]:
# dataloaders only support tensors, np arrays, lists, dicts, and numbers
# define this function to cast PIL images to tensors
transform_fn = transforms.Compose([
    transforms.Resize((150,150)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),  
    transforms.Normalize([0.5,0.5,0.5], 
                        [0.5,0.5,0.5])
])

# create a validation set from the test sset
x_train, x_val, y_train, y_val = train_test_split(np.array(train_imgs), np.array(train_labels), test_size=0.1)

# create dataloaders
train_loader=DataLoader(
    torchvision.datasets.ImageFolder(train_dir, transform=transform_fn),
    batch_size=64,
    shuffle=True,
)
test_loader=DataLoader(
    torchvision.datasets.ImageFolder(test_dir, transform=transform_fn),
    batch_size=32,
    shuffle=False,
)

In [76]:
# get class names
class_names = [label for label in os.listdir(train_dir)]

# support GPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [77]:
# define the model and send it to the device
model = models.resnet18().to(device)

In [87]:
# some deep learning overhead to help the model learn better
# can tweak these if you want but don't need to
warmup_steps = 5000 # delays using learning rate
num_epochs = 10 # arbitrary for now
learning_rate = 5e-3 # intial LR used after warmup
weight_decay = 0.0 # throttles optimizer -- 0 for now

num_update_steps_per_epoch = len(train_loader)
max_train_steps = num_epochs * num_update_steps_per_epoch

optimizer = torch.optim.AdamW(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay,
)

lr_scheduler = transformers.get_scheduler(
    name='linear',
    optimizer=optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=max_train_steps,
)

In [None]:
def evaluate_model(model, test_loader, device):
    # set the model to evaluation mode
    model.eval()
    for batch in tqdm(dataloader, desc="Evaluation"):
        # do some math
        # return some numbers
        pass
        # Mike TODO

In [91]:
# training loop

for epoch in tqdm(range(num_epochs), desc='training epochs'):
    # set the model to train mode
    model.train()
    
    for images, labels in tqdm(train_loader, desc='train batches'):
        images = images.to(device)
        labels = labels.to(device)
                
        preds = model(images).to(device)
        
        loss = F.cross_entropy(preds,labels)
        
        loss.backward()
        optimizer.step()
        lr_scheduler.step()
        optimizer.zero_grad()
        
    # evaluate at the end of each epoch
    # Mike TODO


# save the model checkpoint after training
# in case we want to use it again for class demo
output_dir = 'output'
try:
    os.mkdir(output_dir)
except:
    # we don't care if it already exists... move on
    pass

model.save_pretrained(output_dir)

epochs...:   0%|                                                                                                                                                    | 0/10 [00:00<?, ?it/s]
batches:   0%|                                                                                                                                                      | 0/45 [00:00<?, ?it/s][A

tensor(7.0976, grad_fn=<NllLossBackward0>)



batches:   2%|███▏                                                                                                                                          | 1/45 [00:05<04:09,  5.66s/it][A

tensor(7.0643, grad_fn=<NllLossBackward0>)



batches:   4%|██████▎                                                                                                                                       | 2/45 [00:10<03:50,  5.36s/it][A

tensor(7.0665, grad_fn=<NllLossBackward0>)



batches:   7%|█████████▍                                                                                                                                    | 3/45 [00:16<03:42,  5.29s/it][A

tensor(7.0307, grad_fn=<NllLossBackward0>)



batches:   9%|████████████▌                                                                                                                                 | 4/45 [00:21<03:37,  5.32s/it][A

tensor(6.8890, grad_fn=<NllLossBackward0>)



batches:  11%|███████████████▊                                                                                                                              | 5/45 [00:26<03:33,  5.34s/it][A

tensor(6.9116, grad_fn=<NllLossBackward0>)



batches:  13%|██████████████████▉                                                                                                                           | 6/45 [00:31<03:26,  5.29s/it][A

tensor(6.9605, grad_fn=<NllLossBackward0>)



batches:  16%|██████████████████████                                                                                                                        | 7/45 [00:37<03:19,  5.25s/it][A

tensor(6.8722, grad_fn=<NllLossBackward0>)



batches:  18%|█████████████████████████▏                                                                                                                    | 8/45 [00:42<03:13,  5.24s/it][A

tensor(6.7804, grad_fn=<NllLossBackward0>)



batches:  20%|████████████████████████████▍                                                                                                                 | 9/45 [00:47<03:08,  5.25s/it][A

tensor(6.8737, grad_fn=<NllLossBackward0>)



batches:  22%|███████████████████████████████▎                                                                                                             | 10/45 [00:52<03:04,  5.26s/it][A

tensor(6.6826, grad_fn=<NllLossBackward0>)



batches:  24%|██████████████████████████████████▍                                                                                                          | 11/45 [00:58<03:02,  5.37s/it][A

tensor(6.7489, grad_fn=<NllLossBackward0>)



batches:  27%|█████████████████████████████████████▌                                                                                                       | 12/45 [01:04<03:01,  5.49s/it][A

tensor(6.4059, grad_fn=<NllLossBackward0>)



batches:  29%|████████████████████████████████████████▋                                                                                                    | 13/45 [01:09<02:53,  5.43s/it][A

tensor(6.2750, grad_fn=<NllLossBackward0>)



batches:  31%|███████████████████████████████████████████▊                                                                                                 | 14/45 [01:14<02:47,  5.42s/it][A

tensor(6.3792, grad_fn=<NllLossBackward0>)



batches:  33%|███████████████████████████████████████████████                                                                                              | 15/45 [01:20<02:48,  5.60s/it][A

tensor(6.2911, grad_fn=<NllLossBackward0>)



batches:  36%|██████████████████████████████████████████████████▏                                                                                          | 16/45 [01:26<02:41,  5.57s/it][A

tensor(6.2698, grad_fn=<NllLossBackward0>)



batches:  38%|█████████████████████████████████████████████████████▎                                                                                       | 17/45 [01:32<02:36,  5.58s/it][A

tensor(6.2802, grad_fn=<NllLossBackward0>)



batches:  40%|████████████████████████████████████████████████████████▍                                                                                    | 18/45 [01:37<02:28,  5.51s/it][A

tensor(5.8371, grad_fn=<NllLossBackward0>)



batches:  42%|███████████████████████████████████████████████████████████▌                                                                                 | 19/45 [01:42<02:21,  5.45s/it][A

tensor(6.2512, grad_fn=<NllLossBackward0>)



batches:  44%|██████████████████████████████████████████████████████████████▋                                                                              | 20/45 [01:48<02:15,  5.42s/it][A

tensor(6.1932, grad_fn=<NllLossBackward0>)



batches:  47%|█████████████████████████████████████████████████████████████████▊                                                                           | 21/45 [01:53<02:08,  5.36s/it][A

tensor(5.8592, grad_fn=<NllLossBackward0>)



batches:  49%|████████████████████████████████████████████████████████████████████▉                                                                        | 22/45 [01:58<02:03,  5.39s/it][A

tensor(5.8625, grad_fn=<NllLossBackward0>)



batches:  51%|████████████████████████████████████████████████████████████████████████                                                                     | 23/45 [02:04<02:00,  5.47s/it][A

tensor(5.6643, grad_fn=<NllLossBackward0>)


batches:  51%|████████████████████████████████████████████████████████████████████████                                                                     | 23/45 [02:10<02:04,  5.66s/it]
epochs...:   0%|                                                                                                                                                    | 0/10 [02:10<?, ?it/s]


KeyboardInterrupt: 