<a href="https://colab.research.google.com/github/dominicwhite/DeepFossil/blob/master/notebooks/03-UNet_1channel_input.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/gdrive')

Drive already mounted at /content/gdrive; to attempt to forcibly remount, call drive.mount("/content/gdrive", force_remount=True).


In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset

from torchvision import datasets, transforms

import os
import numpy as np

from PIL import Image

In [0]:
class SegmentationToTensor():
    def __call__(self, t):
        t = torch.from_numpy(np.asarray(t)).unsqueeze(0)
        return t.long() # torch.clamp(t, 1, 2) - 1

img_transform = transforms.Compose([
    transforms.Grayscale(1),
    transforms.ToTensor(),
])
lbl_transform = transforms.Compose([
    SegmentationToTensor(),
])


class CTSegmentationDataset(Dataset):
    def __init__(self, root_dir, transform=None):
        self.root_dir = root_dir
        self.slice_locations = {}
        count = 0
        for vol in os.listdir(self.root_dir):
            im_dir = os.path.join(self.root_dir, vol, "images")
            for vol_idx, im in enumerate(os.listdir(im_dir)):
                self.slice_locations[count] = (vol_idx, vol)
                count += 1
        self.num_slices = count
        self.transform = transform
        
    def __len__(self):
        return self.num_slices
    
    def __getitem__(self, idx):
        slice_idx, vol = self.slice_locations[idx]
        vol_dir = os.path.join(self.root_dir, vol)
        im_name = os.path.join(vol_dir, "images", f"slice-{slice_idx}.png")
        image = Image.open(im_name)
        label_name = os.path.join(vol_dir, "labels", f"label-{slice_idx}.png")
        label = Image.open(label_name)
#         label = np.clip(label, 1, 2) - 1
        
        if self.transform:
            image = self.transform['image'](image)
            label = self.transform['label'](label)
            
        return image, label

dataset = CTSegmentationDataset(
    '/content/gdrive/My Drive/Colab Notebooks/data/CT/simulated_volumes/128', 
    transform={'image': img_transform, 'label': lbl_transform})

dataloader = torch.utils.data.DataLoader(dataset, batch_size=8, shuffle=True)

In [0]:
images, classes = next(iter(dataloader))

In [5]:
im = images[1]
print(im.shape)
print(images.min(), images.max())

torch.Size([1, 128, 128])
tensor(0.) tensor(0.9647)


In [6]:
lb = classes[0]
print(lb.shape)
print(classes.min(), classes.max())

torch.Size([1, 128, 128])
tensor(0) tensor(2)


In [7]:
train_on_gpu = torch.cuda.is_available()

if not train_on_gpu:
    print('CUDA is not available.  Training on CPU ...')
else:
    print('CUDA is available!  Training on GPU ...')

CUDA is available!  Training on GPU ...


In [8]:
# define the CNN architecture

class DownConv(nn.Module):
    
    def __init__(self, in_layers, mid_layers, out_layers, filter_size = 3, pool=True):
        super(DownConv, self).__init__()
        self.pool = pool
        self.dc1 = nn.Conv2d(in_layers, mid_layers, filter_size, padding=1)
        self.dc2 = nn.Conv2d(mid_layers, out_layers, filter_size, padding=1)
    
    def forward(self, x):
        if self.pool == True:
            x = F.max_pool2d(x, 2)
#             print("After pool, shape:", x.shape)
        x = F.relu(self.dc1(x))
        x = F.relu(self.dc2(x))
        return x

    
class UnetUpsample(nn.Module):
    
    def __init__(self, in_layers, out_layers):
        super(UnetUpsample, self).__init__()
        self.up = nn.Upsample(mode='bilinear', scale_factor=2)
        self.conv = nn.Conv2d(in_layers, out_layers, 1)
    
    def forward(self, x):
        return self.conv(self.up(x))
    

