In [None]:
%matplotlib inline

import torch
from PIL import Image

IS_GPU = torch.cuda.is_available()

In [None]:
# Show sample images

In [None]:
class DDSM(torch.utils.data.Dataset):
    def __init__(self, root, label_file, transform):
        self.root = root
        self.transform = transform
        self.image_list = []
        
        def process_line(line):
            img_name, label = line.strip().split(' ')
            label = int(label)
            return img_name, label
        
        with open(os.path.join(root, label_file), 'r') as f:
            for line in f.readlines():
                self.image_list.append(process_line(line))
            
    def __len__(self):
        return len(self.image_list)
    
    def __getitem__(self, idx):
        img_name, label = self.image_list[idx]
        img = Image.open(os.path.join(self.root, img_name)).convert('RGB')
        img = self.transform(img)
        return img, label

In [None]:
# Calculate mean and std in advance

def cal_mean_std():
    import numpy as np
    import torchvision.transforms as transforms

    train_dataset = DDSM(root, 'images/dataset-v2/train.txt', transforms.Compose([transforms.Resize(224),
                                                                                  transforms.ToTensor()],))
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=1024, shuffle=False, num_workers=4, pin_memory=True)

    data_mean = []
    data_std = []

    for batch in train_loader:
        np_images = batch[0].numpy()
        batch_mean = np.mean(np_images, axis=(0, 2, 3))
        batch_std = np.std(np_images, axis=(0, 2, 3))
        data_mean.append(batch_mean)
        data_std.append(batch_std)

    data_mean = np.array(data_mean).mean(axis=0)
    data_std = np.array(data_std).mean(axis=0)

    print('mean: {}'.format(data_mean))
    print('std: {}'.format(data_std))

In [None]:
# Dataloader

import os
import torchvision.transforms as transforms

root = './'

normalize = transforms.Normalize(mean=[0.504, 0.504, 0.504], std=[0.172, 0.172, 0.172])

transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    normalize,
])

augmented_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(360),
    transforms.RandomResizedCrop(224, scale=(0.85, 1.0)),
    transforms.ToTensor(),
    normalize,
])

train_dataset = DDSM(root, 'images/dataset-v2/train.txt', augmented_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

test_dataset = DDSM(root, 'images/dataset-v2/test.txt', transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle=True, num_workers=4, pin_memory=True)

In [None]:
print(len(train_loader))
print(len(test_loader))

In [None]:
# Baseline CNN

import torch.nn as nn
import torch.nn.functional as F

class BaseNet(nn.Module):
    def __init__(self):
        super(BaseNet, self).__init__()
        self.conv_net = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(2, 2),
            
            nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(num_features=32),
            nn.ReLU(inplace=True),
            
            nn.MaxPool2d(2, 2),
        )
        
        self.fc_net = nn.Sequential(
            nn.Linear(in_features=288, out_features=512),
            nn.Linear(in_features=512, out_features=128),
            nn.Linear(in_features=128, out_features=2),
        )
        
    def forward(self, x):
        x = self.conv_net(x)
        x = x.view(x.size(0), -1) # batch_size x image
        x = self.fc_net(x)
        
        return x


In [None]:
# Baseline CNN

def get_basenet():
    model = BaseNet()

    if IS_GPU:
        model = model.cuda()

    model.eval()
    
    return model

In [None]:
# AlexNet

def get_alexnet(pretrained=False):
    import torchvision.models as models
    import torch.nn as nn

    model = models.__dict__['alexnet'](pretrained=False)

    model.classifier._modules['6'] = nn.Linear(4096, 2)

    if IS_GPU:
        model.features = torch.nn.DataParallel(model.features)
        model.cuda()

    model.eval()
    
    return model

In [None]:
# ResNet-18

