In [None]:
# -*- coding: utf-8 -*-
"""
Created on Fri July 29 2022
Last revised on Sun June 16 2023

A Jupyer Notebook for Multi-input Vision Transformer with Similarity Matching

@author: Anonymous

Multi-input Vision Transformer with Similarity Matching
1) Backbone network: ViT-L/32, ResNet50
2) Cosine similarity

"""

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
import torch.backends.cudnn as cudnn
import os.path
from glob import glob
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cv2
import pydicom
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import StratifiedKFold, train_test_split
from sklearn.metrics import roc_curve, auc, confusion_matrix
import seaborn as sns
import timm

%matplotlib inline

In [None]:
torch.manual_seed(256)
torch.cuda.manual_seed(256)
torch.cuda.manual_seed_all(256)
np.random.seed(0)
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(0)

In [None]:
code1 = []
cropped1 = []

img_path = 'path/to/class0'

for i in range(len(os.listdir(img_path))):
    temp = os.listdir(img_path)
    path = glob(os.path.join(img_path, temp[i]))
    temp = pydicom.read_file(path)
    # original image size: 384*384 for ViT-L/32
    img = temp.pixel_array
    
    ds = pydicom.dcmread(path)
    imgnum = ds[0x0028,0x0004].value
    print(imgnum.lower())
    if imgnum.lower() == 'monochrome1':
        img = np.invert(img)
        print('converted')
    
    img2 = cv2.resize(img, (384, 384), interpolation = cv2.INTER_LANCZOS4)
    scaler = StandardScaler()
    scaled = scaler.fit_transform(img2)
    code1.append(torch.stack([torch.tensor(scaled), torch.tensor(scaled), torch.tensor(scaled)], dim=0))
    
    # ROI-cropped image
    img3 = cv2.resize(img, (2048, 2048), interpolation = cv2.INTER_LANCZOS4)
    img4 = img3[200:2000, 200:2000]
    scaler = StandardScaler()
    img5 = scaler.fit_transform(img4)
    img5 = cv2.resize(img5, (384, 384), interpolation = cv2.INTER_LANCZOS4)
    cropped1.append(torch.stack([torch.tensor(img5), torch.tensor(img5), torch.tensor(img5)], dim=0))

In [None]:
code2 = []
cropped2 = []

img_path = 'path/to/class1'

for i in range(len(os.listdir(img_path))):
    temp = os.listdir(img_path)
    path = glob(os.path.join(img_path, temp[i]))
    temp = pydicom.read_file(path)
    # original image size: 384*384 for ViT-L/32
    img = temp.pixel_array
    
    ds = pydicom.dcmread(path)
    imgnum = ds[0x0028,0x0004].value
    print(imgnum.lower())
    if imgnum.lower() == 'monochrome1':
        img = np.invert(img)
        print('converted')
    
    img2 = cv2.resize(img, (384, 384), interpolation = cv2.INTER_LANCZOS4)
    scaler = StandardScaler()
    scaled = scaler.fit_transform(img2)
    code2.append(torch.stack([torch.tensor(scaled), torch.tensor(scaled), torch.tensor(scaled)], dim=0))
    
    # ROI-cropped image
    img3 = cv2.resize(img, (2048, 2048), interpolation = cv2.INTER_LANCZOS4)
    img4 = img3[200:2000, 200:2000]
    scaler = StandardScaler()
    img5 = scaler.fit_transform(img4)
    img5 = cv2.resize(img5, (384, 384), interpolation = cv2.INTER_LANCZOS4)
    cropped2.append(torch.stack([torch.tensor(img5), torch.tensor(img5), torch.tensor(img5)], dim=0))

In [None]:
limited = code1
good = code2
limited_c = cropped1
good_c = cropped2

In [None]:
target = []

for i in range(len(limited)):
    target.append(0)
    
for i in range(len(good)):
    target.append(1)

idx_list = []
for i in range(len(target)):
  idx_list.append(i)

In [None]:
idx_train, idx_test = train_test_split(idx_list, test_size = 0.1, shuffle = True, random_state = 256)

In [None]:
data = limited + good
data_c = limited_c + good_c

In [None]:
traindata = []
trainidx = []
traindata_c = []

for i in range(len(idx_train)):
    traindata.append(data[idx_train[i]])
    trainidx.append(target[idx_train[i]])
    traindata_c.append(data_c[idx_train[i]])

