In [None]:
# coding:utf-8
from __future__ import print_function
import numpy as np
from torch.utils.data import Dataset
from PIL import Image
import argparse
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
from utils import progress_bar
import os
from torchvision import models
from efficientnet_pytorch import EfficientNet
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import data_preprocess
from correlation import corr
from sklearn.metrics import roc_auc_score

os.environ["CUDA_VISIBLE_DEVICES"] = "0"
parser = argparse.ArgumentParser(description='PyTorch EfficientNet Training')
parser.add_argument('--lr', default=5e-5, type=float, help='learning rate')
args = parser.parse_args()

def adjust_learning_rate(optimizer, decay_rate=.5):  
    for param_group in optimizer.param_groups:
        param_group['lr'] = param_group['lr'] * decay_rate
    print("changing lr rate")
        
print('==> Preparing data..')

In [None]:
class MyDataset(Dataset):
    def __init__(self, imagesCC,labelsCC,imagesMLO,labelsMLO):
        self.imagesCC = imagesCC
        self.labelsCC = labelsCC
        self.imagesMLO = imagesMLO
        self.labelsMLO = labelsMLO

    def __getitem__(self, index):
        imgCC = torch.Tensor(self.imagesCC[index])
        targetCC = self.labelsCC[index]
        imgMLO = torch.Tensor(self.imagesMLO[index])
        targetMLO = self.labelsMLO[index]
        return imgCC,targetCC,imgMLO,targetMLO

    def __len__(self):
        return len(self.imagesCC)

In [None]:
# input image dimensions
img_rows, img_cols = 256, 256
img_channels = 1

# the data, shuffled and split between train and test sets
trXMLO, y_trainMLO, teXMLO, y_testMLO, trXCC, y_trainCC, teXCC, y_testCC = data_preprocess.loaddata()

# Reshape labels
trYMLO = y_trainMLO.reshape((y_trainMLO.shape[0],1))
teYMLO = y_testMLO.reshape((y_testMLO.shape[0],1))
trYCC = y_trainCC.reshape((y_trainCC.shape[0],1))
teYCC = y_testCC.reshape((y_testCC.shape[0],1))

ratio = trYMLO.sum()*1./trYMLO.shape[0]*1.

train_len = len(trXMLO)
test_len = len(teXMLO)
print('tr ratio'+str(ratio))
weights = np.array((ratio,1-ratio))
weights = torch.Tensor(weights)
weights = weights.cuda()

# Reshape and convert to 3-channel
def prepare_images(images):
    images = images.reshape(-1, img_channels, img_rows, img_cols)
    extended = np.zeros((images.shape[0], 3, img_rows, img_cols))
    for i in range(images.shape[0]):
        resized = np.resize(images[i,:,:,:], (img_rows, img_cols))
        extended[i,0,:,:] = resized
        extended[i,1,:,:] = resized
        extended[i,2,:,:] = resized
    return extended.astype('float32')

X_trainMLO = prepare_images(trXMLO)
X_testMLO = prepare_images(teXMLO)
X_trainCC = prepare_images(trXCC)
X_testCC = prepare_images(teXCC)

print('X_train shape:', X_trainMLO.shape)
print(X_trainMLO.shape[0], 'train samples')
print(X_testMLO.shape[0], 'test samples')

# Create datasets and dataloaders
trainset = MyDataset(X_trainCC, y_trainCC, X_trainMLO, y_trainMLO)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=8, shuffle=True, num_workers=2)
testset = MyDataset(X_testCC, y_testCC, X_testMLO, y_testMLO)
testloader = torch.utils.data.DataLoader(testset, batch_size=1, shuffle=False, num_workers=2)

In [None]:
print('==> Building model..')

class EfficientNetBackbone(nn.Module):
    def __init__(self):
        super(EfficientNetBackbone, self).__init__()
        self.efficientnet = EfficientNet.from_pretrained('efficientnet-b0')
        
        # Remove the original classifier
        self.features = nn.Sequential(*list(self.efficientnet.children())[:-1])
        
        # Add custom head for feature extraction
        self.avgpool = nn.AdaptiveAvgPool2d(1)
        
        # Classification head
        self.classifier = nn.Sequential(
            nn.Linear(1280, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 2)
        )
    
    def forward(self, x):
        # Extract features
        features = self.features(x)
        pooled = self.avgpool(features)
        flattened = torch.flatten(pooled, 1)
        
        # Classify
        predictions = self.classifier(flattened)
        
        return predictions, features

