In [1]:
# Code credits: Adapted bits and pieces from https://github.com/webdataset/webdataset/blob/master/docs/gettingstarted.ipynb

import sys
sys.path.append('..')

import gc
import json
import os
from itertools import islice
from datetime import datetime
import pytz
from pytz import timezone
import numpy as np
from sklearn.metrics import precision_recall_fscore_support
import matplotlib.pyplot as plt
import skimage.transform as st
import tqdm

import torch
import torch.optim as optim
from torchvision import transforms
import webdataset as wds

from model.baseline_3d_cnn import *
from utils.model_utils import *
from utils.model_run import *

%load_ext autoreload
%autoreload 2

In [2]:
data_dir = '../data'
shards_dir = os.path.join(data_dir, 'shards_new')

# Opening JSON file
with open('../parameters.json') as json_file:
    parameters = json.load(json_file)

batch_size = 4 #parameters['batch_size']
shard_size = 4 #parameters['shard_size']
#parameters

{'batch_size': 16, 'shard_size': 16}

In [3]:
urls = [os.path.join(shards_dir, it) for it in os.listdir(shards_dir) if it.endswith('.tar')]

# Try to overfit on smaller data
# urls = urls[:round(len(urls)*0.3)]

# Another shard directory, continued; realize can't use because keys will collide cuz we refreshed...
# shards_dir2 = os.path.join(data_dir, 'shards_new_cont')
# urls += [os.path.join(shards_dir2, it) for it in os.listdir(shards_dir2) if it.endswith('.tar')]


# All the data
train_urls = urls[:round(len(urls)*0.7)]
val_urls = urls[round(len(urls)*0.7):round(len(urls)*0.85)]
test_urls = urls[round(len(urls)*0.85):]

# Smaller data just to run model once
# train_urls = urls[:2]
# val_urls = urls[2:3]
# test_urls = urls[3:]


print("Number of train shards:", len(train_urls))
print("Number of validation shards:", len(val_urls))
print("Number of test shards:", len(test_urls))

Number of train shards: 67
Number of validation shards: 15
Number of test shards: 14


In [8]:
# Create dataset objects
train_iternum = len(train_urls)*shard_size//batch_size
val_iternum = len(val_urls)*shard_size//batch_size
test_iternum = len(test_urls)*shard_size//batch_size

print("Number of iterations per train epoch:", train_iternum)

train_dataset = (
    wds
    .WebDataset(train_urls, length=train_iternum)
    .shuffle(shard_size)
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
    .batched(batch_size)
#     .map_tuple(pre_transforms, identity, identity)
)
loader_train = torch.utils.data.DataLoader(train_dataset, num_workers=0, batch_size=None) #setting batch_size = None disables batching
val_dataset = (
    wds
    .WebDataset(val_urls, length=val_iternum)
    .shuffle(shard_size)
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
    .batched(batch_size)
)
loader_val = torch.utils.data.DataLoader(val_dataset, num_workers=0, batch_size=None)

test_dataset = (
    wds
    .WebDataset(test_urls, length=test_iternum)
    .shuffle(shard_size)
    .decode('torch')
    .to_tuple('volumes.pyd', 'labels.pyd', 'studynames.pyd')
    .batched(batch_size)
)
loader_test = torch.utils.data.DataLoader(test_dataset, num_workers=0, batch_size=None)

# for image, target in islice(dataset, 0, 2):
#     print(image.shape)

Number of iterations per train epoch: 67


In [9]:
# tmp  = (
#     wds
#     .WebDataset(sorted(train_urls)[:2])
# )
# for i, sample in enumerate(tmp):
#     print("IIIIIIIIII:", i)
#     for key, value in sample.items():
#         print(key, repr(value)[:50])
#     print()

### Original quality images:

In [10]:
gc.collect()

108

In [11]:
# patient_num = 1

# for t, (x, y, z) in enumerate(loader_train):
#     if t > 0:
#         break
#     tmp = x[patient_num, :, :, :].detach().numpy()
# del x, y, z

In [12]:
# fig, axs = plt.subplots(5,8, figsize=(15, 6))
# axs = axs.ravel()
# for i in range(40):
#     axs[i].imshow(tmp[i,:,:])

### Downsampled quality images

In [13]:
# fig, axs = plt.subplots(5,8, figsize=(15, 6))
# axs = axs.ravel()
# for i in range(40):
#     axs[i].imshow(st.resize(tmp[i,:,:], (256,256)))

In [14]:
# del tmp
# gc.collect()

In [15]:
USE_GPU = True
dtype = torch.float

if USE_GPU and torch.cuda.is_available():
    device = torch.device('cuda')
#     dtype = torch.cuda.FloatTensor
else:
    device = torch.device('cpu')

print(device)
print(dtype)

cuda
torch.float32


In [16]:
# Make log directory and checkpoint directory
dir_nm = datetime.now(tz=pytz.utc).astimezone(timezone('US/Pacific')).strftime('%Y-%m-%d_%H-%M-%S')
# dir_nm = "first_mini_c2fc2"
log_dir = os.path.join('../runs/baseline', dir_nm)
os.mkdir(log_dir)
os.mkdir(os.path.join(log_dir, 'Checkpoints'))


# Model, optimizer, criterion
model = baseline_3DCNN(in_num_ch=1)
optimizer = optim.Adam(model.parameters(), lr = 1e-4)
criterion = torch.nn.BCEWithLogitsLoss()

In [17]:
gc.collect()

22

In [19]:
train_loss_dict, val_loss_dict = train(model, optimizer, criterion, loader_train, loader_val, log_dir, device=device, epochs=10, val_every=5)

Epoch 1:   3%|▎         | 2/67 [00:51<27:49, 25.69s/batch, loss=0.663]


KeyboardInterrupt: 

### Loading a checkpoint

In [None]:
# ckpt_path = os.path.join(log_dir, 'Checkpoints', 'ep_0_iter_3_ckpt.pt')
# ckpt = torch.load(ckpt_path)

# ckpt_model = baseline_3DCNN(in_num_ch=1)
# ckpt_model.load_state_dict(ckpt['model_state_dict'])
# ckpt_model.state_dict()