In [1]:
import sys, os
import torch
import torch.nn as nn
import numpy as np
from torchvision import transforms
from torch.utils.data import DataLoader
import tqdm

In [2]:
import matplotlib.pyplot as plt

In [3]:
sys.path.append("../")

In [4]:
from Transformers import ChannelsFirst, ToTensor, Cut, Rescale, splitter, splitter_train_val_test

In [5]:
from DataSets import UNetDataSetFromNpz, UNetDatasetFromFolders

In [6]:
from cUNet_pytorch_pooling import cUNet, dice_loss

In [7]:
import torch.optim as optim

In [8]:
DATA_DIR_DEEPTHOUGHT="/storage/yw18581/data"
data_dir = DATA_DIR_DEEPTHOUGHT
train_test = os.path.join(data_dir, "train_validation_test")

#### import data from npz file to train model

In [9]:
data = np.load(os.path.join(train_test,"Xy_train+val_clean_300_24_10_25.npz"))
x = data["x"]
y = data['y']
dist = data['dist']

In [10]:
composed_npz = transforms.Compose([Rescale(.25), ChannelsFirst(), ToTensor()])

In [11]:
dataset_train = UNetDataSetFromNpz(x, y, transform=composed_npz, dist = dist[...,np.newaxis])

In [12]:
train_loaders, train_lengths = splitter(dataset_train, validation_split=0.2, batch=16, workers=4)

In [13]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [14]:
model = cUNet(out_size=1)

In [15]:
criterion_mask = dice_loss
criterion_dist = nn.MSELoss()

In [16]:
model.to(device)

cUNet(
  (conv_block_down1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (conv_block_down2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (conv_block_down3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (conv_block_down4): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (conv_block_d

In [18]:
optimizer = optim.Adam(model.parameters(), lr=1e-4)

In [19]:
epochs = 50
coeff_mask = 0.75

#### Train on npz dataset for 50 epochs

In [19]:
for epoch in tqdm.tqdm(range(epochs)):
    print("Epoch {}/{}\n".format(epoch+1, epochs))
    print('-'* 10)
    
    for phase in ['train', 'val']:
        if phase == 'train':
            model.train(True)
        else:
            model.train(False)
            
        running_loss = 0.0
        for i, batch in enumerate(train_loaders[phase]):
            inputs = batch['image'].float().to(device)
            labels_mask = batch['mask'].float().to(device)
            labels_dist = batch['dist'].float().to(device)
            
            optimizer.zero_grad()
            out_mask, out_dist  = model(inputs)
            loss_mask = criterion_mask(out_mask, labels_mask)
            loss_dist = criterion_dist(out_dist, labels_dist)
            loss = coeff_mask * loss_mask + (1-coeff_mask) * loss_dist
            if phase == 'train':
                loss.backward()
                optimizer.step()

            running_loss += loss.item()
        epoch_loss = running_loss / train_lengths[phase]
        print('{} Loss: {:.4f}'.format(phase, epoch_loss))
print('Finished Training')

  0%|          | 0/50 [00:00<?, ?it/s]

Epoch 1/50

----------
train Loss: 1.6995


  2%|▏         | 1/50 [00:45<37:21, 45.75s/it]

val Loss: 1.3765
Epoch 2/50

----------
train Loss: 1.2527


  4%|▍         | 2/50 [01:31<36:37, 45.77s/it]

val Loss: 1.4168
Epoch 3/50

----------
train Loss: 1.2584


  6%|▌         | 3/50 [02:17<35:50, 45.75s/it]

val Loss: 1.3818
Epoch 4/50

----------
train Loss: 1.2625


  8%|▊         | 4/50 [03:03<35:04, 45.75s/it]

val Loss: 1.4266
Epoch 5/50

----------
train Loss: 1.2743


 10%|█         | 5/50 [03:48<34:13, 45.64s/it]

val Loss: 1.4020
Epoch 6/50

----------
train Loss: 1.2519


 12%|█▏        | 6/50 [04:34<33:33, 45.76s/it]

val Loss: 1.3808
Epoch 7/50

----------
train Loss: 1.2445


 14%|█▍        | 7/50 [05:19<32:43, 45.66s/it]

val Loss: 1.4009
Epoch 8/50

----------
train Loss: 1.2377


 16%|█▌        | 8/50 [06:05<31:51, 45.52s/it]

val Loss: 1.3580
Epoch 9/50

----------
train Loss: 1.2264


 18%|█▊        | 9/50 [06:50<31:08, 45.57s/it]

val Loss: 1.3645
Epoch 10/50

----------
train Loss: 1.2329


 20%|██        | 10/50 [07:37<30:32, 45.80s/it]

val Loss: 1.5644
Epoch 11/50

----------
train Loss: 1.2659


 22%|██▏       | 11/50 [08:22<29:44, 45.76s/it]

val Loss: 1.3494
Epoch 12/50

----------
train Loss: 1.2297


 24%|██▍       | 12/50 [09:07<28:48, 45.48s/it]

val Loss: 1.3549
Epoch 13/50

----------
train Loss: 1.2093


 26%|██▌       | 13/50 [09:52<27:54, 45.24s/it]

val Loss: 1.3399
Epoch 14/50

----------
train Loss: 1.2179


KeyboardInterrupt: 

In [None]:
model_name = "../model/trained_cUNet_pytorch_regression_{}epochs_coeff_mask{}_validation_on_npz_notranspose.pkl".format(epochs, coeff_mask)

In [None]:
torch.save(model.state_dict(), model_name)

#### Test on npz test data

In [None]:
model_inference = cUNet(out_size=1)

In [None]:
model_inference.load_state_dict(torch.load(model_name))

In [None]:
test_data = np.load(os.path.join(train_test,"Xy_test_clean_300_24_10_25.npz"))
x_test = test_data["x"]
y_test = test_data['y']
dist_test = test_data['dist']

In [None]:
test_dataset = UNetDataSetFromNpz(x_test, y_test, transform=composed_npz, dist = dist_test[...,np.newaxis])

test_data_loader = DataLoader(test_dataset, batch_size=16, shuffle=True, num_workers=4)

In [None]:
model_inference = model.eval()
model_inference.to(device)

In [None]:
for i, batch in enumerate(test_data_loader):

    true_images, true_masks, true_dists = batch['image'], batch['mask'], batch['dist']
    pred_masks, pred_classes = model_inference(true_images.float().to(device))
    print("batch {}".format(i+1))
    for j, (img, tr_msk, tr_cl, pr_msk, pr_cl) in enumerate(zip(true_images,
                                                 true_masks, 
                                                 true_dists.cpu().detach().numpy(),
                                                 pred_masks.cpu().detach().numpy(), 
                                                 pred_classes.cpu().detach().numpy())):
    
        true_dist = tr_cl
        pred_dist = pr_cl
        print("{}: true_dist: {}, pred_dist: {}".format(j+1, true_dist, pred_dist))
        
        f = plt.figure(figsize=(10,5))
        f.add_subplot(1,3, 1)
        plt.imshow(img[0,...], cmap='gray')
        f.add_subplot(1,3, 2)
        plt.imshow(tr_msk[0,...], cmap='gray')
        f.add_subplot(1,3, 3)
        plt.imshow(pr_msk[0,...], cmap='gray')
        plt.show(block=True)
        
    if i==1:
        break