In [1]:
import math
import random
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

# Path Settings
img_path = 'data/1-NSCLC-TCIA/3-NSCLC-TCIA-MV/'
data_path = 'data/1-NSCLC-TCIA/TCIA-325/'


def get_k_fold(k, i):
    train_list = None
    for j in range(k):
        if j == i:
            valid_list = np.loadtxt(data_path + 'dataset_fold_' + str(j) + '.txt', delimiter=None, dtype=str)
        elif train_list is None:
            train_list = np.loadtxt(data_path + 'dataset_fold_' + str(j) + '.txt', delimiter=None, dtype=str)
        else:
            tmp = np.loadtxt(data_path + 'dataset_fold_' + str(j) + '.txt', delimiter=None, dtype=str)
            train_list = np.concatenate((train_list, tmp), axis=0)
            tmp = None
    return train_list, valid_list


def setup_seed(seed):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    random.seed(seed)
    torch.backends.cudnn.deterministic = True
    
setup_seed(2000)

In [2]:
# Dataset Preparation

class Dataset_nsclc(Dataset):
    
    
    def __init__(self, img_path, datalist, is_transform=None):
        
        self.path = img_path
        self.transforms = is_transform

        img_name = []
        label = []

        for line in datalist:
            img_name.append(line[0])
            label.append(int(line[1]))

        label = torch.tensor(label)
            
        self.name = img_name
        self.label = label
    
    
    def get_img(self, index):
        data = np.load(self.path + self.name[index] + '.npy')
        data = torch.from_numpy(data).type(torch.FloatTensor)
        
        if self.transforms:
            transform = [transforms.RandomHorizontalFlip(), transforms.RandomVerticalFlip()]
            transform = transforms.Compose([transforms.ToPILImage(),transforms.RandomChoice(transform),transforms.ToTensor()])
            data  = transform(data)

        if self.transforms == 'feature':
            transform = transforms.Compose([transforms.ToPILImage(),transforms.RandomHorizontalFlip(p=1),transforms.ToTensor()])
            data = transform(data)
        
        return data
    
    
    def __getitem__(self, index):
        data = self.get_img(index)
        data_z = data[:,:,0].unsqueeze(0)
        data_y = data[:,:,1].unsqueeze(0)
        data_x = data[:,:,2].unsqueeze(0)
        return data_z, data_y, data_x, self.label[index], self.name[index]
    
    
    def __len__(self):
        return len(self.label)

In [3]:
from torch import nn
from torch.nn import functional as F

In [4]:
# Tool Functions

class CMD(nn.Module):
    """
    Adapted from https://github.com/wzell/cmd/blob/master/models/domain_regularizer.py
    """

    def __init__(self):
        super(CMD, self).__init__()

    def forward(self, x1, x2, n_moments):
        mx1 = torch.mean(x1, 0)
        mx2 = torch.mean(x2, 0)
        sx1 = x1-mx1
        sx2 = x2-mx2
        dm = self.matchnorm(mx1, mx2)
        scms = dm
        for i in range(n_moments - 1):
            scms += self.scm(sx1, sx2, i + 2)
        return scms

    def matchnorm(self, x1, x2):
        power = torch.pow(x1-x2,2)
        summed = torch.sum(power)
        sqrt = summed**(0.5)
        return sqrt
        # return ((x1-x2)**2).sum().sqrt()

    def scm(self, sx1, sx2, k):
        ss1 = torch.mean(torch.pow(sx1, k), 0)
        ss2 = torch.mean(torch.pow(sx2, k), 0)
        return self.matchnorm(ss1, ss2)


class DiffLoss(nn.Module):

    def __init__(self):
        super(DiffLoss, self).__init__()

    def forward(self, input1, input2):

        batch_size = input1.size(0)
        input1 = input1.view(batch_size, -1)
        input2 = input2.view(batch_size, -1)

        # Zero mean
        input1_mean = torch.mean(input1, dim=0, keepdims=True)
        input2_mean = torch.mean(input2, dim=0, keepdims=True)
        input1 = input1 - input1_mean
        input2 = input2 - input2_mean

        input1_l2_norm = torch.norm(input1, p=2, dim=1, keepdim=True).detach()
        input1_l2 = input1.div(input1_l2_norm.expand_as(input1) + 1e-6)
        

        input2_l2_norm = torch.norm(input2, p=2, dim=1, keepdim=True).detach()
        input2_l2 = input2.div(input2_l2_norm.expand_as(input2) + 1e-6)

        diff_loss = torch.mean((input1_l2.t().mm(input2_l2)).pow(2))

        return diff_loss