testdata = []
testidx = []
testdata_c = []

for i in range(len(idx_test)):
    testdata.append(data[idx_test[i]])
    testidx.append(target[idx_test[i]])
    testdata_c.append(data_c[idx_test[i]])

trainidx = np.array(trainidx)
testidx = np.array(testidx)

In [None]:
def cosine_similarity(a, b):
    a_norm = torch.linalg.norm(a)
    b_norm = torch.linalg.norm(b)
    a_b_dot = torch.inner(a, b)
    return torch.mean(a_b_dot / (a_norm * b_norm))

In [None]:
def l1_norm(a, b):
    add_inv_b = torch.mul(b, -1)
    summation = torch.add(a, add_inv_b)
    abs_val = torch.abs(summation)
    return torch.sum(abs_val)

In [None]:
def l2_norm(a, b):
    add_inv_b = torch.mul(b, -1)
    summation = torch.add(a, add_inv_b)
    square = torch.mul(summation, summation)
    sqrt = torch.sqrt(square)
    return torch.sum(sqrt)

In [None]:
model1 = timm.create_model('vit_large_patch32_384', pretrained=True).cuda()


for parameter in model1.parameters():
    parameter.requires_grad = False


num_features = model1.head.in_features

model1.head = nn.Sequential(
    nn.Linear(num_features, 1),
    nn.Sigmoid()
)

model1 = model1.cuda()

model2 = timm.create_model('vit_large_patch32_384', pretrained=True).cuda()


for parameter in model2.parameters():
    parameter.requires_grad = False


num_features = model2.head.in_features

model2.head = nn.Sequential(
    nn.Linear(num_features, 1),
    nn.Sigmoid()
)

model2 = model2.cuda()

In [None]:
class TwoInputsNet(nn.Module):
    def __init__(self):
        super(TwoInputsNet, self).__init__()
        # model1, model2: pre-trained ViT-L/32
        self.model1 = torch.nn.Sequential(*list(model1.children())[:-2])
        self.model2 = torch.nn.Sequential(*list(model2.children())[:-2])
        self.fc = nn.Sequential(
            nn.Linear(294912, 1000),
            nn.BatchNorm1d(1000),
            nn.Dropout(0.3),
            nn.Linear(1000, 100),
            nn.BatchNorm1d(100),
            nn.Dropout(0.3),
            nn.Linear(100, 1),
            nn.Sigmoid()
        )

    def forward(self, input1, input2):
        c = self.model1(input1)
        f = self.model2(input2)
        combined = torch.cat([c, f], dim = 2)
        combined2 = combined.reshape(c.shape[0], -1)
        out = self.fc(combined2)
        return out

model_merged = TwoInputsNet().cuda()

In [None]:
batchsize = 64
optimizer = torch.optim.Adam(model_merged.parameters(), lr = 0.001)
criterion = nn.BCELoss()
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda epoch : 0.95 ** epoch)
epochs = 100

In [None]:
testfinal = []
for i in range(len(testdata)):
    testfinal.append((testdata[i], testdata_c[i], testidx[i]))

testloader = DataLoader(testfinal, batch_size = batchsize, shuffle = False)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
def train(model, train_loader, optimizer, log_interval):
    model.train()

    for batch_idx, (image1, image2, label) in enumerate(train_loader):
        image1 = image1.to(device, dtype=torch.float)
        image2 = image2.to(device, dtype=torch.float)
        label = label.to(device)
        optimizer.zero_grad()
        output = model(image1, image2).squeeze(dim=1)
        loss = criterion(output.to(torch.float32), label.to(torch.float32))
        loss.backward()
        optimizer.step()

        if batch_idx % log_interval == 0:
          print("Train Epoch: {} [{}/{} ({:.0f}%)]\tTrain Loss: {:.6f}".format(epoch, batch_idx * len(image1), 
                len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.item()))
    scheduler.step()

