In [0]:
!pip3 install torch 
!pip3 install torchvision
!pip3 install tqdm



In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms, utils, datasets
from tqdm import tqdm

from torch.nn.parameter import Parameter
import pdb

import torchvision
import os
import gzip
import tarfile
import gc
from IPython.core.ultratb import AutoFormattedTB
__ITB__ = AutoFormattedTB(mode = 'Verbose',color_scheme='LightBg', tb_offset = 1)
 
assert torch.cuda.is_available()                                                # You need to request a GPU from Runtime > Change Runtime Type

In [0]:
class CancerDataset(Dataset):
  def __init__(self, root, download=True, size=512, train=True):
    if download and not os.path.exists(os.path.join(root, 'cancer_data')):
      datasets.utils.download_url('http://liftothers.org/cancer_data.tar.gz', root, 'cancer_data.tar.gz', None)
      self.extract_gzip(os.path.join(root, 'cancer_data.tar.gz'))
      self.extract_tar(os.path.join(root, 'cancer_data.tar'))
 
    postfix = 'train' if train else 'test'
    root = os.path.join(root, 'cancer_data', 'cancer_data')
    self.dataset_folder = torchvision.datasets.ImageFolder(os.path.join(root, 'inputs_' + postfix) ,transform = transforms.Compose([transforms.Resize(size),transforms.ToTensor()]))
    self.label_folder = torchvision.datasets.ImageFolder(os.path.join(root, 'outputs_' + postfix) ,transform = transforms.Compose([transforms.Resize(size),transforms.ToTensor()]))
 
  @staticmethod
  def extract_gzip(gzip_path, remove_finished=False):
    print('Extracting {}'.format(gzip_path))
    with open(gzip_path.replace('.gz', ''), 'wb') as out_f, gzip.GzipFile(gzip_path) as zip_f:
      out_f.write(zip_f.read())
    if remove_finished:
      os.unlink(gzip_path)
 
  @staticmethod
  def extract_tar(tar_path):
    print('Untarring {}'.format(tar_path))
    z = tarfile.TarFile(tar_path)
    z.extractall(tar_path.replace('.tar', ''))
 
 
  def __getitem__(self,index):
    img = self.dataset_folder[index]
    label = self.label_folder[index]
    return img[0] * 255,label[0][0]
 
  def __len__(self):
    return len(self.dataset_folder)
    

In [0]:
import PIL.Image

def register_extension(id, extension):
  PIL.Image.EXTENSION[extension.lower()] = id.upper()
PIL.Image.register_extension = register_extension

def register_extensions(id, extensions):
  for extension in extensions:
    register_extension(id, extension)
PIL.Image.register_extensions = register_extensions

In [0]:
# U Net Class
class CancerDetection(nn.Module):
  def __init__(self, dataset):
    super(CancerDetection, self).__init__()

    self.conv01 = nn.Conv2d(3, 64, 3, padding=1)
    self.relu02 = nn.ReLU()
    self.conv03 = nn.Conv2d(64, 64, 3, padding=1)
    self.relu04 = nn.ReLU()
    
    self.pool05 = nn.MaxPool2d(2, stride=2)
    self.conv06 = nn.Conv2d(64, 128, 3, padding=1)
    self.relu07 = nn.ReLU()
    self.conv08 = nn.Conv2d(128, 128, 3, padding=1)
    self.relu09 = nn.ReLU()
    
    self.pool10 = nn.MaxPool2d(2, stride=2)
    self.conv11 = nn.Conv2d(128, 256, 3, padding=1)
    self.relu12 = nn.ReLU()
    self.conv13 = nn.Conv2d(256, 256, 3, padding=1)
    self.relu14 = nn.ReLU()
    
    self.pool15 = nn.MaxPool2d(2, stride=2)
    self.conv16 = nn.Conv2d(256, 512, 3, padding=1)
    self.relu17 = nn.ReLU()
    self.conv18 = nn.Conv2d(512, 512, 3, padding=1)
    self.relu19 = nn.ReLU()
    
    self.pool20 = nn.MaxPool2d(2, stride=2)
    self.conv21 = nn.Conv2d(512, 1024, 3, padding=1)
    self.relu22 = nn.ReLU()
    self.conv23 = nn.Conv2d(1024, 1024, 3, padding=1)
    self.relu24 = nn.ReLU()
    self.tran25 = nn.ConvTranspose2d(1024, 512, 2, stride=2)
    
    self.conv26 = nn.Conv2d(1024, 512, 3, padding=1)
    self.relu27 = nn.ReLU()
    self.conv28 = nn.Conv2d(512, 512, 3, padding=1)
    self.relu29 = nn.ReLU()
    self.tran30 = nn.ConvTranspose2d(512, 256, 2, stride=2)
    
    self.conv31 = nn.Conv2d(512, 256, 3, padding=1)
    self.relu32 = nn.ReLU()
    self.conv33 = nn.Conv2d(256, 256, 3, padding=1)
    self.relu34 = nn.ReLU()
    self.tran35 = nn.ConvTranspose2d(256, 128, 2, stride=2)
    
    self.conv36 = nn.Conv2d(256, 128, 3, padding=1)
    self.relu37 = nn.ReLU()
    self.conv38 = nn.Conv2d(128, 128, 3, padding=1)
    self.relu39 = nn.ReLU()
    self.tran40 = nn.ConvTranspose2d(128, 64, 2, stride=2)
    
    self.conv41 = nn.Conv2d(128, 64, 3, padding=1)
    self.relu42 = nn.ReLU()
    self.conv43 = nn.Conv2d(64, 64, 3, padding=1)
    self.relu44 = nn.ReLU()
    
    self.conv45 = nn.Conv2d(64, 2, 1)
    
  def forward(self, image):
    
    out01 = self.conv01(image)
    out02 = self.relu02(out01)
    out03 = self.conv03(out02)
    out04 = self.relu04(out03)
    
    out05 = self.pool05(out04)
    out06 = self.conv06(out05)
    out07 = self.relu07(out06)
    out08 = self.conv08(out07)
    out09 = self.relu09(out08)
    
    out10 = self.pool10(out09)
    out11 = self.conv11(out10)
    out12 = self.relu12(out11)
    out13 = self.conv13(out12)
    out14 = self.relu14(out13)
    
    out15 = self.pool15(out14)
    out16 = self.conv16(out15)
    out17 = self.relu17(out16)
    out18 = self.conv18(out17)
    out19 = self.relu19(out18)
    
    out20 = self.pool20(out19)
    out21 = self.conv21(out20)
    out22 = self.relu22(out21)
    out23 = self.conv23(out22)
    out24 = self.relu24(out23)
    out25 = self.tran25(out24)
    
    cat01 = torch.cat([out19, out25], dim=1)
    
    out26 = self.conv26(cat01)
    out27 = self.relu27(out26)
    out28 = self.conv28(out27)
    out29 = self.relu29(out28)
    out30 = self.tran30(out29)
    
    cat02 = torch.cat([out14, out30], dim=1) 
    
    out31 = self.conv31(cat02)
    out32 = self.relu32(out31)
    out33 = self.conv33(out32)
    out34 = self.relu34(out33)
    out35 = self.tran35(out34)
    
    cat03 = torch.cat([out09, out35], dim=1) 
   
    out36 = self.conv36(cat03)
    out37 = self.relu37(out36)
    out38 = self.conv38(out37)
    out39 = self.relu39(out38)
    out40 = self.tran40(out39)
    
    cat04 = torch.cat([out04, out40], dim=1) 
    
    out41 = self.conv41(cat04)
    out42 = self.relu42(out41)
    out43 = self.conv43(out42)
    out44 = self.relu44(out43)
    
    out45 = self.conv45(out44)
    return out45