def get_resnet(pretrained=False):
    import torchvision.models as models
    import torch.nn as nn

    model = models.__dict__['resnet18'](pretrained=pretrained)

    num_features = model.fc.in_features
    model.fc = nn.Linear(num_features, 2)

    if IS_GPU:
        model = torch.nn.DataParallel(model)
        model.cuda()

    model.eval()
    
    return model

In [None]:
# Define optimizer

def get_optimizer(model, lr=0.01, momentum=0.9):
    import torch.optim as optim

    criterion = nn.CrossEntropyLoss().cuda()
    optimizer = optim.SGD(model.parameters(), lr=lr, momentum=momentum)
    
    return optimizer

In [None]:
# Train Model

def train_model(model, optimizer, start_epoch=1, epoch_num=20, filename=None, save_model_name=None):
    import time

    # filename = 'resnet-no-tl.txt'

    start_ts = time.time()

    best_test_acc = 0.0
    best_epoch = 0

    for epoch in range(start_epoch, start_epoch + epoch_num):
        epoch_loss = 0.0
        epoch_accuracy = 0.0

        sample_count = 0
        correct_count = 0

        for i, samples in enumerate(train_loader):
            images, labels = samples
            if IS_GPU:
                images = images.cuda()
                labels = labels.cuda()

            optimizer.zero_grad()

            outputs = model(images)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            _, preds = torch.max(outputs.data, 1)

            epoch_loss += loss.item()

            sample_count += len(labels)
            correct_count += torch.sum(preds == labels).item()

        print('Time: {}'.format(int(time.time() - start_ts)))

        train_acc = correct_count / sample_count
        test_acc, precision, recall, F1 = evaluate(model, test_loader, epoch)

        if test_acc > best_test_acc:
            best_test_acc = test_acc
            best_epoch = epoch

        print('Epoch {} '.format(epoch) + 'tarin accuracy: ' + str(train_acc))
        print('Epoch {} '.format(epoch) + 'loss: ' + str(epoch_loss))
        print('Epoch {} '.format(best_epoch) + 'has the best test accuracy: ' + str(best_test_acc))
        print()

        if filename:
            with open('./evaluation/{}'.format(filename), 'a') as f:
                f.write('Evaluate epoch {} '.format(epoch) + 'accuracy: ' + str(test_acc) + '\n')
                f.write('Evaluate epoch {} '.format(epoch) + 'precision: ' + str(precision) + '\n')
                f.write('Evaluate epoch {} '.format(epoch) + 'recall: ' + str(recall) + '\n')
                f.write('Evaluate epoch {} '.format(epoch) + 'F1: ' + str(F1) + '\n')
                f.write('epoch {} '.format(epoch) + 'tarin accuracy: ' + str(train_acc) + '\n')
                f.write('epoch {} '.format(epoch) + 'loss: ' + str(epoch_loss) + '\n')
                f.write('epoch {} '.format(best_epoch) + 'has the best test acc: ' + str(best_test_acc) + '\n\n')
                
    if save_model_name:
        torch.save(model, './models/{}.pth'.format(save_model_name))
        

In [None]:
# model = get_resnet()
# optimizer = get_optimizer(model)
# train_model(model, optimizer, start_epoch=451, epoch_num=100, filename='resnet-no-tl.txt')
train_model(model, optimizer, start_epoch=451, epoch_num=100)

In [None]:
model = get_resnet()
optimizer = get_optimizer(model)
train_model(model, optimizer, start_epoch=1, epoch_num=400, filename='resnet-no-tl.txt')

In [None]:
train_model(model, optimizer, start_epoch=601, epoch_num=50, filename='resnet-no-tl.txt')

In [None]:
# model = get_basenet()
# optimizer = get_optimizer(model, lr=0.01)
# train_model(model, optimizer, start_epoch=101, epoch_num=100, filename='basenet-lr-1-2.txt')

In [None]:
# Evaluation