In [None]:
def evaluate(model, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    testlabel = []
    testpred = []
    testprob = []

    with torch.no_grad():
        for image1, image2, label in test_loader:
            image1 = image1.to(device, dtype = torch.float)
            image2 = image2.to(device, dtype = torch.float)
            label = label.to(device)
            testlabel.append(label)
            output = model(image1, image2).squeeze(dim=1)
            testprob.append(output.to(torch.float32))
            test_loss += criterion(output.to(torch.float32), label.to(torch.float32)).item()
            output1 = output.cpu()
            output1[output1 >= 0.5] = 1
            output1[output1 < 0.5] = 0
            correct += output1.eq(label.cpu()).int().sum()
            testpred.append(output1)
    
    test_loss /= len(test_loader.dataset)
    test_accuracy = 100. * correct / len(test_loader.dataset)
    return test_loss, test_accuracy, testlabel, testpred, testprob

In [None]:
kfold = 10
skf = StratifiedKFold(n_splits = kfold, shuffle = True, random_state = 256)
traindata1 = np.array(traindata)
traindata2 = np.array(traindata_c)

In [None]:
testlabel = []
testpred = []
testloss = []
testacc = []
testprob = []


for i, (train_index, val_index) in enumerate(skf.split(traindata1, trainidx)):
    
    
    print('[Fold %d/%d]' % (i + 1, kfold))
    
    X_train, X_valid = traindata1[train_index], traindata1[val_index]
    y_train, y_valid = trainidx[train_index], trainidx[val_index]
    
    # Split the original and ROI-cropped images with respect to their classes - the given classes of input pairs aren't mixed up.
    
    original_0 = []
    original_1 = []
    
    for j in range(len(y_train)):
        if y_train[j] == 0:
            original_0.append(traindata1[train_index[j]])
        else:
            original_1.append(traindata1[train_index[j]])
    
    cropped_0 = []
    cropped_1 = []
    
    for k in range(len(y_train)):
        if y_train[k] == 0:
            cropped_0.append(traindata2[train_index[k]])
        else:
            cropped_1.append(traindata2[train_index[k]])
    
    # Calculate cosine similarities

    metric_0 = []
    metric_1 = []
    
    for a in range(len(original_0)):
        for b in range(len(cropped_0)):
            metric_0.append((cosine_similarity(original_0[a][0], cropped_0[b][0])).detach().cpu().numpy().item())

    for c in range(len(original_1)):
        for d in range(len(cropped_1)):
            metric_1.append((cosine_similarity(original_1[c][0], cropped_1[d][0])).detach().cpu().numpy().item())
    
    # List chunk (List comprehension)
    
    metric_0_chunk = [metric_0[e * len(cropped_0):(e + 1) * len(cropped_0)] for e in range((len(metric_0) + len(cropped_0) - 1) // len(cropped_0) )]
    metric_1_chunk = [metric_1[f * len(cropped_1):(f + 1) * len(cropped_1)] for f in range((len(metric_1) + len(cropped_1) - 1) // len(cropped_1) )]
    
    # Find the index having the lowest cosine similarity
    
    min_0 = []
    min_1 = []
    
    for g in range(len(metric_0_chunk)):
        min_0.append(np.argmin(metric_0_chunk[g]))
        
        # Removing duplicates process
        for n in range(len(metric_0_chunk)):
            metric_0_chunk[n][min_0[g]] = np.inf        

    for h in range(len(metric_1_chunk)):
        min_1.append(np.argmin(metric_1_chunk[h]))
        
        # Removing duplicates process
        for o in range(len(metric_1_chunk)):
            metric_1_chunk[o][min_1[h]] = np.inf  
    
    # Aggregate two matched images

    cropped_0_sorted = []
    cropped_1_sorted = []
    
    for l in range(len(min_0)):
        cropped_0_sorted.append(cropped_0[min_0[l]])

    for m in range(len(min_1)):
        cropped_1_sorted.append(cropped_1[min_1[m]])

    new_traindata2 = []
    
    cropped_num0 = 0
    cropped_num1 = 0
    
    for k in range(len(y_train)):
        if y_train[k] == 0:
            new_traindata2.append(cropped_0_sorted[cropped_num0])
            cropped_num0 += 1

        else:
            new_traindata2.append(cropped_1_sorted[cropped_num1])
            cropped_num1 += 1
    
    new_traindata2 = np.array(new_traindata2)
    
    # There is no need to implement Similarity Matching for validation and test process. 
    X_train2, X_valid2 = new_traindata2, traindata2[val_index]
    
    
    trainfinal = []

    for h in range(X_train.shape[0]):
      trainfinal.append((X_train[h], X_train2[h], y_train[h]))

    valfinal = []

    for t in range(X_valid.shape[0]):
      valfinal.append((X_valid[t], X_valid2[t], y_valid[t]))

    trainloader = DataLoader(trainfinal, batch_size = batchsize, shuffle = False)
    validloader = DataLoader(valfinal, batch_size = batchsize, shuffle = False)
    
    print('[Fold %d/%d Prediciton:]' % (i + 1, kfold))
    


    # Train and Validation Process
    epochval = []
    valloss = []
    valacc = []

    for epoch in range(1, epochs + 1):
        train(model_merged, trainloader, optimizer, log_interval = 5)
        valid_loss, valid_accuracy, _, _, _ = evaluate(model_merged, validloader)
        epochval.append(epoch)
        valloss.append(valid_loss)
        valacc.append(valid_accuracy)
        print("\n[EPOCH: {}], \tValidation Loss: {:.6f}, \tValidation Accuracy: {:.6f} % \n".format(
            epoch, valid_loss, valid_accuracy))
    

    # Test Process
    test_loss, test_accuracy, label, pred, prob = evaluate(model_merged, testloader)
    testlabel.append(label)
    testpred.append(pred)
    testloss.append(test_loss)
    testacc.append(test_accuracy)
    testprob.append(prob)

    print("\nTest Loss: {:.4f}, \tTest Accuracy: {:.4f} % \n".format(test_loss, test_accuracy))

In [None]:
# Check whether every entry of metric_0_chunk is inf 
print(metric_0_chunk)

In [None]:
num_list = []
for i in range(len(testprob[1])):
    num_list.append(i)

In [None]:
def plot_roc_curve_fold_red(fper, tper):
    sns.lineplot(x = fper, y = tper, ci=None, color='red', alpha = 0.08)

def plot_roc_curve_red(fper, tper):
    sns.lineplot(x = fper, y = tper, ci=None, color='red', alpha = 1)
    plt.plot([0, 1], [0, 1], color='gray', linestyle='--')
    
def plot_roc_curve_fold_blue(fper, tper):
    sns.lineplot(x = fper, y = tper, ci=None, color='blue', alpha = 0.08)

def plot_roc_curve_blue(fper, tper):
    sns.lineplot(x = fper, y = tper, ci=None, color='blue', alpha = 1)

In [None]:
plt.figure(figsize = (8, 8))
auc10 = []
temptpr = []
tempfpr = []
    
for i in range(kfold):
    true = np.concatenate([testlabel[i][j].detach().cpu().numpy() for j in num_list])
    prob = np.concatenate([testprob[i][j].detach().cpu().numpy() for j in num_list])
    fpr, tpr, _ = roc_curve(true, prob)
    
    tempfpr.append(fpr)
    temptpr.append(tpr)
    
    plot_roc_curve_fold_red(fpr, tpr)
    currauc = auc(fpr, tpr)
    
    auc10.append(currauc)
    
plot_roc_curve_red(np.sort(np.array(pd.DataFrame(tempfpr).mean())), np.sort(np.array(pd.DataFrame(temptpr).mean())))
ax = plt.gca()
ax.axes.xaxis.set_visible(False)
ax.axes.yaxis.set_visible(False)

plt.grid(True)
plt.legend(labels = ['ViT-L/32'], loc='lower right', fontsize = 14)

In [None]:
sens = []
spec = []
ppv_list = []
npv_list = []

for i in range(kfold):
    tn, fp, fn, tp = confusion_matrix(np.concatenate([testlabel[i][j].detach().cpu().numpy() for j in num_list]).astype(int), np.concatenate([testpred[i][j].detach().cpu().numpy() for j in num_list]).astype(int)).ravel()
    sensitivity = tp / (tp + fn)
    specificity = tn / (tn + fp)
    ppv = tp / (tp + fp)
    npv = tn / (tn + fn)
    sens.append(sensitivity)
    spec.append(specificity)
    ppv_list.append(ppv)
    npv_list.append(npv)

In [None]:
print(np.mean(testacc)/100, np.mean(auc10), np.mean(sens), np.mean(spec), np.mean(ppv_list), np.mean(npv_list))
print(np.std(testacc), np.std(auc10), np.std(sens), np.std(spec), np.std(ppv_list), np.std(npv_list))

In [None]:
torch.save(model_merged, 'vit-similarity.pt')