class UpConv(nn.Module):
    
    def __init__(self, in_layers, mid_layers, out_layers, upsample_layers=0, filter_size = 3, interp=True):
        super(UpConv, self).__init__()
        self.interp = interp
        self.trans_conv1 = nn.ConvTranspose2d(in_layers, mid_layers, filter_size, padding=1)
        self.trans_conv2 = nn.ConvTranspose2d(mid_layers, out_layers, filter_size, padding=1)
        if self.interp == True:
            self.up = UnetUpsample(out_layers, upsample_layers)
    
    def forward(self, x):
        x = F.relu(self.trans_conv1(x))
        x = F.relu(self.trans_conv2(x))
        if self.interp == True:
            x = self.up(x)
        return x
        
        
        
        
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = DownConv(1, 64, 64, pool=False)
        self.conv2 = DownConv(64, 128, 128)
        self.conv3 = DownConv(128, 256, 256)
        self.conv4 = DownConv(256, 512, 512)
        self.conv5 = DownConv(512, 1024, 1024)
        
        self.up_conv5 = UnetUpsample(1024, 512)
        self.up_conv6 = UpConv(1024, 512, 512, 256)
        self.up_conv7 = UpConv(512, 256, 256, 128)
        self.up_conv8 = UpConv(256, 128, 128, 64)
        self.up_conv9 = UpConv(128, 64, 64, interp=False)
        
        self.final = nn.Conv2d(64, 3, 1)

    def forward(self, x):
        c1 = self.conv1(x)
        c2 = self.conv2(c1)
        c3 = self.conv3(c2)
        c4 = self.conv4(c3)
        c5 = self.conv5(c4)
        
        u6 = torch.cat([c4, self.up_conv5(c5)], dim=1)
        u7 = torch.cat([c3, self.up_conv6(u6)], dim=1)
        u8 = torch.cat([c2, self.up_conv7(u7)], dim=1)
        u9 = torch.cat([c1, self.up_conv8(u8)], dim=1)
        u10 = self.up_conv9(u9)
        fin = self.final(u10)

        return fin

model = Net()
for n, p in model.named_parameters():
    print(n, p.shape)
# print(model)
output = model.forward(images)
print(output.shape)
# print(output[0])
classes = classes #.long()
# print(classes.shape)
# print(classes[0])
# print(classes.max())

# criterion = nn.CrossEntropyLoss()
# loss = criterion(output, classes.squeeze())
# print(loss.item())

# lsoftmax = F.softmax(output, dim=1)
# print("log_softmax of output:", lsoftmax.shape)
# pred1 = lsoftmax[0]
# print(" and just of image 1:", pred1.shape)
# print(pred1)
# print(torch.sum(pred1, 0))
# idx1 = torch.argmax(pred1, 0)
# idxall = torch.argmax(lsoftmax, dim=1)
# print("all:", idxall.shape)
# print("all[0]:", idxall[0].shape)
# print(idxall[0])
# print("indexes:", idx1.shape)
# print(idx1)


conv1.dc1.weight torch.Size([64, 1, 3, 3])
conv1.dc1.bias torch.Size([64])
conv1.dc2.weight torch.Size([64, 64, 3, 3])
conv1.dc2.bias torch.Size([64])
conv2.dc1.weight torch.Size([128, 64, 3, 3])
conv2.dc1.bias torch.Size([128])
conv2.dc2.weight torch.Size([128, 128, 3, 3])
conv2.dc2.bias torch.Size([128])
conv3.dc1.weight torch.Size([256, 128, 3, 3])
conv3.dc1.bias torch.Size([256])
conv3.dc2.weight torch.Size([256, 256, 3, 3])
conv3.dc2.bias torch.Size([256])
conv4.dc1.weight torch.Size([512, 256, 3, 3])
conv4.dc1.bias torch.Size([512])
conv4.dc2.weight torch.Size([512, 512, 3, 3])
conv4.dc2.bias torch.Size([512])
conv5.dc1.weight torch.Size([1024, 512, 3, 3])
conv5.dc1.bias torch.Size([1024])
conv5.dc2.weight torch.Size([1024, 1024, 3, 3])
conv5.dc2.bias torch.Size([1024])
up_conv5.conv.weight torch.Size([512, 1024, 1, 1])
up_conv5.conv.bias torch.Size([512])
up_conv6.trans_conv1.weight torch.Size([1024, 512, 3, 3])
up_conv6.trans_conv1.bias torch.Size([512])
up_conv6.trans_conv2.we

  "See the documentation of nn.Upsample for details.".format(mode))


torch.Size([8, 3, 128, 128])


In [9]:
bone_pixels = 0
rock_pixels = 0
num_im = 0
for images, labels in dataloader:
    bone = np.where(labels.numpy() == 2, 1, 0)
    rock = np.where(labels.numpy() == 1, 1, 0)
    bone_pixels += np.sum(bone)
    rock_pixels += np.sum(rock)
    num_im += bone.shape[0]