In [None]:
class DualViewEfficientNet(nn.Module):
    def __init__(self):
        super(DualViewEfficientNet, self).__init__()
        self.model_backbone = EfficientNetBackbone()
        
    def forward(self, CC, MLO):
        CC_predict, CC_feature = self.model_backbone(CC)
        MLO_predict, MLO_feature = self.model_backbone(MLO)
        
        # Calculate correlation between features
        corr_total = 0
        # EfficientNet features are 4D: [batch, channels, height, width]
        for i in range(CC_feature.size(2)):  # Loop through spatial dimensions
            for j in range(CC_feature.size(3)):
                corr_total += corr(CC_feature[:,:,i,j], MLO_feature[:,:,i,j])
        
        correlation = corr_total / (CC_feature.size(2) * CC_feature.size(3))  # Average correlation
        
        return CC_predict, MLO_predict, correlation

# Initialize model
net = DualViewEfficientNet()

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
net = net.to(device)

if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

criterion1 = nn.CrossEntropyLoss()
criterion2 = nn.CrossEntropyLoss(weight=weights)
optimizer = optim.Adam(net.parameters(), lr=args.lr, weight_decay=1e-5)

Loss_list = []
Accuracy_list = []

start_epoch = 0
best_accCC = 0
best_accMLO = 0
best_accAVG = 0
best_aucCC = 0
best_aucMLO = 0
best_aucAVG = 0

In [None]:
# Training function
def train(epoch):
    print('\nEpoch: %d' % epoch)
    net.train()
    train_lossCC = 0
    train_lossMLO = 0
    train_losscorr = 0
    train_losstotal = 0 
    correctCC = 0
    totalCC = 0
    correctMLO = 0
    totalMLO = 0    
    phase_predCC = np.array([])
    phase_labelCC = np.array([])
    phase_predMLO = np.array([])
    phase_labelMLO = np.array([])
    
    for batch_idx, (inputsCC, targetsCC, inputsMLO, targetsMLO) in enumerate(trainloader):
        inputsCC, targetsCC = inputsCC.to(device), targetsCC.to(device)
        inputsMLO, targetsMLO = inputsMLO.to(device), targetsMLO.to(device)

        optimizer.zero_grad()
        outputsCC, outputsMLO, correlation = net(inputsCC, inputsMLO)
        
        lossCC = criterion2(outputsCC, targetsCC.long())
        lossMLO = criterion2(outputsMLO, targetsMLO.long())
        losscorr = correlation
        losstotal = lossCC + lossMLO - losscorr
        losstotal.backward(retain_graph=True)
        optimizer.step()

        train_lossCC += lossCC.item()
        train_lossMLO += lossMLO.item()
        train_losscorr += losscorr.item()
        train_losstotal += losstotal.item()
        _, predictedCC = outputsCC.max(1)
        _, predictedMLO = outputsMLO.max(1)

        totalCC += targetsCC.size(0)
        correctCC += predictedCC.eq(targetsCC.long()).sum().item()
        totalMLO += targetsMLO.size(0)
        correctMLO += predictedMLO.eq(targetsMLO.long()).sum().item()

        phase_predCC = np.append(phase_predCC, predictedCC.cpu().numpy())
        phase_labelCC = np.append(phase_labelCC, targetsCC.data.cpu().numpy())

        phase_predMLO = np.append(phase_predMLO, predictedMLO.cpu().numpy())
        phase_labelMLO = np.append(phase_labelMLO, targetsMLO.data.cpu().numpy())

    epoch_aucCC = roc_auc_score(phase_labelCC, phase_predCC)
    epoch_aucMLO = roc_auc_score(phase_labelMLO, phase_predMLO)

    print('trainLossCC: %.3f, trainAccuCC: %.3f%% (%d/%d)' % (train_lossCC/(batch_idx+1), 100.*correctCC/totalCC, correctCC, totalCC))
    print('trainLossMLO: %.3f, trainAccuMLO: %.3f%% (%d/%d)' % (train_lossMLO/(batch_idx+1), 100.*correctMLO/totalMLO, correctMLO, totalMLO)) 
    print('trainepoch_aucCC: %.3f,trainepoch_aucMLO: %.3f' % (epoch_aucCC, epoch_aucMLO))
    print('trainLosscorr: %.3f,trainLosstotal: %.3f' % (train_losscorr/(batch_idx+1), train_losstotal/(batch_idx+1)))
    print('trainLosstotal: %.3f' % (train_losstotal/(batch_idx+1)))        

