<a href="https://colab.research.google.com/github/dominicwhite/DeepFossil/blob/master/notebooks/02-Simple-CNN-Segmentation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [0]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from torchvision import datasets, transforms

import os
import numpy as np

from PIL import Image

In [0]:
class FlattenImage():
    def __call__(self, t):
        return t.view((1,-1)).squeeze()

class BoneOnly():
    def __call__(self, t):
        return t #torch.clamp(t, 1, 2)

class SegmentationToTensor():
    def __call__(self, t):
        t = torch.from_numpy(np.asarray(t)).unsqueeze(0)
        return torch.clamp(t, 1, 2) - 1

img_transform = transforms.Compose([
    # you can add other transformations in this list
    transforms.Grayscale(1),
    transforms.ToTensor(),
])
lbl_transform = transforms.Compose([
    # you can add other transformations in this list
#     transforms.Grayscale(1),
    SegmentationToTensor(),
#     BoneOnly()
])


class CTSegmentationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.slice_locations = {}
        count = 0
        for vol in os.listdir(self.root_dir):
            im_dir = os.path.join(self.root_dir, vol, "images")
            for vol_idx, im in enumerate(os.listdir(im_dir)):
                self.slice_locations[count] = (vol_idx, vol)
                count += 1
        self.num_slices = count
        self.transform = transform
        
    def __len__(self):
        return self.num_slices
    
    def __getitem__(self, idx):
        slice_idx, vol = self.slice_locations[idx]
        vol_dir = os.path.join(self.root_dir, vol)
        im_name = os.path.join(vol_dir, "images", f"slice-{slice_idx}.png")
        image = Image.open(im_name)
        label_name = os.path.join(vol_dir, "labels", f"label-{slice_idx}.png")
        label = Image.open(label_name)
#         label = np.clip(label, 1, 2) - 1
        
        if self.transform:
            image = self.transform['image'](image)
            label = self.transform['label'](label)
            
        return image, label

dataset = CTSegmentationDataset(
    '/content/gdrive/My Drive/Colab Notebooks/data/CT/simulated_volumes/128', 
    transform={'image': img_transform, 'label': lbl_transform})

dataloader = torch.utils.data.DataLoader(dataset, batch_size=32, shuffle=True)

In [0]:
images, classes = next(iter(dataloader))

In [0]:
im = images[1]
print(im.shape)
print(images.max())

torch.Size([1, 128, 128])
tensor(1.)


In [0]:
lb = classes[0]
print(lb.shape)
print(classes.max())

torch.Size([1, 128, 128])
tensor(1, dtype=torch.uint8)


In [0]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')

CUDA is available!  Training on GPU ...


In [0]:
# define the CNN architecture
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 3, padding=1)
        self.conv2 = nn.Conv2d(16, 32, 3, padding=1)
        self.conv3 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv4 = nn.Conv2d(64, 128, 3, padding=1)
        self.convtc5 = nn.Conv2d(128, 64, 3, padding=1)
        self.convtc6 = nn.Conv2d(64, 32, 3, padding=1)
        self.convtc7 = nn.Conv2d(32, 16, 3, padding=1)
        self.convtc8 = nn.Conv2d(16, 1, 3, padding=1)
        self.mp = nn.MaxPool2d(2,2)

    def forward(self, x):
        x = self.mp(F.relu(self.conv1(x)))
        x = self.mp(F.relu(self.conv2(x)))
        x = self.mp(F.relu(self.conv3(x)))
        x = self.mp(F.relu(self.conv4(x)))
        x = F.relu(self.convtc5(F.interpolate(x,scale_factor=2)))
        x = F.relu(self.convtc6(F.interpolate(x,scale_factor=2)))
        x = F.relu(self.convtc7(F.interpolate(x,scale_factor=2)))
        x = self.convtc8(F.interpolate(x,scale_factor=2))
#         x = torch.sigmoid(x)
        return x

model = Net()
print(model)
output = model.forward(images)
print(output.shape)
print(output[0])
print(classes.shape)
print(classes[0])
print(classes.max())

