In [21]:
import numpy as np
import pandas as pd
import matplotlib as mp
import matplotlib.pyplot as plt
import time

from pathlib import Path
import torch
from torch.utils.data import Dataset, DataLoader, sampler
from torch import nn

from DatasetLoader import DatasetLoader
from Unet2D import Unet2D


In [22]:
def train(model, train_dl, valid_dl, loss_fn, optimizer, acc_fn, epochs=1):
    start = time.time()
    model.cuda()

    train_loss, valid_loss = [], []

    best_acc = 0.0

    for epoch in range(epochs):
        print('Epoch {}/{}'.format(epoch, epochs - 1))
        print('-' * 10)

        for phase in ['train', 'valid']:
            if phase == 'train':
                model.train(True)  # Set trainind mode = true
                dataloader = train_dl
            else:
                model.train(False)  # Set model to evaluate mode
                dataloader = valid_dl

            running_loss = 0.0
            running_acc = 0.0

            step = 0

            # iterate over data
            for x, y in dataloader:
                x = x.cuda()
                y = y.cuda()
                step += 1

                # forward pass
                if phase == 'train':
                    # zero the gradients
                    optimizer.zero_grad()
                    outputs = model(x)
                    loss = loss_fn(outputs, y)

                    # the backward pass frees the graph memory, so there is no 
                    # need for torch.no_grad in this training pass
                    loss.backward()
                    optimizer.step()
                    # scheduler.step()

                else:
                    with torch.no_grad():
                        outputs = model(x)
                        loss = loss_fn(outputs, y.long())

                # stats - whatever is the phase
                acc = acc_fn(outputs, y)

                running_acc  += acc*dataloader.batch_size
                running_loss += loss*dataloader.batch_size 

                if step % 100 == 0:
                    # clear_output(wait=True)
                    print('Current step: {}  Loss: {}  Acc: {}  AllocMem (Mb): {}'.format(step, loss, acc, torch.cuda.memory_allocated()/1024/1024))
                    # print(torch.cuda.memory_summary())

            epoch_loss = running_loss / len(dataloader.dataset)
            epoch_acc = running_acc / len(dataloader.dataset)

            print('Epoch {}/{}'.format(epoch, epochs - 1))
            print('-' * 10)
            print('{} Loss: {:.4f} Acc: {}'.format(phase, epoch_loss, epoch_acc))
            print('-' * 10)

            train_loss.append(epoch_loss) if phase=='train' else valid_loss.append(epoch_loss)

    time_elapsed = time.time() - start
    print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))    
    
    return train_loss, valid_loss    

In [23]:
def acc_metric(predb, yb):
    return (predb.argmax(dim=1) == yb.cuda()).float().mean()

def batch_to_img(xb, idx):
    img = np.array(xb[idx,0:3])
    return img.transpose((1,2,0))

def predb_to_mask(predb, idx):
    p = torch.functional.F.softmax(predb[idx], 0)
    return p.argmax(0).cpu()


In [26]:
def main ():
    #enable if you want to see some plotting
    visual_debug = True

    #batch size
    bs = 12

    #epochs
    epochs_val = 50

    #learning rate
    learn_rate = 0.01

    #sets the matplotlib display backend (most likely not needed)
    mp.use('TkAgg', force=True)

    #load the training data
    base_path = Path('./home/gkiss/Data/CAMUS_resized')
    data = DatasetLoader(base_path/'train_gray', 
                        base_path/'train_gt')
    print(len(data))

    #split the training dataset and initialize the data loaders
    train_dataset, valid_dataset = torch.utils.data.random_split(data, (300, 150))
    train_data = DataLoader(train_dataset, batch_size=bs, shuffle=True)
    valid_data = DataLoader(valid_dataset, batch_size=bs, shuffle=True)

    if visual_debug:
        fig, ax = plt.subplots(1,2)
        ax[0].imshow(data.open_as_array(150))
        ax[1].imshow(data.open_mask(150))
        plt.show()

    xb, yb = next(iter(train_data))
    print (xb.shape, yb.shape)

    # build the Unet2D with one channel as input and 2 channels as output
    unet = Unet2D(1,2)

    #loss function and optimizer
    loss_fn = nn.CrossEntropyLoss()
    opt = torch.optim.Adam(unet.parameters(), lr=learn_rate)

    #do some training
    train_loss, valid_loss = train(unet, train_data, valid_data, loss_fn, opt, acc_metric, epochs=epochs_val)

    #plot training and validation losses
    if visual_debug:
        plt.figure(figsize=(10,8))
        plt.plot(train_loss, label='Train loss')
        plt.plot(valid_loss, label='Valid loss')
        plt.legend()
        plt.show()

    #predict on the next train batch (is this fair?)
    xb, yb = next(iter(train_data))
    with torch.no_grad():
        predb = unet(xb.cuda())

    #show the predicted segmentations
    if visual_debug:
        fig, ax = plt.subplots(bs,3, figsize=(15,bs*5))
        for i in range(bs):
            ax[i,0].imshow(batch_to_img(xb,i))
            ax[i,1].imshow(yb[i])
            ax[i,2].imshow(predb_to_mask(predb, i))

        plt.show()

In [27]:
if __name__ == "__main__":
    main()

450
torch.Size([12, 1, 384, 384]) torch.Size([12, 384, 384])
Epoch 0/49
----------
Epoch 0/49
----------
train Loss: 0.5206 Acc: 0.902328610420227
----------
Epoch 0/49
----------
valid Loss: 0.3820 Acc: 0.9709836840629578
----------
Epoch 1/49
----------
Epoch 1/49
----------
train Loss: 0.2258 Acc: 0.9317498803138733
----------
Epoch 1/49
----------
valid Loss: 0.1807 Acc: 0.9701107740402222
----------
Epoch 2/49
----------
Epoch 2/49
----------
train Loss: 0.1574 Acc: 0.9317498207092285
----------
Epoch 2/49
----------
valid Loss: 0.1979 Acc: 0.9701799750328064
----------
Epoch 3/49
----------
Epoch 3/49
----------
train Loss: 0.1378 Acc: 0.9317498207092285
----------
Epoch 3/49
----------
valid Loss: 0.1608 Acc: 0.9710049629211426
----------
Epoch 4/49
----------
Epoch 4/49
----------
train Loss: 0.1291 Acc: 0.9317498207092285
----------
Epoch 4/49
----------
valid Loss: 0.1298 Acc: 0.97055983543396
----------
Epoch 5/49
----------
Epoch 5/49
----------
train Loss: 0.1228 Acc: 0.93

KeyboardInterrupt: 