In [None]:
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import os
import numpy as np
import cv2 as cv
import matplotlib.pyplot as plt

from torchvision import models
from torchvision.models.vgg import VGG

device = 'cuda:1'

In [None]:
# just a test

image = cv.imread("/disk1/jklu/ClearSpineData/CAT_only_256_512/images/00001_AP.png", cv.IMREAD_GRAYSCALE)

In [None]:
# self-defined dataset

class SpineDataSet(torch.utils.data.Dataset):
    def __init__(self, path):
    
        image_files = os.listdir(path + "/images/")
        image_files.sort(key=lambda x:int(x[:5]))
        
        txt_files = os.listdir(path + "/txt/")
        txt_files.sort(key=lambda x:int(x[:5]))

        self.images = []
        self.labels = []
        self.heatmaps = []
        
        for img in image_files:

            image = cv.imread(path + "/images/" + img, cv.IMREAD_GRAYSCALE)
            
            image = np.reshape(image, (1, image.shape[0], image.shape[1]))
            
            img = torch.from_numpy(image).type(torch.FloatTensor)
            
            self.images.append(img)

        for txt in txt_files:
            src_txt = open(path + "/txt/" + txt, 'r')
            points = np.array(src_txt.read().split()).astype(np.float)
            
            pts = torch.from_numpy(points).type(torch.FloatTensor)
            
            self.labels.append(points)
            src_txt.close()
            
            heatmap = self.generate_heatmap(256, 512, points, 8)            
            self.heatmaps.append(heatmap)
           
        assert len(self.images) == len(self.labels)
                
    def __getitem__(self, index):
        img = self.images[index]
        label = self.labels[index]
        heatmap = self.heatmaps[index]
        return img, label, heatmap
    
    def __len__(self):
        return len(self.images)
    
    def generate_heatmap(self, width, height, coordinate_array, sigma):
        
        landmark_num = len(coordinate_array) // 2
        
        heat_map = -128 * np.ones((landmark_num, height, width), np.float)
        
        for index in range(landmark_num):
            x = coordinate_array[2 * index]
            y = coordinate_array[2 * index + 1]
        
            for xx in range(int(x) - 3 * sigma, int(x) + 3 * sigma + 1):
                for yy in range(int(y) - 3 * sigma, int(y) + 3 * sigma + 1):
                    value = 256 * np.exp(-( (x-xx)**2 + (y-yy)**2) / (2*sigma**2)) - 128
                    if 0 <= xx < width and 0 <= yy < height:
                        heat_map[index, yy, xx] = value
        
        return heat_map
            

In [None]:
# initialize the train and test data loader

dataset = SpineDataSet("/disk2/jklu/ClearSpineData/CAT_only_256_512")

batch = 4
train_num = 400
test_num = 17

train_data, test_data = torch.utils.data.random_split(dataset, [train_num, test_num])

train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch,
                                         shuffle=True, num_workers=2)

test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch,
                                         shuffle=True, num_workers=2)


print("number of batches(training): ", len(train_loader))
print("number of batches(testing): ",len(test_loader))

In [None]:
ranges = {
    'vgg11': ((0, 3), (3, 6),  (6, 11),  (11, 16), (16, 21)),
    'vgg13': ((0, 5), (5, 10), (10, 15), (15, 20), (20, 25)),
    'vgg16': ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)),
    'vgg19': ((0, 5), (5, 10), (10, 19), (19, 28), (28, 37))
}

# Vgg-Net config 
# Vgg网络结构配置
cfg = {
    'vgg11': [64, 'M', 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg13': [64, 64, 'M', 128, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M'],
    'vgg16': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512, 'M'],
    'vgg19': [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 256, 'M', 512, 512, 512, 512, 'M', 512, 512, 512, 512, 'M'],
}

# make layers using Vgg-Net config(cfg)
# 由cfg构建vgg-Net
def make_layers(cfg, batch_norm=False):
    layers = []
    in_channels = 1
    for v in cfg:
        if v == 'M':
            layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
        else:
            conv2d = nn.Conv2d(in_channels, v, kernel_size=3, padding=1)
            if batch_norm:
                layers += [conv2d, nn.BatchNorm2d(v), nn.ReLU(inplace=True)]
            else:
                layers += [conv2d, nn.ReLU(inplace=True)]
            in_channels = v
    return nn.Sequential(*layers)


