# Novel Data: Training PFENet

For a description of how the [PFENet](https://github.com/dvlab-research/PFENet) works, please refer to [Understanding PFENet](./Understanding%20PFENet.ipynb).

In [1]:
import matplotlib.pyplot as plt
import numpy as np
import cv2
import os
from util import dataset
from util import transform

def render(img):
    I = np.array(img)
    I = I - I.min((1,2), keepdims=1)
    I = I / I.max((1,2), keepdims=1)
    I = I.transpose((1,2, 0))
    return I
    
def overlay(img, mask, chan=1, weights=[0.5, 0.5]):
    T = np.zeros_like(img)
    T[..., chan] = mask
    return img * weights[0] + T * weights[1]

def val_transform(img, label, __cache={}):
    if not __cache:
        value_scale = 255
        mean = [0.485, 0.456, 0.406]
        mean = [item * value_scale for item in mean]
        std = [0.229, 0.224, 0.225]
        std = [item * value_scale for item in std]
        __cache['val_transform'] = transform.Compose([
            transform.test_Resize(size=473),
            transform.ToTensor(),
            transform.Normalize(mean=mean, std=std),
        ])
    f = __cache['val_transform']
    return f(img, label)

def fetch_ivus_triplet(fi, fm, fl):
    with open(fl) as f:
        pts_lum = np.array([
            list(map(float, line.strip().split(',')))
            for line in f.readlines()
        ], np.int32)
        pts_lum = pts_lum.reshape((-1, 1, 2))

    with open(fm) as f:
        pts_med = np.array([
            list(map(float, line.strip().split(',')))
            for line in f.readlines()
        ], np.int32)
        pts_med = pts_med.reshape((-1, 1, 2))
    
    img = cv2.imread(fi, cv2.IMREAD_COLOR)
    img = img[..., [0, 0, 0]]
    

    # Inside of lumen
    lab_lumen = np.zeros(img.shape[:-1])
    cv2.fillPoly(lab_lumen, [pts_lum], 1)
    
    # Inside of media
    lab_media = np.zeros_like(lab_lumen)
    cv2.fillPoly(lab_media, [pts_med], 1)
    #lab_media -= lab_lumen # remove points in lumen
    
    # Border of lumen
    lab_lumen = cv2.polylines(lab_lumen, [pts_lum], True, 255, 2)
    
    # Border of media
    #lab_media = cv2.polylines(lab_media, [pts_lum], True, 255, 2)
    lab_media = cv2.polylines(lab_media, [pts_med], True, 255, 2)
    
    img = cv2.resize(img, (473, 473))
    lab_media = cv2.resize(lab_media, (473, 473))
    lab_lumen = cv2.resize(lab_lumen, (473, 473))
    lab_media = np.where((lab_media == 1) | (lab_media == 255), lab_media, 0)
    lab_lumen = np.where((lab_lumen == 1) | (lab_lumen == 255), lab_lumen, 0)
    
    # Use the model transforms
    _  , lab_media = val_transform(img, lab_media)
    img, lab_lumen = val_transform(img, lab_lumen)

    return img.numpy(), lab_media.numpy(), lab_lumen.numpy()

def load_data(location='../dataset/Data_set_B', batch=None, batchsize=10):

    import os
    import gc
    from pathlib import Path
    LOCATION = Path(location)

    lum_files = sorted([f for f in os.listdir(LOCATION / 'LABELS_obs2_v2/') if 'lum' in f])
    med_files = sorted([f for f in os.listdir(LOCATION / 'LABELS_obs2_v2/') if 'med' in f])

    gc.collect()
    torch.cuda.empty_cache() # just in case
    images = []
    medias = []
    lumens = []
    for i, (luf, mef) in enumerate(zip(lum_files, med_files)):
        if batch is not None:
            if i < batch * batchsize:
                continue
            if i >= (batch + 1) * batchsize:
                break
        assert luf[3:] == mef[3:], (luf, mef)
        _, _, patient, frame, _ = luf.split('_')
        imf = LOCATION/ f'DCM/frame_{patient}_{frame}_003.png'
        luf = os.path.join(LOCATION / 'LABELS_obs2_v2/', luf)
        mef = os.path.join(LOCATION / 'LABELS_obs2_v2/', mef)
        imf, luf, mef = map(str, [imf, luf, mef])
        assert os.path.exists(imf), imf
        img, med, lum = fetch_ivus_triplet(imf, mef, luf)
        images.append(img)
        medias.append(med)
        lumens.append(lum)
    images = np.stack(images)
    medias = np.stack(medias)
    medias = np.where(medias == 1, 2, medias)
    lumens = np.stack(lumens)
    torch.cuda.empty_cache()
    return images, medias, lumens

In [2]:
from model.PFENet import PFENet
import torch
from torch import nn
SHOT = 1
def get_model(shot, checkpoint='exp/ivus/split0_resnet50/model/final.pth'):
    torch.cuda.empty_cache()
    model = PFENet(layers=50, classes=2, zoom_factor=8, \
        criterion=nn.CrossEntropyLoss(ignore_index=255), BatchNorm=nn.BatchNorm2d, \
        pretrained=True, shot=shot, ppm_scales=[60, 30, 15, 8])

    checkpoint = torch.load(checkpoint)
    model = torch.nn.DataParallel(model.cuda())
    model.load_state_dict(checkpoint['state_dict'])
    model.train(mode=False)
    print('Allocated:', torch.cuda.memory_allocated())
    return model

In [3]:
def show3(x, y, p, imshowargs=None, **kwargs):
    _imshowargs = dict(cmap='gray')
    if imshowargs:
        _imshowargs.update(imshowargs)
    
    fig, ax = plt.subplots(1, 3, **kwargs)
    
    for a in ax:
        a.set_yticks([])
        a.set_xticks([])
    
    ax[0].set_title('Query')
    ax[0].imshow(x, **_imshowargs)
    
    ax[1].set_title('Target')
    ax[1].imshow(y, **_imshowargs)
    
    ax[2].set_title('Prediction')
    ax[2].imshow(p, **_imshowargs)
    
    return fig, ax

In [5]:
import os
OUTDIR = 'predictions'
os.makedirs(OUTDIR, exist_ok=True)

In [7]:
images, medias, lumens = load_data()
#torch.device('cpu')
model = get_model(shot=1)

[60, 30, 15, 8]
INFO: Using ResNet 50
Allocated: 327864832


**Warning!** The following cells are very resource intensive, if you dont have a *huge* graphics card it will probably crash. Consider splitting the training in sub-datasets.

In [8]:
x = torch.tensor(images[:, ...])
y = torch.tensor(medias[:, ...])
s_x = torch.tile(torch.tensor(images[:1, ...]), (len(x), 1, 1, 1, 1))
s_y = torch.tile(torch.tensor(medias[:1, ...]), (len(x), 1, 1, 1))

output1 = model(x=x, s_x=s_x, s_y=s_y).cpu().detach().numpy()
#torch.cuda.empty_cache()

In [10]:
for i, p in enumerate(output.argmax(1)):
    num = str(i + n).zfill(4)
    labname = f'lumen-{num}.png' # we start at the 1st image and
    print('Saving', labname)
    plt.imshow(p)
    plt.show()
    #assert cv2.imwrite(os.path.join(OUTDIR, labname), p)    

NameError: name 'output' is not defined