In [1]:
import sys, os
import torch
import torch.nn as nn
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms

In [2]:
sys.path.append("../")

In [3]:
from Transformers import UNetDataset, ChannelsFirst, ToTensor, Rescale

In [4]:
DATA_DIR_DEEPTHOUGHT="/storage/yw18581/data"
data_dir = DATA_DIR_DEEPTHOUGHT
train_test = os.path.join(data_dir, "train_validation_test")

In [5]:
data = np.load("/storage/yw18581/data/train_validation_test/Xy_train+val_clean_300_24_10_25.npz")
x = data["x"]
y = data['y']
dist = np.unique(data['dist'], return_inverse=True)[1]

In [6]:
composed = transforms.Compose([Rescale(.25), ChannelsFirst(), ToTensor()])

In [7]:
train_dataset = UNetDataset(x, y, transform=composed, dist = dist)

In [8]:
data_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)

In [9]:
torch.cuda.empty_cache() 

In [10]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [11]:
from cUNet_pytorch_pooling import cUNet, dice_loss

In [12]:
model = cUNet()

In [13]:
criterion_mask = dice_loss

In [14]:
criterion_class = nn.CrossEntropyLoss()

In [15]:
model.to(device)

cUNet(
  (conv_block_down1): Sequential(
    (0): Conv2d(1, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(16, 16, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (conv_block_down2): Sequential(
    (0): Conv2d(16, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (conv_block_down3): Sequential(
    (0): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (conv_block_down4): Sequential(
    (0): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
  )
  (conv_block_d

In [16]:
import torch.optim as optim


optimizer = optim.Adam(model.parameters(), lr=1e-5)

In [17]:
epochs = 100

In [None]:
for epoch in range(epochs):
    print("Epoch {}\n".format(epoch+1))
    running_loss = 0.0
    for i, batch in enumerate(data_loader):
        inputs = batch['image'].float().to(device)
        labels_mask = batch['mask'].float().to(device)
        labels_dist = batch['dist'].long().to(device)
        #print(inputs.is_cuda, labels_mask.is_cuda, labels_dist.is_cuda)
        optimizer.zero_grad()
        out_mask, out_class  = model(inputs)
        #print(out_mask.is_cuda)
        #print(out_class.is_cuda)
        loss_mask = criterion_mask(out_mask, labels_mask)
        loss_class = criterion_class(out_class, labels_dist)
        loss = loss_mask + loss_class
        #print(loss_mask, loss_class)
        loss.backward()
        optimizer.step()
        
        # print statistics
        running_loss += loss.item()
        if i % 60 == 59:    # print every 2000 mini_batches
            print('[%d, %5d] loss: %.3f' % 
                 (epoch + 1, i + 1, running_loss / 60))
print('Finished Training')

Epoch 1

[1,    60] loss: 1.359
Epoch 2

[2,    60] loss: 1.358
Epoch 3



In [None]:
torch.save(model.state_dict(), "../model/trained_cUNet_pytorch_{}epochs.pkl".format(epochs))

In [None]:
model_inference = cUNet()

In [None]:
model_inference.load_state_dict(torch.load("../model/trained_cUNet_pytorch_{}epochs.pkl".format(epochs)))

In [None]:
test_data = np.load("/storage/yw18581/data/train_validation_test/Xy_test_clean_300_24_10_25.npz")
x_test = test_data["x"]
y_test = test_data['y']
dist_test = test_data['dist']
dist_test_idx = np.unique(dist_test, return_inverse=True)[1]

In [None]:
test_dataset = UNetDataset(x_test, y_test, transform=composed, dist = dist_test_idx)

test_data_loader = DataLoader(test_dataset, batch_size=16, shuffle=True, num_workers=4)

In [None]:
model_inference = model.eval()

In [None]:
model_inference.to(device)

In [None]:
import matplotlib.pyplot as plt

In [None]:
dist_dict = {str(idx):cl  for (cl, idx) in 
             zip(np.unique(dist_test), np.unique(dist_test_idx))}

In [None]:
for i, batch in enumerate(test_data_loader):

    true_images, true_masks, true_dists = batch['image'], batch['mask'], batch['dist']
    pred_masks, pred_classes = model_inference(true_images.float().to(device))
    print("batch {}".format(i+1))
    for j, (img, tr_msk, tr_cl, pr_msk, pr_cl) in enumerate(zip(true_images,
                                                 true_masks, 
                                                 true_dists.cpu().detach().numpy(),
                                                 pred_masks.cpu().detach().numpy(), 
                                                 pred_classes.cpu().detach().numpy())):
    
        true_dist = dist_dict[str(tr_cl)]
        pred_dist = dist_dict[str(np.argmax(pr_cl))]
        print("{}: true_dist: {}, pred_dist: {}".format(j+1, true_dist, pred_dist))
        
        f = plt.figure(figsize=(10,5))
        f.add_subplot(1,3, 1)
        plt.imshow(img[0,...], cmap='gray')
        f.add_subplot(1,3, 2)
        plt.imshow(tr_msk[0,...], cmap='gray')
        f.add_subplot(1,3, 3)
        plt.imshow(pr_msk[0,...], cmap='gray')
        plt.show(block=True)
        
    if i==1:
        break

In [None]:
y_test = []
y_pred = []
for i, batch in enumerate(test_data_loader):
    true_images, true_dists = batch['image'], batch['dist']
    _, pred_classes = model_inference(true_images.float().to(device))
    print("batch {}".format(i+1))
    for j, (img, tr_cl, pr_cl) in enumerate(zip(true_images,
                                                 true_dists.cpu().detach().numpy(), 
                                                 pred_classes.cpu().detach().numpy())):
    
        true_dist = dist_dict[str(tr_cl)]
        pred_dist = dist_dict[str(np.argmax(pr_cl))]
        y_test.append(true_dist)
        y_pred.append(pred_dist)

In [None]:
15*16

In [None]:
from sklearn.metrics import accuracy_score, confusion_matrix

In [None]:
confusion_matrix(y_test, y_pred)

In [None]:
accuracy_score(y_test, y_pred)