Net(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convtc5): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convtc6): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convtc7): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convtc8): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (mp): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
torch.Size([32, 1, 128, 128])
tensor([[[-0.0510, -0.0560, -0.0552,  ..., -0.0564, -0.0566, -0.0564],
         [-0.0605, -0.0632, -0.0628,  ..., -0.0645, -0.0640, -0.0591],
         [-0.0589, -0.0610, -0.0597,  ..., -0.0622, -0.0618, -0.0566],
         ...,
         [-0.0599, -0.0615, -

In [0]:
# create a complete CNN
model = Net()
print(model)

# move tensors to GPU if CUDA is available
if train_on_gpu:
    model.cuda()

from torch import nn, optim

class_weights = torch.tensor([1/0.00537], dtype=torch.float).to('cuda')
print(class_weights)

criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=0.0003)

epochs = 50

for e in range(epochs):
    print("Starting epoch", e+1)
    running_loss = 0
    for images, labels in dataloader:
#         print(images.shape)
#         images, labels = images.to('cuda', dtype=torch.float), labels.to('cuda', dtype=torch.long)
        images, labels = images.to('cuda', dtype=torch.float), labels.to('cuda', dtype=torch.float)
        optimizer.zero_grad()
        
        output = model(images)
#         print("output", output.shape)
#         print("labels", labels.shape)
        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
#         break
    else:
        print(f' Loss: {running_loss}')
        ## TODO: Implement the validation pass and print out the validation accuracy
        with torch.no_grad():
            correct = []
            num_im = 0
#             total = torch.zeros(1)
#             img, lbl = dataset[182]
#             img = img.unsqueeze(0)
#             print("img", img.shape, img.sum(), img.min(), img.max())
#             lbl = lbl.unsqueeze(0)
#             img = img.to('cuda', dtype=torch.float)
#             ps = torch.sigmoid(model(img)).cpu().numpy()
#             print("ps", ps.shape, ps.sum(), ps.min(), ps.max())
#             preds = np.where(ps < 0.5, 0, 1)
#             print("preds", preds.shape, preds.sum())
#             print("lbl", lbl.shape, lbl.sum())
            total_bone_predictions = 0
            total_pixels = 0
            for images, labels in dataloader:
                images = images.to('cuda', dtype=torch.float)
#                 ps = torch.sigmoid(model(images)).cpu().numpy()
                ps = torch.sigmoid(model(images)).cpu().numpy()
#                 print("ps", ps.shape)
                preds = np.where(ps < 0.5, 0, 1)
#                 print("preds", preds.shape, preds.min(), preds.max())
                np_labels = labels.numpy()
                diff = np.sum(np.abs(preds - np_labels))
#                 print("  diff", diff)
                total_bone_predictions += np.sum(preds)
                total_pixels += ps.shape[0]*ps.shape[2]*ps.shape[3]
                correct.append(diff)
                num_im += ps.shape[0]
#                 top_p, top_class = ps.topk(1, dim=1)
#                 print("top_p", top_p.shape)
#                 print("top_p", top_p[0])
#                 print("top_class", top_class.shape)
#                 print("top_class", top_class[0])
#                 equals = top_class == labels.view(*top_class.shape)
#                 correct += torch.sum(equals)
#                 total += len(equals)
# #                 accuracy = accuracy + torch.mean(equals.type(torch.FloatTensor))
#             accuracy = correct / total
#         print("  sum of correct:", sum(correct))
        print(" Predicted bone fraction", total_bone_predictions/total_pixels)
        print(f' Accuracy: {sum(correct)*100/(ps.shape[2]*ps.shape[3]*num_im)}%')
#     break

Net(
  (conv1): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv2): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv3): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (conv4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convtc5): Conv2d(128, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convtc6): Conv2d(64, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convtc7): Conv2d(32, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (convtc8): Conv2d(16, 1, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (mp): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
tensor([186.2197], device='cuda:0')
Starting epoch 1
 Loss: 25.72144877910614
 Predicted bone fraction 0.24804344177246093
 Accuracy: 24.276037216186523%
Starting epoch 2
 Loss: 15.136053383350372
 Predicted bone fraction 0.20147018432617186
 Accuracy: 19.62508201599121%
Starting epoch