In [5]:
class DoubleConv(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(DoubleConv, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, 3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        return self.conv(x)

class Conv_1_1(nn.Module):
    def __init__(self, in_ch, out_ch):
        super(Conv_1_1, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, 1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True),
        )

    def forward(self, x):
        return self.conv(x)

In [6]:
class Residual(nn.Module):
    def __init__(self, input_channels, num_channels,
                use_1x1conv=False, strides=1):
        super().__init__()
        self.conv1 = nn.Conv2d(input_channels, num_channels,
                              kernel_size=3, padding=1, stride=strides)
        self.conv2 = nn.Conv2d(num_channels, num_channels,
                              kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(input_channels, num_channels,
                                  kernel_size=1, stride=strides)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(num_channels)
        self.bn2 = nn.BatchNorm2d(num_channels)
        
    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        Y += X
        return F.relu(Y)


def resnet_block(input_channels, num_channels, num_residuals,
                first_block=False):
    blk = []
    for i in range(num_residuals):
        if i == 0 and not first_block:
            blk.append(Residual(input_channels, num_channels,
                                use_1x1conv=True, strides=2))
        else:
            blk.append(Residual(num_channels, num_channels))
    return blk


class Resnet18(nn.Module):
    
    def __init__(self, block_num = 2):
        super().__init__()
        # Resnet18
        self.reslayer1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=3, stride=2, padding=1),
                                       nn.BatchNorm2d(64), nn.ReLU(),
                                       nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
        self.reslayer2 = nn.Sequential(*resnet_block(64, 64, block_num, first_block=True))
        self.reslayer3 = nn.Sequential(*resnet_block(64, 128, block_num))
        self.reslayer4 = nn.Sequential(*resnet_block(128, 256, block_num))
        self.reslayer5 = nn.Sequential(*resnet_block(256, 512, block_num))
        #self.reslayer6 = nn.Sequential(*resnet_block(512, 1024, block_num))

        
    def forward(self, x):
        # Resnet18
        x = self.reslayer1(x)
        x = self.reslayer2(x)
        x = self.reslayer3(x)
        x = self.reslayer4(x)
        x = self.reslayer5(x)
        #x = self.reslayer6(x)
        
        return x #1024

In [7]:
class Decoder(nn.Module):
    def __init__(self, in_ch):
        super(Decoder, self).__init__()
        self.conv_in = Conv_1_1(in_ch, 512)
        self.up1 = nn.ConvTranspose2d(512, 256, 2, stride=2)
        self.dconv1 = DoubleConv(256, 256)        
        self.up2 = nn.ConvTranspose2d(256, 128, 2, stride=2)
        self.dconv2 = DoubleConv(128, 128)
        self.up3 = nn.ConvTranspose2d(128, 64, 2, stride=2)
        self.dconv3 = DoubleConv(64, 64)
        self.up4 = nn.ConvTranspose2d(64, 32, 2, stride=2)
        self.dconv4 = DoubleConv(32, 32)
        self.up5 = nn.ConvTranspose2d(32, 16, 2, stride=2)
        self.dconv5 = DoubleConv(16, 16)
        self.conv_out = nn.Conv2d(16, 1, 1)
        #self.up6 = nn.ConvTranspose2d(16, 8, 2, stride=2)
        #self.dconv6 = DoubleConv(8, 8)
        #self.conv_out = nn.Conv2d(8, 1, 1)
    
    def forward(self, x):
        x = self.conv_in(x)                                      
        x = self.up1(x)
        x = self.dconv1(x)
        x = self.up2(x)
        x = self.dconv2(x)
        x = self.up3(x)
        x = self.dconv3(x)
        x = self.up4(x)
        x = self.dconv4(x)
        x = self.up5(x)
        x = self.dconv5(x)
        #x = self.up6(x)
        #x = self.dconv6(x)
        out = self.conv_out(x)
        out = nn.Sigmoid()(out)
        return out

In [13]:
class CARL(nn.Module):
    
    
    def __init__(self, 
                 block_num = 2, 
                 dropout = 0.4, 
                 use_cmd_sim = True,
                 loss_weight = [1, 0.1, 1]
                ):
        super(CARL, self).__init__()
        
        self.res_block_num = block_num
        self.dropout_rate = dropout
        self.use_cmd_sim = use_cmd_sim
        
        self.a = loss_weight[0]
        self.b = loss_weight[1]
        self.c = loss_weight[2]
        
        ##########################################
        # View-private Encoders
        ##########################################
        self.encoder_z = Resnet18(block_num)
        self.encoder_y = Resnet18(block_num)
        self.encoder_x = Resnet18(block_num)
        self.encoder_spe = nn.Sequential(*resnet_block(512, 1024, block_num))
        
        ##########################################
        # View-common Encoders
        ##########################################
        self.encoder_share = Resnet18(block_num)
        self.encoder_com = nn.Sequential(*resnet_block(512, 1024, block_num))
        
        ##########################################
        # Decoders
        ##########################################
        self.decoder_z = Decoder(512)
        self.decoder_y = Decoder(512)
        self.decoder_x = Decoder(512)
        
        ##########################################
        # Classifier
        ##########################################
        #input feature size: 8*8*1024
        self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
        self.out = nn.Sequential(                  
                nn.Linear(2048, 256), nn.ReLU(True),nn.Dropout(self.dropout_rate),
                nn.Linear(256, 2), nn.Softmax(dim=1)
        )  

    
    def forward(self, data_z, data_y, data_x):

        # View-invariant representations
        self.z_com = self.encoder_share(data_z)
        self.y_com = self.encoder_share(data_y)
        self.x_com = self.encoder_share(data_x)
        # View-specific representations
        self.z_spe = self.encoder_z(data_z)
        self.y_spe = self.encoder_y(data_y)
        self.x_spe = self.encoder_x(data_x)

        # Reconstruction
        self.r_z = self.decoder_z(self.z_com + self.z_spe)
        self.r_y = self.decoder_y(self.y_com + self.y_spe)
        self.r_x = self.decoder_x(self.x_com + self.x_spe)

        # Fusion
        f_com = torch.cat([self.z_com.unsqueeze(1), self.y_com.unsqueeze(1), self.x_com.unsqueeze(1)], dim=1)
        f_com = torch.max(f_com,1)[0]
        f_com = self.encoder_com(f_com)
        f_com = self.avgpool(f_com).view(f_com.size()[0], -1)

        f_spe = torch.cat([self.z_spe.unsqueeze(1), self.y_spe.unsqueeze(1), self.x_spe.unsqueeze(1)], dim=1)
        f_spe = torch.max(f_spe,1)[0]
        f_spe = self.encoder_com(f_spe)
        f_spe = self.avgpool(f_spe).view(f_spe.size()[0], -1)

        com_meam = torch.cat([f_com, f_spe], dim=1)

        # Classifier
        cla_out = self.out(com_meam)

        return cla_out

In [18]:
from tqdm import tqdm
from sklearn import metrics

def try_gpu(i=0): #@save 
    if torch.cuda.device_count() >= i + 1: 
        return torch.device(f'cuda:{i}') 
    return torch.device('cpu')

batch_size = 8
DEVICE = try_gpu(0)

# Defines training and evaluation.
def train_model(model, optimizer, train_loader):
    model.train()
    
    # Loss Functions
    loss_cla = nn.CrossEntropyLoss()
    loss_recon = nn.MSELoss()
    loss_cmd = CMD()
    loss_diff = DiffLoss()
    
    for batch_idx, (data_z,data_y,data_x, y, _) in enumerate(train_loader):
        optimizer.zero_grad()
        data_z,data_y,data_x, y = data_z.to(DEVICE),data_y.to(DEVICE),data_x.to(DEVICE), y.to(DEVICE)
        cla_out = model(data_z,data_y,data_x)
        
        # Loss Calculation
        l_cla = loss_cla(cla_out, y)
        
        l_rec = loss_recon(model.r_z, data_z)
        l_rec += loss_recon(model.r_y, data_y)
        l_rec += loss_recon(model.r_x, data_x)
        l_rec = l_rec/3.0
        
        l_con = loss_cmd(model.z_com, model.y_com, 5)
        l_con += loss_cmd(model.z_com, model.x_com, 5)
        l_con += loss_cmd(model.x_com, model.y_com, 5)
        l_con = l_con/3.0
        
        l_dif = loss_diff(model.z_com, model.z_spe)
        l_dif += loss_diff(model.y_com, model.y_spe)
        l_dif += loss_diff(model.x_com, model.x_spe)
        l_dif += loss_diff(model.z_spe, model.y_spe)
        l_dif += loss_diff(model.z_spe, model.x_spe)
        l_dif += loss_diff(model.x_spe, model.y_spe)
        
        l = l_cla + model.a*l_rec + model.b*l_con + model.c*l_dif 
        
        #print('batch:{} | l_cla:{:.4f} | l_rec:{:.4f} | l_con:{:.4f} | l_dif:{:.4f}\n'.format(batch_idx, l_cla, model.a*l_rec, model.b*l_con, model.c*l_dif))
        l.backward()
        optimizer.step()


def eval_model(model, valid_loader):
    model.eval()
    #
    #count = 0
    #correct = 0
    name = np.array([])
    score = np.array([]) 
    pred = np.array([])
    label = np.array([])
    
    with torch.no_grad():
        for batch_idx, (data_z,data_y,data_x, y, name_batch) in enumerate(valid_loader):
            data_z,data_y,data_x, y = data_z.to(DEVICE),data_y.to(DEVICE),data_x.to(DEVICE), y.to(DEVICE)
            cla_out = model(data_z,data_y,data_x)
            pred_temp = cla_out.argmax(dim=1, keepdim=True)
            #correct += pred.eq(y.view_as(pred)).sum().item()
            #count += len(y)
            name = np.concatenate([name, name_batch], axis=0)
            score = np.concatenate([score, cla_out[:,1].cpu().numpy()], axis=0)
            pred = np.concatenate([pred, pred_temp.squeeze().cpu().numpy()], axis=0)
            label = np.concatenate([label, y.cpu().numpy()], axis=0)
            
    #acc = correct / count
    auc = metrics.roc_auc_score(label, score, average='macro', sample_weight=None)
    tn, fp, fn, tp = metrics.confusion_matrix(label, pred).ravel()
    acc = (tp + tn) / (tn + fp + fn + tp)
    sen = tp / (tp + fn)
    spe = tn / (tn + fp)
    
    return auc, acc, sen, spe

In [17]:
def train_manual(fold_num, DEVICE):
    
    torch.cuda.empty_cache()
    
    train_list, valid_list = get_k_fold(5, fold_num)
    train_data = Dataset_nsclc(img_path, train_list)
    valid_data = Dataset_nsclc(img_path, valid_list)
    test_data = Dataset_nsclc(test_path, test_list)
    train_iter = DataLoader(train_data, batch_size, shuffle=True)
    valid_iter = DataLoader(valid_data, batch_size, shuffle=True)
    test_iter = DataLoader(test_data, batch_size, shuffle=True)
    
    
    block_num = 2
    dropout = 0.3
    lr = 1.659-05
    epoch = 20
    a = 0.01
    b = 0.005
    c = 83.75
    loss_weight = [a, b, c]
    
    model = CARL(block_num, dropout, loss_weight).to(DEVICE)

    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    
    #   
    result = []
    result.append('epoch, tr_auc,tr_acc,tr_sen,tr_spe, va_auc,va_acc,va_sen,va_spe, te_auc,te_acc,te_sen,te_spe')
    
    for epoch in tqdm(range(epoch)):
        
        train_model(model, optimizer, train_iter)
        
        tr_auc, tr_acc, tr_sen, tr_spe = eval_model(model, train_iter)
        va_auc, va_acc, va_sen, va_spe = eval_model(model, valid_iter)
        te_auc, te_acc, te_sen, te_spe = eval_model(model, test_iter)
        
        temp = (str(epoch)+','+\
                str(tr_auc)+','+str(tr_acc)+','+str(tr_sen)+','+str(tr_spe)+','+\
                str(va_auc)+','+str(va_acc)+','+str(va_sen)+','+str(va_spe)+','+\
                str(te_auc)+','+str(te_acc)+','+str(te_sen)+','+str(te_spe))
        result.append(temp)
        
        print('epoch:{} | tr_auc:{:.3f} | va_auc:{:.3f} | ts_auc:{:.3f}\n'.format(epoch, tr_auc, va_auc, te_auc))
        print('epoch:{} | tr_acc:{:.3f} | va_acc:{:.3f} | ts_acc:{:.3f}\n'.format(epoch, tr_acc, va_acc, te_acc))
        

In [None]:
for i in range(5):
    train_manual(i, DEVICE)