def evaluate(model, data_loader, epoch):
    with torch.no_grad():
        sample_count = 0
        correct_count = 0
        
        TP = 0
        FP = 0
        FN = 0
        
        accuracy = 0.0
        precision = 0.0
        recall = 0.0
        F1 = 0.0

        for i, samples in enumerate(data_loader):
            images, labels = samples

            if IS_GPU:
                images = images.cuda()
                labels = labels.cuda()

            outputs = model(images)
            _, preds = torch.max(outputs.data, 1)

            sample_count += len(labels)
            correct_count += torch.sum(preds == labels).item()
            
            for idx in range(len(labels)):
                if labels[idx] == 1 and preds[idx] == 1:
                    TP += 1
                if labels[idx] == 1 and preds[idx] == 0:
                    FN += 1
                if labels[idx] == 0 and preds[idx] == 1:
                    FP += 1
        
        try:
            accuracy = correct_count / sample_count
            precision = TP / (TP + FP)
            recall = TP / (TP + FN)
            F1 = 2 * precision * recall / (precision + recall)

            print('Evaluate epoch {} accuracy: '.format(epoch) + str(accuracy))
            print('Evaluate epoch {} precision: '.format(epoch) + str(precision))
            print('Evaluate epoch {} recall: '.format(epoch) + str(recall))
            print('Evaluate epoch {} F1: '.format(epoch) + str(F1))
        except ZeroDivisionError:
            pass
        
        return accuracy, precision, recall, F1


# Visualization

In [None]:
def compute_saliency_maps(X, y, model):
    
    from torch.autograd import Variable
    
    """
    使用模型图像(image)X和标记(label)y计算正确类的saliency map.

    输入:
    - X: 输入图像; Tensor of shape (N, 3, H, W)
    - y: 对应X的标记; LongTensor of shape (N,)
    - model: 一个预先训练好的神经网络模型用于计算X.

    返回值:
    - saliency: A Tensor of shape (N, H, W) giving the saliency maps for the input
    images.
    """
    # Make sure the model is in "test" mode
    model.eval()

    # Wrap the input tensors in Variables
    X_var = Variable(X, requires_grad=True)
    y_var = Variable(y)
    saliency = None
    ##############################################################################
    #
    # 首先进行前向操作，将输入图像pass through已经训练好的model，再进行反向操作，
    # 从而得到对应图像,正确分类分数的梯度
    # 
    ##############################################################################

    # 前向操作
    scores = model(X_var)

    # 得到正确类的分数，scores为[5]的Tensor
    scores = scores.gather(1, y_var.view(-1, 1)).squeeze() 

    #反向计算，从输出的分数到输入的图像进行一系列梯度计算
    scores.backward(torch.FloatTensor([1.0,1.0,1.0,1.0,1.0]).cuda()) # 参数为对应长度的梯度初始化
#     scores.backward() 必须有参数，因为此时的scores为非标量，为5个元素的向量

    # 得到正确分数对应输入图像像素点的梯度
    saliency = X_var.grad.data

    saliency = saliency.abs() # 取绝对值
    saliency, i = torch.max(saliency,dim=1)  # 从3个颜色通道中取绝对值最大的那个通道的数值
    saliency = saliency.squeeze() # 去除1维
#     print(saliency)

    return saliency

In [None]:
from flashtorch.utils import load_image
import matplotlib.pyplot as plt

image_path = './images/malignant_mass/malignant_mass_1807.png'

image = load_image(image_path)

plt.imshow(image)
plt.axis('off')
plt.title(None)

In [None]:
plot_saliency_map(image, 0, model)

In [None]:
model = torch.load('./models/res18_aug_lr_0_001_epoch_250.pth')
model.eval()
evaluate(model, test_loader, 0)

In [None]:
def plot_saliency_map(image, label, model):
    from flashtorch.utils import apply_transforms
    
    input_image = apply_transforms(image)
    
    saliency_map = compute_saliency_maps(input_image.cuda(), torch.tensor(label).cuda(), model)
    saliency_map = saliency_map.cpu().numpy()

    plt.imshow(saliency_map)
    plt.axis('off')