total_pixels = num_im * 128 * 128
bone_fraction = bone_pixels/total_pixels
print("Bone:", bone_fraction)
rock_fraction = rock_pixels/total_pixels
print("Rock:", rock_fraction)
air_fraction = (total_pixels - bone_pixels - rock_pixels)/total_pixels
print("Air: ", air_fraction)
class_weights = torch.tensor([1/air_fraction, 1/rock_fraction, 1/bone_fraction], dtype=torch.float).to('cuda')
print("Class weights:", class_weights)

Bone: 0.005372142791748047
Rock: 0.12730979919433594
Air:  0.867318058013916
Class weights: tensor([  1.1530,   7.8549, 186.1455], device='cuda:0')


In [11]:
# create a complete CNN
model = Net()
print(model)

# move tensors to GPU if CUDA is available
if train_on_gpu:
    model.cuda()

from torch import nn, optim

# class_weights = torch.tensor([1/0.00537], dtype=torch.float).to('cuda')
print("Class weights:", class_weights)

criterion = nn.CrossEntropyLoss(class_weights) #nn.BCEWithLogitsLoss(pos_weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=0.0003)

epochs = 20

for e in range(epochs):
    print("Starting epoch", e+1)
    running_loss = 0
    batch = 1
    for images, labels in dataloader:
#         if batch % 10 == 0: print("On batch:", batch)
        images, labels = images.to('cuda', dtype=torch.float), labels.to('cuda')
        optimizer.zero_grad()
        output = model(images)
        loss = criterion(output, labels.squeeze())
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
        batch += 1
#         break
    else:
        print(f' Loss: {running_loss}')
        ## TODO: Implement the validation pass and print out the validation accuracy
        with torch.no_grad():
            incorrect = 0
            num_im = 0
            total_bone_predictions = 0
            total_actual_bone = 0
            total_pixels = 0
            for images, labels in dataloader:
                images = images.to('cuda', dtype=torch.float)
                prob_predictions = F.log_softmax(model(images), dim=1).cpu()
                
                class_predictions = torch.argmax(prob_predictions, dim=1).numpy()
                bone_predictions = np.where(class_predictions == 2, 1, 0)
                actual_bone = np.where(labels.numpy() == 2, 1, 0)
                diff = np.sum(np.abs(bone_predictions - actual_bone))
                
                total_bone_predictions += np.sum(bone_predictions)
                total_actual_bone += np.sum(actual_bone)
                total_pixels += actual_bone.shape[0]*actual_bone.shape[2]*actual_bone.shape[3]
                incorrect += diff
                num_im += actual_bone.shape[0]
        print(" Predicted bone fraction", total_bone_predictions/total_pixels)
        print(f" % Incorrect bone: {100 * incorrect / total_actual_bone}%")
#     break

Net(
  (conv1): DownConv(
    (dc1): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dc2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (conv2): DownConv(
    (dc1): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dc2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (conv3): DownConv(
    (dc1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dc2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (conv4): DownConv(
    (dc1): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dc2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (conv5): DownConv(
    (dc1): Conv2d(512, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (dc2): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  )
  (up_conv5): UnetUpsample(
    (up): Upsample(scale_factor=2.0, mode=bilinear)
   

  "See the documentation of nn.Upsample for details.".format(mode))


 Loss: 50.72528086602688
 Predicted bone fraction 0.13449687957763673
 % Incorrect bone: 20021.856526601692%
Starting epoch 2
 Loss: 36.023095175623894
 Predicted bone fraction 0.07365789413452148
 % Incorrect bone: 11319.891356446717%
Starting epoch 3
 Loss: 35.2511814981699
 Predicted bone fraction 0.08170995712280274
 % Incorrect bone: 12435.955335428094%
Starting epoch 4
 Loss: 32.03461328148842
 Predicted bone fraction 0.025095748901367187
 % Incorrect bone: 4304.184196978573%
Starting epoch 5
 Loss: 28.314626518636942
 Predicted bone fraction 0.06274356842041015
 % Incorrect bone: 9700.44203014326%
Starting epoch 6
 Loss: 28.219299025833607
 Predicted bone fraction 0.04072093963623047
 % Incorrect bone: 6523.413395821129%
Starting epoch 7
 Loss: 25.22037947177887
 Predicted bone fraction 0.062123680114746095
 % Incorrect bone: 9590.53096873835%
Starting epoch 8
 Loss: 22.383606500923634
 Predicted bone fraction 0.01806917190551758
 % Incorrect bone: 3224.3524879728748%
Starting e