In [None]:
# Testing function
def test(epoch):
    global best_accCC, best_accMLO, best_accAVG, best_aucCC, best_aucMLO, best_aucAVG
    print("best_accCC:", best_accCC, "best_accMLO:", best_accMLO)
    print("best_accAVG:", best_accAVG)
    print("best_aucCC:", best_aucCC, "best_aucMLO:", best_aucMLO)
    print("best_aucAVG:", best_aucAVG)
    
    net.eval()
    test_lossCC = 0
    test_lossMLO = 0
    test_losscorr = 0
    test_losstotal = 0
    correctCC = 0
    correctMLO = 0
    totalCC = 0
    totalMLO = 0
    phase_predCC = np.array([])
    phase_labelCC = np.array([])
    phase_predMLO = np.array([])
    phase_labelMLO = np.array([])
    
    with torch.no_grad():
        for batch_idx, (inputsCC, targetsCC, inputsMLO, targetsMLO) in enumerate(testloader):
            inputsCC, targetsCC = inputsCC.to(device), targetsCC.to(device)
            inputsMLO, targetsMLO = inputsMLO.to(device), targetsMLO.to(device)

            outputsCC, outputsMLO, correlation = net(inputsCC, inputsMLO)
            lossCC = criterion1(outputsCC, targetsCC.long())
            lossMLO = criterion1(outputsMLO, targetsMLO.long())
            losstotal = lossCC + lossMLO 
            test_lossCC += lossCC.item()
            test_lossMLO += lossMLO.item()
            test_losstotal += losstotal.item()
            _, predictedCC = outputsCC.max(1)
            _, predictedMLO = outputsMLO.max(1)

            totalCC += targetsCC.size(0)
            correctCC += predictedCC.eq(targetsCC.long()).sum().item()
            totalMLO += targetsMLO.size(0)
            correctMLO += predictedMLO.eq(targetsMLO.long()).sum().item()

            phase_predCC = np.append(phase_predCC, outputsCC[0,1].cpu().numpy())        
            phase_labelCC = np.append(phase_labelCC, targetsCC.data.cpu().numpy())
            phase_predMLO = np.append(phase_predMLO, outputsMLO[0,1].cpu().numpy())
            phase_labelMLO = np.append(phase_labelMLO, targetsMLO.data.cpu().numpy())

    epoch_aucCC = roc_auc_score(phase_labelCC, phase_predCC)
    epoch_aucMLO = roc_auc_score(phase_labelMLO, phase_predMLO)
    epoch_aucAVG = 0.5*(epoch_aucCC + epoch_aucMLO)
    
    print('testLossCC: %.3f, testAccuCC: %.3f%% (%d/%d)' % (test_lossCC/(batch_idx+1), 100.*correctCC/totalCC, correctCC, totalCC))
    print('testLossMLO: %.3f, testAccuMLO: %.3f%% (%d/%d)' % (test_lossMLO/(batch_idx+1), 100.*correctMLO/totalMLO, correctMLO, totalMLO)) 
    print('testepoch_aucCC: %.3f,testepoch_aucMLO: %.3f,testepoch_aucAVG: %.3f' % (epoch_aucCC, epoch_aucMLO, epoch_aucAVG))
    print('testLosstotal: %.3f' % (test_losstotal/(batch_idx+1)))

    # Save checkpoint
    accCC = 100.*correctCC/totalCC
    accMLO = 100.*correctMLO/totalMLO
    accAVG = 0.5*(accCC + accMLO)
    
    if((accAVG > best_accAVG) or (accAVG == best_accAVG and epoch_aucAVG > best_aucAVG)):
        print('Saving...')
        state = {
            'net': net.state_dict(),
            'accCC': accCC,
            'accMLO': accMLO,
            'accAVG': accAVG,
            'epoch_aucCC': epoch_aucCC,
            'epoch_aucMLO': epoch_aucMLO,
            'epoch_aucavg': epoch_aucAVG, 
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, "best.pth")
        best_accCC = accCC
        best_accMLO = accMLO
        best_accAVG = accAVG
        best_aucCC = epoch_aucCC
        best_aucMLO = epoch_aucMLO
        best_aucAVG = epoch_aucAVG

In [None]:
# Training loop
for epoch in range(start_epoch, start_epoch+500):
    if (epoch % 30 == 0 and epoch != 0):
        adjust_learning_rate(optimizer, decay_rate=0.9)
    train(epoch)
    test(epoch)