class VGGNet(VGG):
    def __init__(self, pretrained=False, model='vgg16', requires_grad=True, remove_fc=True, show_params=False):
        super().__init__(make_layers(cfg[model]))
        self.ranges = ranges[model]

        if pretrained:
            exec("self.load_state_dict(models.%s(pretrained=True).state_dict())" % model)

        if not requires_grad:
            for param in super().parameters():
                param.requires_grad = False

        # delete redundant fully-connected layer params, can save memory
        # 去掉vgg最后的全连接层(classifier)
        if remove_fc:  
            del self.classifier

        if show_params:
            for name, param in self.named_parameters():
                print(name, param.size())

    def forward(self, x):
        output = {}
        # get the output of each maxpooling layer (5 maxpool in VGG net)
        for idx, (begin, end) in enumerate(self.ranges):
        #self.ranges = ((0, 5), (5, 10), (10, 17), (17, 24), (24, 31)) (vgg16 examples)
            for layer in range(begin, end):
                x = self.features[layer](x)
            output["x%d"%(idx+1)] = x

        return output
    
    

class FCNs(nn.Module):

    def __init__(self, pretrained_net, n_class):
        super().__init__()
        self.n_class = n_class
        self.pretrained_net = pretrained_net
        self.relu    = nn.ReLU(inplace=True)
        self.deconv1 = nn.ConvTranspose2d(512, 512, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn1     = nn.BatchNorm2d(512)
        self.deconv2 = nn.ConvTranspose2d(512, 256, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn2     = nn.BatchNorm2d(256)
        self.deconv3 = nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn3     = nn.BatchNorm2d(128)
        self.deconv4 = nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn4     = nn.BatchNorm2d(64)
        self.deconv5 = nn.ConvTranspose2d(64, 32, kernel_size=3, stride=2, padding=1, dilation=1, output_padding=1)
        self.bn5     = nn.BatchNorm2d(32)
        self.classifier = nn.Conv2d(32, n_class, kernel_size=1) 
        # classifier is 1x1 conv, to reduce channels from 32 to n_class

    def forward(self, x):
        output = self.pretrained_net(x)
        x5 = output['x5']  
        x4 = output['x4']  
        x3 = output['x3']  
        x2 = output['x2']  
        x1 = output['x1']  

        score = self.bn1(self.relu(self.deconv1(x5)))     
        score = score + x4                                
        score = self.bn2(self.relu(self.deconv2(score)))  
        score = score + x3                                
        score = self.bn3(self.relu(self.deconv3(score)))  
        score = score + x2                                
        score = self.bn4(self.relu(self.deconv4(score)))  
        score = score + x1                                
        score = self.bn5(self.relu(self.deconv5(score)))  
        score = self.classifier(score)                    

        return score  


In [None]:
vgg_model = VGGNet(requires_grad=True, show_params=True)

net = FCNs(pretrained_net=vgg_model, n_class=4)
#net = torch.load('/disk1/jklu/models/VGG-FCN.pth')

net = net.to(device)

In [None]:
# Visualize the data

#heatmap = dataset.generate_heatmap(256, 512, [200, 300], 10)
#print(heatmap.shape)
#plt.imshow(heatmap[0])
# print(dataset.generate_heatmap(256, 512, 200, 300, 10))

# get some random training images
dataiter = iter(train_loader)
images, labels, heatmaps = dataiter.next()

images = images.to(device)
labels = labels.to(device)

out = net(images)
print(out.shape)

# show images
#plt.imshow(labels[0, 0] + labels[0, 1] + labels[0, 2] + 3*labels[0, 3])
#plt.imshow(images[0].numpy().squeeze(0))

In [None]:
# here define loss function

import torch.optim as optim

criterion = nn.MSELoss().to(device)
optimizer = optim.Adam(net.parameters(), lr=0.001)

#lr_func = lambda epoch: epoch * 1
#scheduler = optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lr_func)
#scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,mode='min',factor=0.98)

In [None]:
# function to compute loss

def compute_loss(net, data_loader):
    loss_sum = 0
    for i, data in enumerate(data_loader, 0):
        images, _, heatmaps = data
        images = images.to(device)
        heatmaps = heatmaps.to(device)
        outputs = net(images)
        
        loss = criterion(outputs.float(), heatmaps.float()) 
        loss_sum += loss.item()
       
    return loss_sum / len(data_loader)

In [None]:
# training

epoch_num = 70

loss1 = []
loss2 = []


for epoch in range(epoch_num):  # loop over the dataset multiple times
    loss_sum = 0
    
    test_loss = compute_loss(net, test_loader)
    loss2.append(test_loss)
    
    for i, data in enumerate(train_loader, 0):
                
        # get the inputs
        inputs, _, heatmaps = data

        inputs = inputs.to(device)
        heatmaps = heatmaps.to(device)
        # zero the parameter gradients
        optimizer.zero_grad()

        # forward + backward + optimize
        outputs = net(inputs)
    
        
        loss = criterion(outputs.float(), heatmaps.float())
        
        loss_sum += loss.item()
        
        loss.backward()
        
        optimizer.step()
        
    
    #scheduler.step(loss.item())
    
    #train_loss = compute_loss(net, train_loader)
    train_loss = loss_sum / len(train_loader)
    loss1.append(train_loss)

    
    print("epoch number", epoch+1, "train_loss", train_loss, "test_loss", test_loss)

print('Finished Training')

In [None]:
x = range(1, epoch_num+1)

plt.plot(x, loss1, label="training loss")
plt.plot(x, loss2, label="testing loss")
plt.xlabel("epoch number")
#plt.xticks(np.linspace(1,epoch_num,epoch_num))
plt.ylabel("loss")
plt.title("loss graph")
plt.legend()
plt.savefig("./VGG-FCN.jpg")
plt.show()

In [None]:
torch.save(net, '/disk1/jklu/models/VGG-FCN.pth')

In [None]:
def CVimshow2pltimshow(cv_img):
    """
    cv_img: [3, height, width], BGR, numpy array, int(0-255)
    """
    b,g,r = cv.split(cv_img)  
    plt_img = cv.merge([r,g,b]).astype(np.int)
    return plt_img

In [None]:
# visualize result (heat map):

import matplotlib as mpl
mpl.rcParams['figure.figsize'] = (15, 15)

dataiter = iter(test_loader)
images, labels, heatmaps = dataiter.next()
batch, channel, height, width = images.shape

ret = net(images.to(device)).to('cpu').detach().numpy()
_, point, _, _ = ret.shape


for i in range(batch):
    sample = images[i].numpy().squeeze(0)
    sample_BGR = cv.cvtColor(sample, cv.COLOR_GRAY2BGR)
    
    label = labels[i]
    for j in range(len(label) // 4):
        cv.line(sample_BGR, (int(label[4*j]), int(label[4*j+1])), 
                (int(label[4*j+2]), int(label[4*j+3])), (255, 0, 0), 1)
    sample_RGB = CVimshow2pltimshow(sample_BGR)
    plt.subplot(2,4,2*i+1)
    plt.imshow(sample_RGB)    
    
    
    img = ret[i, 0] + ret[i, 1] + ret[i, 2] + ret[i, 3]

    norm_heatmap = np.zeros(img.shape, dtype=np.int8)
    norm_heatmap = cv.normalize(img, norm_heatmap, 0, 255, norm_type=cv.NORM_MINMAX).astype(np.int8)
    
    plt.subplot(2,4,2*i+2)
    plt.imshow(img)  
    
    """
    norm_heatmap = np.zeros(img.shape, dtype=np.int8)
    
    norm_heatmap = cv.normalize(img, norm_heatmap, 0, 255, norm_type=cv.NORM_MINMAX).astype(np.int8)
    combine = np.zeros((height, width, 3), np.int8)
    

    combine[:, :, 2] = norm_heatmap
    
    #print(combine)
    
    for j in range(len(label) // 4):
        
        #print(label)
        
        cv.line(combine, (0,0), (0, 10), (0,0,0), 2)
        
        print(combine[1,1])
        
        cv.line(combine, (int(label[4*j]), int(label[4*j+1])), 
                (int(label[4*j+2]), int(label[4*j+3])), (0, 0, 255), 2) 
        
    plt_img = CVimshow2pltimshow(combine)
    
    plt.subplot(2,4,4+i+1)
    plt.imshow(plt_img)    
    """

plt.show()

In [None]:
img = np.array([[1,2,3,4],[5,6,7,8]])
    
norm_heatmap = np.zeros(img.shape, dtype=np.int)
    
norm_heatmap = cv.normalize(img, norm_heatmap, 0, 256, norm_type=cv.NORM_MINMAX).get()

print(norm_heatmap)