In [0]:
def scope():
  try:
    # Init objects
    EPOCHS = 6
    BATCH_SIZE = 10

    data_train = CancerDataset('/tmp/cancer', train=True, size=256)             # Get data
    data_test = CancerDataset('/tmp/cancer', train=False, size=256)
    
    model = CancerDetection(data_train)                                         # Init model
    model = model.cuda()         

    loader_train = DataLoader(data_train,                                       # Init Data Loader
                              batch_size=BATCH_SIZE,
                              pin_memory=True,
                              num_workers=3,
                              shuffle=True)

    objective = nn.CrossEntropyLoss()                                           # Set loss function

    optimizer = optim.Adam(model.parameters(), lr=1e-4)                         # Optimize our wieghts and biases

    loss_train = []
 
    gc.collect()
    print(torch.cuda.memory_allocated(0) / 1e9)
 
    # Start training
    for epoch in range(EPOCHS):

      loop = tqdm(total=len(loader_train), position=0, leave=False)

      # Train batch
      for batch, (x, y_truth) in enumerate(loader_train):                 
        x, y_truth = x.cuda(async=True), y_truth.cuda(async=True)       
        
        optimizer.zero_grad()
        y_hat = model(x)                                                        # Get prediction from model

        loss = objective(y_hat, y_truth.long())                                 # Calculate loss
        
        loss.backward()
        optimizer.step()                                                        # Take Step

        loss_train.append(loss.item())
        loop.set_description('epoch:{},batch:{}, loss:{:.4f}'
                             .format(epoch, batch, loss.item()))
        loop.update(1)

      loop.close()
    
    # Plot results
    plt.plot(loss_train, label="Training Loss")
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.title("Loss")
    plt.xlabel("Time")
    plt.ylabel("Loss")
    
    plt.show()
    
    # Show example output
    test = data_test[172][0]
    truth = data_test[172][1]
    truth = (truth - 1) * (-1)
    plt.axis("off")
    plt.title("Test Image Truth")
    plt.imshow(truth)
    plt.show()    
    
    pred = model(test.cuda().unsqueeze(0)).squeeze(0).argmax(0).cpu()
    pred = (pred - 1) * (-1)
    plt.axis("off")
    plt.title("Test Image Output")
    plt.imshow(pred)
    plt.show()
 
  except:
    __ITB__()
 
scope()

  0%|          | 0/135 [00:00<?, ?it/s]

0.12412928


epoch:0,batch:62, loss:0.3041:  47%|████▋     | 63/135 [02:33<02:53,  2.41s/it]Process Process-5:
Process Process-6:
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
Process Process-4:
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
Traceback (most recent call last):
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 104, in get
    if not self._poll(timeout):
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _b

[0;31m---------------------------------------------------------------------------[0m
[0;31mKeyboardInterrupt[0m                         Traceback (most recent call last)
[0;31mKeyboardInterrupt[0m: 


Exception ignored in: <bound method _DataLoaderIter.__del__ of <torch.utils.data.dataloader._DataLoaderIter object at 0x7fc89ce0c9b0>>
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 397, in __del__
    def __del__(self):
  File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 227, in handler
    _error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 257) exited unexpectedly with exit code 1. Details are lost due to multiprocessing. Rerunning with num_workers=0 may give better error trace.
